import json import operator import re from typing import Annotated, TypedDict from langchain_core.messages import AIMessage from langchain_core.tools import render_text_description from langgraph.graph import END, StateGraph from langgraph.prebuilt import ToolNode from modules.tools import ( get_patient_data_manifest, get_patient_fhir_resource, ) _LLM_INVOKE_ARGS = {"max_tokens": 8000, "temperature": 0.6} def exclude_thinking_component(text: str) -> str: """Removes the thinking block (delimited by and ) from a string.""" return re.sub(r".*?", "", text, flags=re.DOTALL).strip() def strip_json_decoration(text: str) -> str: """Removes JSON markdown fences from the start and end of a string.""" match = re.search(r"```(?:json)?\s*([\{\[].*[\]\}])\s*```", text, re.DOTALL) if match: return match.group(1) return text.strip() class AgentState(TypedDict): messages: Annotated[list, operator.add] patient_fhir_manifest: dict tool_output_summary: Annotated[list, operator.add] tool_calls_to_execute: Annotated[list, operator.add] relevant_resource_types: list manifest_tool_call_request: AIMessage sdt_idx: int edr_idx: int resource_type_processed: str resource_type_retrieved: str summary_generated: bool resource_type_to_retrieve: str resource_type_to_process: str fhir_tool_output: str resource_being_summarized: str tool_call: dict resource_manifest_codes: list def create_agent(llm, fhir_store_url): """Creates and compiles the LangGraph agent.""" manifest_tool_node = ToolNode([get_patient_data_manifest]) data_retrieval_tool_node = ToolNode([get_patient_fhir_resource]) def generate_manifest_tool_call_node(state): """The first step: uses the LLM to find the patient_id from the initial question and generates a tool call for get_patient_data_manifest. """ last_message = state["messages"][-1] extraction_prompt = ( f"USER QUESTION: {last_message.content}\\n\\nYou are an API request" " generator. Your task is to identify the patient ID from the user's" " question and output a JSON object to call the" " `get_patient_data_manifest` tool.\\n\\nYour available tool" f" is:\\n{render_text_description([get_patient_data_manifest])}\\n\\nGenerate" " the correct JSON to call the tool. Respond with only a single, raw" ' JSON object.\\n\\nEXAMPLE:\\n{\\n "name":' ' "get_patient_data_manifest",\\n "args": {\\n "patient_id":' ' "some-patient-id-from-the-question"\\n }\\n}\\n' ) print( "--- generate_manifest_tool_call_node PROMPT" f" ---\n{extraction_prompt}\n-----------------------------" ) response_str = llm.invoke(extraction_prompt, **_LLM_INVOKE_ARGS) print( "--- generate_manifest_tool_call_node RESPONSE" f" ---\n{response_str}\n------------------------------" ) try: cleaned_response = strip_json_decoration(response_str) tool_call_json = json.loads(cleaned_response) tool_call_json["args"]["fhir_store_url"] = fhir_store_url tool_call_msg = AIMessage( content="", tool_calls=[{**tool_call_json, "id": "manifest_call"}], ) return { "manifest_tool_call_request": tool_call_msg, "tool_call": tool_call_json, } except Exception as e: print(f"Error generating manifest tool call: {e}") raise e def execute_manifest_tool_call_node(state): """Executes the get_patient_data_manifest tool call and puts the result in state.""" try: tool_call_msg = state["manifest_tool_call_request"] tool_output_message = manifest_tool_node.invoke([tool_call_msg])[0] manifest_dict = json.loads(tool_output_message.content) print(f"Manifest dict: {manifest_dict}") return {"patient_fhir_manifest": manifest_dict} except Exception as e: print(f"Error calling manifest tool: {e}") raise e def identify_relevant_resource_types(state): """Uses the manifest and user question to identify relevant FHIR resource types.""" print("Identifying Relevant Resource Types") manifest = state.get("patient_fhir_manifest", {}) user_question = state["messages"][1].content manifest_content = "" for resource_type, codes in manifest.items(): manifest_content += f"**{resource_type}**: " if codes: manifest_content += f"Available codes include: {', '.join(codes)}\\n" else: manifest_content += "Present (no specific codes found)\\n" prompt = ( "SYSTEM INSTRUCTION: think silently if needed.\\nUSER QUESTION:" f" {user_question}\\n\\nPATIENT DATA" f" MANIFEST:\\n{manifest_content}\\n\\nYou are a medical assistant" " analyzing a patient's FHIR data manifest to answer a user" " question.\\nBased on the user question, identify the specific FHIR" " resource types from the manifest that are most likely to contain the" " information needed to answer the question.\\nOutput a JSON list of" " the relevant resource types. Do not include any other text or" ' formatting.\\nExample:\n["Condition", "Observation",' ' "MedicationRequest"]\n' ) print( "--- identify_relevant_resource_types PROMPT" f" ---\n{prompt}\n------------------------------------------" ) response_str = llm.invoke(prompt, **_LLM_INVOKE_ARGS) print( "--- identify_relevant_resource_types RESPONSE" f" ---\n{response_str}\n-------------------------------------------" ) try: relevant_resource_types = json.loads(strip_json_decoration(response_str)) except json.JSONDecodeError: print( "Could not decode JSON response for relevant resource types:" f" {response_str}" ) relevant_resource_types = [] print( "Relevant Resource Types Identified:" f" {', '.join(relevant_resource_types)}" ) return { "relevant_resource_types": relevant_resource_types, "sdt_idx": 0, "tool_calls_to_execute": [], } def announce_sdt_node(state): sdt_idx = state["sdt_idx"] relevant_resource_types = state.get("relevant_resource_types", []) resource_type = relevant_resource_types[sdt_idx] manifest = state.get("patient_fhir_manifest", {}) resource_manifest = manifest.get(resource_type, []) print(f"Announcing data selection for {resource_type}") return { "resource_type_to_process": resource_type, "resource_manifest_codes": resource_manifest, } def select_data_to_retrieve(state): """Uses the manifest and relevant resource types to determine which FHIR resources to retrieve.""" sdt_idx = state["sdt_idx"] manifest = state.get("patient_fhir_manifest", {}) relevant_resource_types = state.get("relevant_resource_types", []) tools_string = render_text_description([get_patient_fhir_resource]) resource_type = relevant_resource_types[sdt_idx] print(f"Data Selection for {resource_type}") if resource_type not in manifest: print(f"No data found for {resource_type} in the manifest.") return {"sdt_idx": sdt_idx + 1, "resource_type_processed": resource_type} manifest_content = f"**{resource_type}**: " if len(manifest.get(resource_type, [])) > 0: manifest_content += ( f"Available codes include: {', '.join(manifest[resource_type])}\\n" ) else: manifest_content += "Present (no specific codes found)\\n" prompt = ( "SYSTEM INSTRUCTION: think silently if needed.\\n" + "FOR CONTEXT ONLY, USER QUESTION:" f" {state['messages'][1].content}\\n\\n" + f"PATIENT DATA MANIFEST: {manifest_content}\\n\\n" + "You are a specialized API request generator. Your SOLE task is to" " output a JSON of a tool call to gather the necessary information" " to answer the user's question. Respond with ONLY a JSON, no" " explanations or prose.\\n" + f"Your available tool is:\\n{tools_string}\\n\\n" + f"**At this stage you can only call {resource_type}.**\\n" + "EXAMPLE:\\n" + '{\\"name\\": \\"get_patient_fhir_resource\\", \\"args\\":' ' {\\"patient_id\\": \\"some-patient-id\\",' ' \\"fhir_resource\\": \\"' + resource_type + '\\", \\"filter_code\\": \\"csv-codes-from-manifest\\"}}' ) print( f"--- select_data_to_retrieve PROMPT ({resource_type})" f" ---\n{prompt}\n------------------------------------------" ) response_str = llm.invoke(prompt, **_LLM_INVOKE_ARGS) print( f"--- select_data_to_retrieve RESPONSE ({resource_type})" f" ---\n{response_str}\n-------------------------------------------" ) try: tool_call = json.loads(strip_json_decoration(response_str)) tool_call["args"]["fhir_store_url"] = fhir_store_url return { "tool_calls_to_execute": [{**tool_call, "id": resource_type}], "sdt_idx": sdt_idx + 1, "resource_type_processed": resource_type, } except json.JSONDecodeError: print( f"Could not decode JSON response for {resource_type}: {response_str}" ) # If we fail to decode, we just skip this resource type. return {"sdt_idx": sdt_idx + 1, "resource_type_processed": resource_type} def sdt_conditional_edge(state): if state["sdt_idx"] < len(state["relevant_resource_types"]): return "announce_sdt" return "init_edr_idx" def init_edr_idx_node(state): return {"edr_idx": 0} def init_edr_conditional_edge(state): if state["tool_calls_to_execute"]: return "announce_retrieval" return "final_answer" def announce_retrieval_node(state): edr_idx = state["edr_idx"] tool_calls = state.get("tool_calls_to_execute", []) tool_call = tool_calls[edr_idx] resource_type = tool_call.get("id", "unknown_resource") print(f"Announcing retrieval of {resource_type}") return {"resource_type_to_retrieve": resource_type} def execute_data_retrieval(state): """Executes the planned tool calls and summarizes the output.""" edr_idx = state["edr_idx"] tool_calls = state.get("tool_calls_to_execute", []) tool_call = tool_calls[edr_idx] resource_type = tool_call.get("id", "unknown_resource") print(f"Fetching FHIR data for {resource_type}") tool_output_list = data_retrieval_tool_node.invoke( [AIMessage(content="", tool_calls=[tool_call])] ) if not tool_output_list: print(f"No tool output received for {resource_type}") return { "resource_type_retrieved": resource_type, "summary_generated": False, "fhir_tool_output": "", } tool_output = tool_output_list[0].content return { "resource_type_retrieved": resource_type, "summary_generated": True, "fhir_tool_output": tool_output, } def announce_summarization_node(state): resource_type = state["resource_type_retrieved"] print(f"Announcing summarization of {resource_type}") return {"resource_being_summarized": resource_type} def summarize_node(state: AgentState) -> dict: if not state["summary_generated"]: return {"edr_idx": state["edr_idx"] + 1} resource_type = state["resource_type_retrieved"] tool_output = state["fhir_tool_output"] concise_facts_prompt = ( "SYSTEM INSTRUCTION: think silently if needed.\\nFOR CONTEXT ONLY," f" USER QUESTION: {state['messages'][1].content}\\n\\nTOOL" f" OUTPUT:\\n{tool_output}\\n\\nYou are a fact summarizing agent." " Your output will be used to answer the USER QUESTION.\\nCollect" " from the 'TOOL OUTPUT' facts ONLY if it is relevant to answer the" " USER QUESTION.\\nWrite a very concise English summary, only facts" " relevant to the user question. DO NOT OUTPUT JSON.\\nYou are not" " authorized to answer the user question. Do not provide any output" " beyond concise facts. Filter out any facts which are not helpful" " for the user question. Include date or date ranges. Only for the" " most critical facts, include FHIR record references [record" " type/record id]. For repeating multiple times provide summarize" " and provide only a single reference and date range." ) print( f"--- summarize_node PROMPT ({resource_type})" f" ---\n{concise_facts_prompt}\n------------------------------------------" ) current_summary = llm.invoke(concise_facts_prompt, **_LLM_INVOKE_ARGS) print( f"--- summarize_node RESPONSE ({resource_type})" f" ---\n{current_summary}\n-------------------------------------------" ) return { "tool_output_summary": [exclude_thinking_component(current_summary)], "edr_idx": state["edr_idx"] + 1, "resource_type_retrieved": resource_type, } def should_summarize_edge(state): if state["summary_generated"]: return "announce_summarization" return "summarize_node" def edr_conditional_edge(state): if state["edr_idx"] < len(state["tool_calls_to_execute"]): return "announce_retrieval" return "final_answer" def get_final_answer(state): """If we have enough data, this node generates the final answer.""" summary = "\\n\\n".join(state["tool_output_summary"]) prompt = ( "Synthesize all information from the 'SUMMARIZED INFORMATION' to" " provide a comprehensive final answer. Preserve relevant FHIR" " references.\\n\\nUSER QUESTION:" f" {state['messages'][1].content}\\n\\nSUMMARIZED INFORMATION:" f" {summary}\\n\\nFinal Answer using markdown:" ) print( "--- get_final_answer PROMPT" f" ---\n{prompt}\n------------------------------------------" ) response = llm.invoke(prompt, **_LLM_INVOKE_ARGS) print( "--- get_final_answer RESPONSE" f" ---\n{response}\n-------------------------------------------" ) return {"messages": [AIMessage(content=response)]} workflow = StateGraph(AgentState) workflow.add_node( "generate_manifest_tool_call", generate_manifest_tool_call_node ) workflow.add_node( "execute_manifest_tool_call", execute_manifest_tool_call_node ) workflow.add_node( "identify_relevant_resource_types", identify_relevant_resource_types ) workflow.add_node("announce_sdt", announce_sdt_node) workflow.add_node("select_data_to_retrieve", select_data_to_retrieve) workflow.add_node("init_edr_idx", init_edr_idx_node) workflow.add_node("announce_retrieval", announce_retrieval_node) workflow.add_node("execute_data_retrieval", execute_data_retrieval) workflow.add_node("announce_summarization", announce_summarization_node) workflow.add_node("summarize_node", summarize_node) workflow.add_node("final_answer", get_final_answer) workflow.set_entry_point("generate_manifest_tool_call") workflow.add_edge("generate_manifest_tool_call", "execute_manifest_tool_call") workflow.add_edge( "execute_manifest_tool_call", "identify_relevant_resource_types" ) workflow.add_edge( "identify_relevant_resource_types", "announce_sdt" ) workflow.add_edge("announce_sdt", "select_data_to_retrieve") workflow.add_conditional_edges( "select_data_to_retrieve", sdt_conditional_edge, { "announce_sdt": "announce_sdt", "init_edr_idx": "init_edr_idx", }, ) workflow.add_conditional_edges( "init_edr_idx", init_edr_conditional_edge, { "announce_retrieval": "announce_retrieval", "final_answer": "final_answer", }, ) workflow.add_edge("announce_retrieval", "execute_data_retrieval") workflow.add_conditional_edges( "execute_data_retrieval", should_summarize_edge, { "announce_summarization": "announce_summarization", "summarize_node": "summarize_node", }, ) workflow.add_edge("announce_summarization", "summarize_node") workflow.add_conditional_edges( "summarize_node", edr_conditional_edge, { "announce_retrieval": "announce_retrieval", "final_answer": "final_answer", }, ) workflow.add_edge("final_answer", END) return workflow.compile()