lirony's picture
Update
ff5e06a
import json
import operator
import re
from typing import Annotated, TypedDict
from langchain_core.messages import AIMessage
from langchain_core.tools import render_text_description
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from modules.tools import (
get_patient_data_manifest,
get_patient_fhir_resource,
)
_LLM_INVOKE_ARGS = {"max_tokens": 8000, "temperature": 0.6}
def exclude_thinking_component(text: str) -> str:
"""Removes the thinking block (delimited by <unused94> and <unused95>) from a string."""
return re.sub(r"<unused94>.*?<unused95>", "", text, flags=re.DOTALL).strip()
def strip_json_decoration(text: str) -> str:
"""Removes JSON markdown fences from the start and end of a string."""
match = re.search(r"```(?:json)?\s*([\{\[].*[\]\}])\s*```", text, re.DOTALL)
if match:
return match.group(1)
return text.strip()
class AgentState(TypedDict):
messages: Annotated[list, operator.add]
patient_fhir_manifest: dict
tool_output_summary: Annotated[list, operator.add]
tool_calls_to_execute: Annotated[list, operator.add]
relevant_resource_types: list
manifest_tool_call_request: AIMessage
sdt_idx: int
edr_idx: int
resource_type_processed: str
resource_type_retrieved: str
summary_generated: bool
resource_type_to_retrieve: str
resource_type_to_process: str
fhir_tool_output: str
resource_being_summarized: str
tool_call: dict
resource_manifest_codes: list
def create_agent(llm, fhir_store_url):
"""Creates and compiles the LangGraph agent."""
manifest_tool_node = ToolNode([get_patient_data_manifest])
data_retrieval_tool_node = ToolNode([get_patient_fhir_resource])
def generate_manifest_tool_call_node(state):
"""The first step: uses the LLM to find the patient_id from the initial question
and generates a tool call for get_patient_data_manifest.
"""
last_message = state["messages"][-1]
extraction_prompt = (
f"USER QUESTION: {last_message.content}\\n\\nYou are an API request"
" generator. Your task is to identify the patient ID from the user's"
" question and output a JSON object to call the"
" `get_patient_data_manifest` tool.\\n\\nYour available tool"
f" is:\\n{render_text_description([get_patient_data_manifest])}\\n\\nGenerate"
" the correct JSON to call the tool. Respond with only a single, raw"
' JSON object.\\n\\nEXAMPLE:\\n{\\n "name":'
' "get_patient_data_manifest",\\n "args": {\\n "patient_id":'
' "some-patient-id-from-the-question"\\n }\\n}\\n'
)
print(
"--- generate_manifest_tool_call_node PROMPT"
f" ---\n{extraction_prompt}\n-----------------------------"
)
response_str = llm.invoke(extraction_prompt, **_LLM_INVOKE_ARGS)
print(
"--- generate_manifest_tool_call_node RESPONSE"
f" ---\n{response_str}\n------------------------------"
)
try:
cleaned_response = strip_json_decoration(response_str)
tool_call_json = json.loads(cleaned_response)
tool_call_json["args"]["fhir_store_url"] = fhir_store_url
tool_call_msg = AIMessage(
content="",
tool_calls=[{**tool_call_json, "id": "manifest_call"}],
)
return {
"manifest_tool_call_request": tool_call_msg,
"tool_call": tool_call_json,
}
except Exception as e:
print(f"Error generating manifest tool call: {e}")
raise e
def execute_manifest_tool_call_node(state):
"""Executes the get_patient_data_manifest tool call and puts the result in state."""
try:
tool_call_msg = state["manifest_tool_call_request"]
tool_output_message = manifest_tool_node.invoke([tool_call_msg])[0]
manifest_dict = json.loads(tool_output_message.content)
print(f"Manifest dict: {manifest_dict}")
return {"patient_fhir_manifest": manifest_dict}
except Exception as e:
print(f"Error calling manifest tool: {e}")
raise e
def identify_relevant_resource_types(state):
"""Uses the manifest and user question to identify relevant FHIR resource types."""
print("Identifying Relevant Resource Types")
manifest = state.get("patient_fhir_manifest", {})
user_question = state["messages"][1].content
manifest_content = ""
for resource_type, codes in manifest.items():
manifest_content += f"**{resource_type}**: "
if codes:
manifest_content += f"Available codes include: {', '.join(codes)}\\n"
else:
manifest_content += "Present (no specific codes found)\\n"
prompt = (
"SYSTEM INSTRUCTION: think silently if needed.\\nUSER QUESTION:"
f" {user_question}\\n\\nPATIENT DATA"
f" MANIFEST:\\n{manifest_content}\\n\\nYou are a medical assistant"
" analyzing a patient's FHIR data manifest to answer a user"
" question.\\nBased on the user question, identify the specific FHIR"
" resource types from the manifest that are most likely to contain the"
" information needed to answer the question.\\nOutput a JSON list of"
" the relevant resource types. Do not include any other text or"
' formatting.\\nExample:\n["Condition", "Observation",'
' "MedicationRequest"]\n'
)
print(
"--- identify_relevant_resource_types PROMPT"
f" ---\n{prompt}\n------------------------------------------"
)
response_str = llm.invoke(prompt, **_LLM_INVOKE_ARGS)
print(
"--- identify_relevant_resource_types RESPONSE"
f" ---\n{response_str}\n-------------------------------------------"
)
try:
relevant_resource_types = json.loads(strip_json_decoration(response_str))
except json.JSONDecodeError:
print(
"Could not decode JSON response for relevant resource types:"
f" {response_str}"
)
relevant_resource_types = []
print(
"Relevant Resource Types Identified:"
f" {', '.join(relevant_resource_types)}"
)
return {
"relevant_resource_types": relevant_resource_types,
"sdt_idx": 0,
"tool_calls_to_execute": [],
}
def announce_sdt_node(state):
sdt_idx = state["sdt_idx"]
relevant_resource_types = state.get("relevant_resource_types", [])
resource_type = relevant_resource_types[sdt_idx]
manifest = state.get("patient_fhir_manifest", {})
resource_manifest = manifest.get(resource_type, [])
print(f"Announcing data selection for {resource_type}")
return {
"resource_type_to_process": resource_type,
"resource_manifest_codes": resource_manifest,
}
def select_data_to_retrieve(state):
"""Uses the manifest and relevant resource types to determine which FHIR resources to retrieve."""
sdt_idx = state["sdt_idx"]
manifest = state.get("patient_fhir_manifest", {})
relevant_resource_types = state.get("relevant_resource_types", [])
tools_string = render_text_description([get_patient_fhir_resource])
resource_type = relevant_resource_types[sdt_idx]
print(f"Data Selection for {resource_type}")
if resource_type not in manifest:
print(f"No data found for {resource_type} in the manifest.")
return {"sdt_idx": sdt_idx + 1, "resource_type_processed": resource_type}
manifest_content = f"**{resource_type}**: "
if len(manifest.get(resource_type, [])) > 0:
manifest_content += (
f"Available codes include: {', '.join(manifest[resource_type])}\\n"
)
else:
manifest_content += "Present (no specific codes found)\\n"
prompt = (
"SYSTEM INSTRUCTION: think silently if needed.\\n"
+ "FOR CONTEXT ONLY, USER QUESTION:"
f" {state['messages'][1].content}\\n\\n"
+ f"PATIENT DATA MANIFEST: {manifest_content}\\n\\n"
+ "You are a specialized API request generator. Your SOLE task is to"
" output a JSON of a tool call to gather the necessary information"
" to answer the user's question. Respond with ONLY a JSON, no"
" explanations or prose.\\n"
+ f"Your available tool is:\\n{tools_string}\\n\\n"
+ f"**At this stage you can only call {resource_type}.**\\n"
+ "EXAMPLE:\\n"
+ '{\\"name\\": \\"get_patient_fhir_resource\\", \\"args\\":'
' {\\"patient_id\\": \\"some-patient-id\\",'
' \\"fhir_resource\\": \\"'
+ resource_type
+ '\\", \\"filter_code\\": \\"csv-codes-from-manifest\\"}}'
)
print(
f"--- select_data_to_retrieve PROMPT ({resource_type})"
f" ---\n{prompt}\n------------------------------------------"
)
response_str = llm.invoke(prompt, **_LLM_INVOKE_ARGS)
print(
f"--- select_data_to_retrieve RESPONSE ({resource_type})"
f" ---\n{response_str}\n-------------------------------------------"
)
try:
tool_call = json.loads(strip_json_decoration(response_str))
tool_call["args"]["fhir_store_url"] = fhir_store_url
return {
"tool_calls_to_execute": [{**tool_call, "id": resource_type}],
"sdt_idx": sdt_idx + 1,
"resource_type_processed": resource_type,
}
except json.JSONDecodeError:
print(
f"Could not decode JSON response for {resource_type}: {response_str}"
)
# If we fail to decode, we just skip this resource type.
return {"sdt_idx": sdt_idx + 1, "resource_type_processed": resource_type}
def sdt_conditional_edge(state):
if state["sdt_idx"] < len(state["relevant_resource_types"]):
return "announce_sdt"
return "init_edr_idx"
def init_edr_idx_node(state):
return {"edr_idx": 0}
def init_edr_conditional_edge(state):
if state["tool_calls_to_execute"]:
return "announce_retrieval"
return "final_answer"
def announce_retrieval_node(state):
edr_idx = state["edr_idx"]
tool_calls = state.get("tool_calls_to_execute", [])
tool_call = tool_calls[edr_idx]
resource_type = tool_call.get("id", "unknown_resource")
print(f"Announcing retrieval of {resource_type}")
return {"resource_type_to_retrieve": resource_type}
def execute_data_retrieval(state):
"""Executes the planned tool calls and summarizes the output."""
edr_idx = state["edr_idx"]
tool_calls = state.get("tool_calls_to_execute", [])
tool_call = tool_calls[edr_idx]
resource_type = tool_call.get("id", "unknown_resource")
print(f"Fetching FHIR data for {resource_type}")
tool_output_list = data_retrieval_tool_node.invoke(
[AIMessage(content="", tool_calls=[tool_call])]
)
if not tool_output_list:
print(f"No tool output received for {resource_type}")
return {
"resource_type_retrieved": resource_type,
"summary_generated": False,
"fhir_tool_output": "",
}
tool_output = tool_output_list[0].content
return {
"resource_type_retrieved": resource_type,
"summary_generated": True,
"fhir_tool_output": tool_output,
}
def announce_summarization_node(state):
resource_type = state["resource_type_retrieved"]
print(f"Announcing summarization of {resource_type}")
return {"resource_being_summarized": resource_type}
def summarize_node(state: AgentState) -> dict:
if not state["summary_generated"]:
return {"edr_idx": state["edr_idx"] + 1}
resource_type = state["resource_type_retrieved"]
tool_output = state["fhir_tool_output"]
concise_facts_prompt = (
"SYSTEM INSTRUCTION: think silently if needed.\\nFOR CONTEXT ONLY,"
f" USER QUESTION: {state['messages'][1].content}\\n\\nTOOL"
f" OUTPUT:\\n{tool_output}\\n\\nYou are a fact summarizing agent."
" Your output will be used to answer the USER QUESTION.\\nCollect"
" from the 'TOOL OUTPUT' facts ONLY if it is relevant to answer the"
" USER QUESTION.\\nWrite a very concise English summary, only facts"
" relevant to the user question. DO NOT OUTPUT JSON.\\nYou are not"
" authorized to answer the user question. Do not provide any output"
" beyond concise facts. Filter out any facts which are not helpful"
" for the user question. Include date or date ranges. Only for the"
" most critical facts, include FHIR record references [record"
" type/record id]. For repeating multiple times provide summarize"
" and provide only a single reference and date range."
)
print(
f"--- summarize_node PROMPT ({resource_type})"
f" ---\n{concise_facts_prompt}\n------------------------------------------"
)
current_summary = llm.invoke(concise_facts_prompt, **_LLM_INVOKE_ARGS)
print(
f"--- summarize_node RESPONSE ({resource_type})"
f" ---\n{current_summary}\n-------------------------------------------"
)
return {
"tool_output_summary": [exclude_thinking_component(current_summary)],
"edr_idx": state["edr_idx"] + 1,
"resource_type_retrieved": resource_type,
}
def should_summarize_edge(state):
if state["summary_generated"]:
return "announce_summarization"
return "summarize_node"
def edr_conditional_edge(state):
if state["edr_idx"] < len(state["tool_calls_to_execute"]):
return "announce_retrieval"
return "final_answer"
def get_final_answer(state):
"""If we have enough data, this node generates the final answer."""
summary = "\\n\\n".join(state["tool_output_summary"])
prompt = (
"Synthesize all information from the 'SUMMARIZED INFORMATION' to"
" provide a comprehensive final answer. Preserve relevant FHIR"
" references.\\n\\nUSER QUESTION:"
f" {state['messages'][1].content}\\n\\nSUMMARIZED INFORMATION:"
f" {summary}\\n\\nFinal Answer using markdown:"
)
print(
"--- get_final_answer PROMPT"
f" ---\n{prompt}\n------------------------------------------"
)
response = llm.invoke(prompt, **_LLM_INVOKE_ARGS)
print(
"--- get_final_answer RESPONSE"
f" ---\n{response}\n-------------------------------------------"
)
return {"messages": [AIMessage(content=response)]}
workflow = StateGraph(AgentState)
workflow.add_node(
"generate_manifest_tool_call", generate_manifest_tool_call_node
)
workflow.add_node(
"execute_manifest_tool_call", execute_manifest_tool_call_node
)
workflow.add_node(
"identify_relevant_resource_types", identify_relevant_resource_types
)
workflow.add_node("announce_sdt", announce_sdt_node)
workflow.add_node("select_data_to_retrieve", select_data_to_retrieve)
workflow.add_node("init_edr_idx", init_edr_idx_node)
workflow.add_node("announce_retrieval", announce_retrieval_node)
workflow.add_node("execute_data_retrieval", execute_data_retrieval)
workflow.add_node("announce_summarization", announce_summarization_node)
workflow.add_node("summarize_node", summarize_node)
workflow.add_node("final_answer", get_final_answer)
workflow.set_entry_point("generate_manifest_tool_call")
workflow.add_edge("generate_manifest_tool_call", "execute_manifest_tool_call")
workflow.add_edge(
"execute_manifest_tool_call", "identify_relevant_resource_types"
)
workflow.add_edge(
"identify_relevant_resource_types", "announce_sdt"
)
workflow.add_edge("announce_sdt", "select_data_to_retrieve")
workflow.add_conditional_edges(
"select_data_to_retrieve",
sdt_conditional_edge,
{
"announce_sdt": "announce_sdt",
"init_edr_idx": "init_edr_idx",
},
)
workflow.add_conditional_edges(
"init_edr_idx",
init_edr_conditional_edge,
{
"announce_retrieval": "announce_retrieval",
"final_answer": "final_answer",
},
)
workflow.add_edge("announce_retrieval", "execute_data_retrieval")
workflow.add_conditional_edges(
"execute_data_retrieval",
should_summarize_edge,
{
"announce_summarization": "announce_summarization",
"summarize_node": "summarize_node",
},
)
workflow.add_edge("announce_summarization", "summarize_node")
workflow.add_conditional_edges(
"summarize_node",
edr_conditional_edge,
{
"announce_retrieval": "announce_retrieval",
"final_answer": "final_answer",
},
)
workflow.add_edge("final_answer", END)
return workflow.compile()