Commit be234f51 authored by fushen's avatar fushen

Initial commit

parents
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="@localhost" uuid="48fd6c5e-3149-4657-9034-b7741bd917f5">
<driver-ref>redis</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>jdbc.RedisDriver</jdbc-driver>
<jdbc-url>jdbc:redis://localhost:6379/</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
<data-source source="LOCAL" name="1@localhost" uuid="a6a6e607-270c-4596-b92d-e7d1b2a189c3">
<driver-ref>redis</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>jdbc.RedisDriver</jdbc-driver>
<jdbc-url>jdbc:redis://localhost:6379/1</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" />
</profile>
</component>
\ No newline at end of file
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="fusion_agent" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="fusion_agent" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/fusion_agent.iml" filepath="$PROJECT_DIR$/.idea/fusion_agent.iml" />
</modules>
</component>
</project>
\ No newline at end of file
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
\ No newline at end of file
import json
from openai import OpenAI
from redis import Redis
from agent.memory.memory import RedisMemory
class BaseAgent:
def __init__(self, api_key, base_url, model_name, tools=None, memory=None, system_prompt=None):
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.model = model_name
self.tools = tools or []
self.memory = memory
self.messages = []
# 注入 system prompt 到对话历史最前
if system_prompt:
self.messages.append({"role": "system", "content": system_prompt})
# 加载历史记录(如果有)
if self.memory:
self.messages += self.memory.load()
def add_user_message(self, content):
self.messages.append({"role": "user", "content": content})
def add_assistant_message(self, content=None, tool_call=None):
msg = {"role": "assistant"}
if content:
msg["content"] = content
if tool_call:
msg["tool_calls"] = [tool_call]
self.messages.append(msg)
def add_tool_message(self, tool_call_id, name, content):
self.messages.append({
"role": "tools",
"tool_call_id": tool_call_id,
"name": name,
"content": content
})
def process_stream_chunk(self, delta):
"""普通模型流式解析,返回 (文本, 是否触发工具调用, tool_call信息dict)"""
text = ""
tool_call_triggered = False
tool_call_info = None
if hasattr(delta, "content") and delta.content:
text = delta.content
if hasattr(delta, "tool_calls") and delta.tool_calls:
tool_call_triggered = True
tc = delta.tool_calls[0]
tool_call_info = {
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", None),
"function_name": getattr(tc.function, "name", None) if tc.function else None,
"arguments": getattr(tc.function, "arguments", "") if tc.function else ""
}
return text, tool_call_triggered, tool_call_info
def call_tool(self, tool_call_info):
"""根据工具调用信息执行本地工具,示例只支持get_weather"""
if not tool_call_info:
return None
fname = tool_call_info.get("function_name")
args_str = tool_call_info.get("arguments", "")
try:
args = json.loads(args_str)
except Exception:
args = {}
if fname == "get_weather":
city = args.get("city", "")
return self.get_weather(city)
# 这里可扩展更多工具调用
return None
def get_weather(self, city):
fake_data = {
"上海": "晴,27℃,无降水",
"北京": "多云,22℃,轻微降水",
"广州": "阴,29℃,无降水",
}
return fake_data.get(city, "无法获取该城市的天气信息")
def chat_stream(self, user_input):
self.add_user_message(user_input)
tool_call_info = None
tool_call_triggered = False
print("【模型回应】")
response = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
tools=self.tools,
tool_choice="auto",
temperature=0.7,
stream=True,
)
for chunk in response:
delta = chunk.choices[0].delta
text, triggered, tcall_info, chunk_type = self.process_stream_chunk(delta)
if chunk_type == "think_start":
print("\n【模型思考阶段开始】")
elif chunk_type == "think_continue":
print(text, end="", flush=True) # ✅ 追加打印思考内容
elif chunk_type == "think_end":
print(f"【模型思考阶段结束】:{text}\n")
elif chunk_type == "content":
print(f"【模型回答】:{text}", end="", flush=True)
elif chunk_type == "tool_call":
tool_call_triggered = True
tool_call_info = tcall_info
print("\n【触发工具调用】")
print(f"函数名:{tool_call_info.get('function_name')}")
print(f"参数:{tool_call_info.get('arguments')}")
class Qwen3Agent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.in_think_phase = False
self.tool_call_acc = {
"id": None,
"type": None,
"function": {"name": None, "arguments": ""}
}
def process_stream_chunk(self, delta):
text = ""
chunk_type = None
# 处理文本内容和think阶段
if hasattr(delta, "content") and delta.content:
c = delta.content
if "<think>" in c:
self.in_think_phase = True
c = c.replace("<think>", "")
chunk_type = "think_start"
if "</think>" in c:
self.in_think_phase = False
c = c.replace("</think>", "")
chunk_type = "think_end"
if self.in_think_phase:
if chunk_type is None:
chunk_type = "think_continue"
text = c
else:
if chunk_type is None:
chunk_type = "content"
text = c
# 处理工具调用,累积参数字符串
if hasattr(delta, "tool_calls") and delta.tool_calls:
tool_call = delta.tool_calls[0]
if tool_call.id:
self.tool_call_acc["id"] = tool_call.id
if tool_call.type:
self.tool_call_acc["type"] = tool_call.type
if tool_call.function and tool_call.function.name:
self.tool_call_acc["function"]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
self.tool_call_acc["function"]["arguments"] += tool_call.function.arguments
chunk_type = "tool_call"
return text, chunk_type
def chat_stream(self, user_input):
self.add_user_message(user_input)
tool_call_triggered = False
tool_call_info = None
print("\n【模型回应】")
response = self.client.chat.completions.create(
model=self.model,
messages=self.messages,
tools=self.tools,
tool_choice="auto",
temperature=0.7,
stream=True,
)
for chunk in response:
delta = chunk.choices[0].delta
text, chunk_type = self.process_stream_chunk(delta)
if chunk_type == "think_start":
print("\n【模型思考阶段开始】", end="", flush=True)
elif chunk_type == "think_continue":
print(text, end="", flush=True)
elif chunk_type == "think_end":
print(text, flush=True)
print("【模型思考阶段结束】\n")
elif chunk_type == "content":
print(text, end="", flush=True)
elif chunk_type == "tool_call":
tool_call_triggered = True
# 工具调用信息在 self.tool_call_acc 中累积
tool_call_info = {
"id": self.tool_call_acc["id"],
"type": self.tool_call_acc["type"],
"function_name": self.tool_call_acc["function"]["name"],
"arguments": self.tool_call_acc["function"]["arguments"]
}
if tool_call_triggered and tool_call_info:
print("\n【触发工具调用】")
print(f"函数名:{tool_call_info['function_name']}")
print(f"参数:{tool_call_info['arguments']}")
# 调用工具
# result = self.call_tool(tool_call_info)
# 添加对话消息等后续逻辑...
import logging
# 先关闭全部低于 WARNING 级别的日志
logging.basicConfig(level=logging.WARNING)
# 对某些第三方库单独降级日志,比如 openai、httpx、urllib3
logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
if __name__ == "__main__":
# 你根据实际API KEY和地址替换
API_KEY = "your_api_key"
BASE_URL = "http://180.163.119.106:16031/v1"
MODEL_NAME = "Fusion2-chat-v2.0" # 或者 qwen3模型名称
#
# API_KEY = "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"
# BASE_URL = "https://api.openai.com/v1"
# MODEL_NAME = "gpt-4" # 或者 qwen3模型名称
# 定义工具
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "获取指定城市的天气情况",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,用于查询天气"
}
},
"required": ["city"]
}
}
}
]
# 假设 session_id 和 agent_id 由上层控制
session_id = "session001"
agent_id = "weather_agent"
redis_client = Redis(host="localhost", port=6379,password='21221', decode_responses=True)
memory = RedisMemory(session_id="session001", agent_id="weather", redis_client=redis_client)
agent = Qwen3Agent(
api_key=API_KEY,
base_url=BASE_URL,
model_name=MODEL_NAME,
tools=tools,
memory=memory
)
user_question = "请帮我查询上海的天气"
# user_question = "写一篇200字的关于人工智能的文章"
agent.chat_stream(user_question)
import asyncio
import json
import re
import time
import uuid
from redis import Redis
from openai import AsyncOpenAI
from agent.memory.memory import RedisMemory
from agent.tools.tools import tool_manager, ToolManager
from agent.memory.memory import BaseMemory
from abc import ABC, abstractmethod
import json
class AsyncBaseAgent(ABC):
def __init__(self, api_key, base_url, model_name, tool_manager: ToolManager, memory: BaseMemory = None,
sys_prompt=None):
self.system_prompt = sys_prompt
self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
self.model = model_name
self.tool_manager = tool_manager
self.memory = memory
self.tool_calls_acc = {}
self.think = False
self.cur_stream_id = None
def build_messages(self, user_input: str = None, system_prompt: str = None, tool_msgs: list[dict] = None):
messages = []
think_prompt = '' if self.think else '/nothink'
if system_prompt:
messages.append({"role": "system", "content": system_prompt + think_prompt})
if self.memory:
messages.extend(self.memory.load())
# 需要保存的新消息列表
new_msgs_to_save = []
if user_input: # 仅当 user_input 有内容时才添加
user_msg = {"role": "user", "content": user_input}
messages.append(user_msg)
new_msgs_to_save.append(user_msg)
if tool_msgs:
messages.extend(tool_msgs)
new_msgs_to_save.extend(tool_msgs)
# 追加新消息到内存
if self.memory and new_msgs_to_save:
# 如果你的内存实现有 extend 方法,就用它,否则用循环 append
if hasattr(self.memory, "extend"):
self.memory.extend(new_msgs_to_save)
else:
for msg in new_msgs_to_save:
self.memory.append(msg)
return messages
def reload_messages(self):
messages = []
think_prompt = '' if self.think else '/nothink'
system_prompt = self.system_prompt
if system_prompt:
messages.append({"role": "system", "content": system_prompt + think_prompt})
if self.memory:
messages.extend(self.memory.load())
return messages
@abstractmethod
def process_stream_chunk(self, delta):
"""
留给子类实现的接口,用于处理流式返回 delta 内容。
应返回: text, chunk_type
"""
pass
async def stream(self, user_input: str = None, tool_calls: list[dict] = None):
if user_input is not None:
async for event in self.chat_stream(user_input):
yield event
elif tool_calls is not None:
async for event in self.continue_tool_response(tool_calls):
yield event
else:
raise ValueError("必须传入 user_input 或 tool_calls 之一")
# 首次对话
async def chat_stream(self, user_input: str):
def message_factory ():
return self.build_messages(user_input=user_input, system_prompt=self.system_prompt)
async for event in self.stream_response_from_messages(message_factory):
yield event
# 工具返回接力
async def continue_tool_response(self, tool_calls: list[dict], stream_id=None):
tasks = []
stream_id = str(uuid.uuid4())
for tool_info in tool_calls:
raw_args = tool_info["function"].get("arguments", "")
# 如果是空字符串,替换为 "{}"
if not raw_args.strip():
raw_args = "{}"
tool_info["function"]["arguments"] = raw_args # ✅ 同步更新回去
try:
args = json.loads(tool_info["function"]["arguments"])
except json.JSONDecodeError as e:
print(f"工具参数解析失败: {e}")
args = {}
async def timed_call(tool_id, name, args, ):
start_time = time.perf_counter()
try:
result = await self.tool_manager.call_tool(name, args)
except Exception as e:
result = {"error": str(e)}
end_time = time.perf_counter()
duration = end_time - start_time
print(f"🛠️ 工具 `{name}`(ID: {tool_id})执行完成,用时:{duration:.2f} 秒")
return tool_id, result
task = asyncio.create_task(
timed_call(tool_info["id"], tool_info["function"]["name"], args)
)
tasks.append(task)
# 并发执行所有工具调用
results = await asyncio.gather(*tasks)
# 构造 tool 消息
tool_msgs = []
for tool_id, result in results:
d = {
"role": "tool",
"tool_call_id": tool_id,
"content": json.dumps(result, ensure_ascii=False),
}
tool_msgs.append(d)
yield {
"stream_id": str(uuid.uuid4()),
"type": "tool_calls_result",
"stream_group_id": self.cur_stream_id,
"info": d
}
# 构建新的消息上下文
def message_factory():
return self.build_messages(user_input=None, system_prompt=self.system_prompt, tool_msgs=tool_msgs)
# 继续流式响应
async for event in self.stream_response_from_messages(message_factory):
yield event
async def append_system_reflection_message(self, messages: list[dict]):
"""
在messages末尾添加一段系统消息,引导模型用自然语言对当前上下文进行分析和反思,
并规划工具调用的步骤顺序,确保在正式操作前有清晰的思考和计划。
"""
new_messages = messages + [{
"role": "system",
"content": (
"🧠【你的任务】:你正在进行**思考阶段**,请根据已有对话和工具调用结果,**用自然语言描述你下一步打算做什么**。\n\n"
"📌【规则】:\n"
"- ❌ 禁止直接回答用户问题;\n"
"- ❌ 禁止介绍或列出你有哪些工具;\n"
"- ❌ 禁止输出代码、JSON 或结构体;\n\n"
"📉【失败处理】:\n"
"- 工具失败时,优先分析失败原因;\n"
"- 参数问题请说明如何修正;\n"
"- 信息不足请主动向用户提问;\n\n"
"🚀【执行策略】:\n"
"- 多个无依赖任务应并行执行;\n"
"- 有依赖的请说明执行顺序;\n"
"- 可合并、去重、优化的,请合理规划。\n\n"
"✅【输出内容】(仅自然语言):\n"
"- 你对当前任务的理解;\n"
"- 你的下一步计划(调用什么、顺序、是否并行);\n"
"- 如果信息不足,你打算问什么。\n\n"
"⚠️ 请**只输出自然语言描述的计划内容**,禁止包含任何调用结构或代码格式。"
) + ('' if self.think else '\n/nothink')
}]
new_messages = new_messages[1:]
async for result in self._handle_streaming_response(new_messages):
if result["type"] in ("done", "tool_calls"):
msg = result.get("msg")
msg['role']='system'
self._append_to_memory(msg)
result['think'] = True
yield result
async def stream_response_from_messages(self, messages_factory):
messages = messages_factory()
# # #对当前上下文进行系统反思,并将思考的结果作为一条系统消息
async for result in self.append_system_reflection_message(messages):
yield result
messages = self.reload_messages()
# 调用提取后的处理方法
async for result in self._handle_streaming_response(messages):
if result["type"] in ("done", "tool_calls"):
msg = result.get("msg") or {
"role": "assistant",
"tool_calls": result.get("info")
}
self._append_to_memory(msg)
yield result
def _append_to_memory(self, message: dict):
"""
统一 memory 存储方法,便于扩展管理。
"""
self.memory.append(message)
async def _handle_streaming_response(self, messages: list[dict]):
stream_id = str(uuid.uuid4())
tools = await self.tool_manager.list_tools()
tool_call_active = False
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=tools,
tool_choice="auto",
temperature=0.7,
stream=True,
)
full_text = ""
async for chunk in response:
delta = chunk.choices[0].delta
text, chunk_type = self.process_stream_chunk(delta)
if chunk_type == "content":
full_text += text
yield {
"stream_id": stream_id,
"type": "content",
"text": text,
"stream_group_id": self.cur_stream_id
}
elif chunk_type == "tool_call":
tool_call_active = True
# ✅ 每片都返回完整 tool_calls_acc(实时更新版)
yield {
"stream_id": stream_id,
"type": "tool_call_stream",
"info": list(self.tool_calls_acc.values()),
"stream_group_id": self.cur_stream_id
}
# ✅ 全部 tool_call 结束后返回最终结构(带校验)
if tool_call_active and self.tool_calls_acc:
tool_calls = list(self.tool_calls_acc.values())
self.tool_calls_acc = {}
for tool_call in tool_calls:
if (
tool_call.get("type") == "function"
and isinstance(tool_call.get("function"), dict)
):
raw_args = tool_call["function"].get("arguments", "")
try:
if not raw_args or not raw_args.strip():
raise ValueError("空字符串")
json.loads(raw_args)
except (json.JSONDecodeError, ValueError) as e:
print(f"arguments 非法,已替换为 '{{}}':{e},tool_call id: {tool_call.get('id')}")
tool_call["function"]["arguments"] = "{}"
assistant_tool_msg = {
"role": "assistant",
"tool_calls": tool_calls
}
yield {
"stream_id": stream_id,
"type": "tool_calls",
"info": tool_calls,
"stream_group_id": self.cur_stream_id,
"msg": assistant_tool_msg
}
return
# ✅ 没有工具调用时的普通 content 回复
final_msg = {"role": "assistant", "content": full_text}
yield {
"stream_id": stream_id,
"type": "done",
"stream_group_id": self.cur_stream_id,
"msg": final_msg
}
def extract_tool_call_info(self, delta):
tc = delta.tool_calls[0]
return {
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", None),
"function_name": getattr(tc.function, "name", None) if tc.function else None,
"arguments": getattr(tc.function, "arguments", "") if tc.function else ""
}
async def _handle_tool_call(self, tool_call_info):
print("\n【触发工具调用】")
print(f"函数名:{tool_call_info['function_name']}")
print(f"参数:{tool_call_info['arguments']}")
try:
args = json.loads(tool_call_info["arguments"])
except json.JSONDecodeError as e:
print(f"参数解析失败: {e}")
args = {}
result = await self.tool_manager.call_tool(tool_call_info["function_name"], args)
print(f"\n【工具调用结果】:{result}")
def handle_special_chunk(self, chunk_type, text):
"""
子类可以重写这个方法处理特殊类型的输出(如 <think>)
"""
pass
class AsyncQwen3Agent(AsyncBaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.in_think_phase = False
self.start_tag_buffer = ""
self.end_tag_buffer = ""
def process_stream_chunk(self, delta):
text = ""
chunk_type = None
# ========= 普通文本处理(content 或 <think>) =========
if hasattr(delta, "content") and delta.content:
c = delta.content
if "<think>" in c:
self.in_think_phase = True
chunk_type = "think_start"
elif "</think>" in c:
self.in_think_phase = False
chunk_type = "think_end"
elif self.in_think_phase:
chunk_type = "think_continue"
else:
chunk_type = "content"
text = c
return text, chunk_type
# ========= 工具调用流式处理 =========
if hasattr(delta, "tool_calls") and delta.tool_calls:
if not hasattr(self, "tool_calls_acc"):
self.tool_calls_acc = {}
for tool_call in delta.tool_calls:
index = tool_call.index
is_new = index not in self.tool_calls_acc
if is_new:
self.tool_calls_acc[index] = {
"id": tool_call.id or "",
"type": tool_call.type or "function",
"function": {
"name": "",
"arguments": ""
}
}
acc = self.tool_calls_acc[index]
if tool_call.id:
acc["id"] = tool_call.id
if tool_call.type:
acc["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
acc["function"]["name"] += tool_call.function.name
if tool_call.function.arguments:
# print(tool_call.function.arguments)
acc["function"]["arguments"] += tool_call.function.arguments
# 判断 stream_status
stream_status = "streaming"
if is_new:
stream_status = "start"
elif acc["function"]["arguments"].strip().endswith("}") or \
acc["function"]["arguments"].strip().endswith(")") or \
acc["function"]["arguments"].strip().endswith("\""):
stream_status = "complete"
# 每个 tool_call 片段都返回
text = {
"index": index,
"id": acc["id"],
"type": acc["type"],
"function": {
"name": acc["function"]["name"],
"arguments": acc["function"]["arguments"]
},
"stream_status": stream_status # <== 加的字段
}
chunk_type = "tool_call"
return text, chunk_type
return text, chunk_type
def extract_tool_call_info(self, delta):
return {
"id": self.tool_call_acc["id"],
"type": self.tool_call_acc["type"],
"function_name": self.tool_call_acc["function"]["name"],
"arguments": self.tool_call_acc["function"]["arguments"]
}
def handle_special_chunk(self, chunk_type, text):
if chunk_type == "think_start":
print("\n【模型思考阶段开始】", end="", flush=True)
elif chunk_type == "think_continue":
print(text, end="", flush=True)
elif chunk_type == "think_end":
print(text, flush=True)
print("【模型思考阶段结束】\n")
from redis import Redis
from agent.llm.llm import AsyncQwen3Agent
from agent.llm.prompt import AGENT_PROMPT
from agent.memory.memory import RedisMemory
from agent.tools.tools import tool_manager, ToolManager, register_medical_tools,register_tools
API_KEY = "your_api_key"
# BASE_URL = "http://180.163.119.106:16031/v1"
BASE_URL = "http://180.163.119.106:16015/v1"
MODEL_NAME = "Fusion2-chat-v2.0"
#
#
# # #
# API_KEY = "sk-ba458ddd9fd649ea9442cc0461b496c6"
# BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# # MODEL_NAME = "qwen3-235b-a22b" # 或者 qwen3模型名称
# MODEL_NAME = "qwen3-30b-a3b" # 或者 qwen3模型名称
# Redis 客户端建议复用,不用每次都重新创建
redis_client = Redis(host="localhost", port=6379, password="21221", decode_responses=True)
def create_agent(session_id: str) -> AsyncQwen3Agent:
CONFIG = {
# "FastGPT": {
# "url": "https://cloud.fastgpt.cn/api/mcp/app/oj7BgLJmYh45J17YxC7EEFAj/mcp"
# },
# "filesystem": {
# "args": [
# "-y",
# "@modelcontextprotocol/server-filesystem",
# "/tmp"
# ],
# "command": "npx"
# }
# "redis": {
# "command": "npx",
# "args": [
# "-y",
# "@modelcontextprotocol/server-redis",
# "redis://:21221@localhost:6379/1"
# ]
# }
}
tool_manager = ToolManager(CONFIG)
# register_medical_tools(tool_manager)
register_tools(tool_manager)
memory = RedisMemory(session_id=session_id, agent_id="weather", redis_client=redis_client)
agent = AsyncQwen3Agent(
api_key=API_KEY,
base_url=BASE_URL,
model_name=MODEL_NAME,
tool_manager=tool_manager,
memory=memory,
sys_prompt=AGENT_PROMPT,
# sys_prompt='/nothink'
)
return agent
import asyncio
async def main(session_id: str):
user_input = "查下病人张三的信息?并且查下上海的天气"
agent = create_agent(session_id)
await agent.tool_manager.initialize()
async for event in agent.stream(user_input=user_input):
if event["type"] == "content":
print(event["text"], end="", flush=True)
elif event["type"] == "think_start":
print("\n【模型思考开始】", end="", flush=True)
elif event["type"] == "think_continue":
print(event["text"], end="", flush=True)
elif event["type"] == "think_end":
print(event["text"], flush=True)
print("【模型思考结束】")
elif event["type"] == "tool_calls":
tool_info = event["info"]
print(f"\n【触发工具调用】函数:{tool_info}")
async for followup in agent.stream(tool_calls=tool_info):
if followup["type"] == "content":
print(followup["text"], end="", flush=True)
elif followup["type"] == "done":
print("\n【对话完成】")
return
elif event["type"] == "done":
print("\n【对话完成】")
return
if __name__ == "__main__":
asyncio.run(main("session423"))
AGENT_PROMPT = '''
你是一个具备调用外部工具能力的智能助手。你可以访问多个工具,每个工具都有功能描述、输入参数和输出返回值。
你的目标是根据用户请求**自动完成任务**,如数据查询、信息提取、结构分析、图表生成等。
若任务涉及多个工具,请严格按照以下执行流程自动规划和执行。
---
🔁【任务执行流程】:
1. 🎯 用户意图识别:
- 准确理解用户目标,如查询、对比、分析、汇总等;
- 若涉及数据库,重点识别用户意图所需的表、字段、条件、聚合需求;
2. 📊 构建工具依赖图:
- 根据任务目标自动识别所需工具;
- 构建有向依赖图,优先执行无依赖工具;
- 并发执行独立工具,串行执行有依赖链条的工具;
- 示例:获取表结构 ➜ 构造 SQL ➜ 执行 SQL
3. 🔐 数据库结构感知(防止幻觉):
- ⛔【禁止幻觉】:不得构造任何数据库中不存在的表名或字段名;
- ✅【结构验证流程】:
1. 自动执行 `SHOW TABLES` 获取所有表名;
2. 对目标表执行 `DESC 表名` 获取字段结构与类型;
3. 所有 SQL 构造必须基于已验证存在的表与字段;
4. 所有字段名应合法(不含中文、空格或特殊字符);
- 💡 若字段信息不足,不允许猜测,应主动向用户说明并补充请求;
- ✅ 字段类型感知:
- 构造 SQL 条件时,应根据字段类型判断是否加引号;
- 如:数字字段不加引号,字符串字段加引号;
4. 💾 SQL 安全与合法性要求:
- 所有 SQL 构造必须防止注入、拼接错误;
- 所有字段名、表名应转义合法;
- 不允许执行 DROP/TRUNCATE 等危险语句;
- 支持多条 SQL 时,需自动分句,确保并发时不会共享连接(避免 readexactly 错误);
5. 📦 多数据处理与分页:
- 默认在 SQL 中添加 `LIMIT` 限制(如 LIMIT 100);
- 若返回数据超出限制,应仅展示前若干条,并提示用户可分页查询;
- 支持添加分页参数(如 offset / page)构造 SQL;
6. 🔄 工具失败处理与重试策略:
- 工具调用失败时,若错误非结构性或系统错误,可自动重试一次;
- 若因参数错误或结构缺失导致失败,应修正参数后重试;
- 若返回值非结构化字符串,应尝试提取关键内容,必要时降级为提示文字;
7. ❓ 参数缺失时的主动提问:
- 若某工具必需参数缺失,且上下文无法推理,应合并为一句自然语言提问;
- 不要逐一提问多个字段,应尽量合并提示并减少对话轮数;
8. 📬 最终统一输出用户可读结果:
- 工具全部执行完成后,再统一返回最终结果;
- 查询任务应返回总数+样例数据;
- 中间调用过程默认不展示(如需展示需用户明确请求);
---
🧱【JSON 构造与格式规范】:
- ✅ 所有工具参数应为严格合法 JSON 格式;
- ✅ 所有字符串使用双引号 `"` 包裹;
- ✅ 不允许:
- 单引号 `'`;
- 尾逗号;
- 注释;
- 非法值如 `undefined`、`NaN`、`Infinity`;
- ✅ 支持嵌套 JSON,须逐级验证,确保 `JSON.parse()` 可成功解析;
- ❗禁止 Python 风格 dict、代码片段、含函数/表达式的结构;
- ✅ 可在工具调用前输出如:“以下 JSON 字符串已校验通过,可安全传入工具。”
---
🚫【禁止行为】:
- ❌ 跳过工具依赖直接调用;
- ❌ 使用未验证的表或字段构造 SQL(即禁止幻觉);
- ❌ 在结构未获取前构造查询;
- ❌ 返回无法被 JSON 解析的数据;
- ❌ 返回代码结构或工具细节(除非用户请求);
- ❌ 在工具调用失败后立即盲目重复相同调用;
---
📎【工具注册结构(供参考)】:
每个工具应提供如下元信息:
- 名称:唯一识别符;
- 描述:一句话说明功能;
- 输入参数(JSON Schema);
- 输出结构示例(JSON 格式);
---
✅【补充说明】:
- 工具调用过程可并发执行(如批量执行 `DESC`),避免串行拖慢响应;
- 若多个 SQL 查询依赖独立连接,应避免共享同一个连接(使用连接池或多连接并发);
- 所有调用任务均应以任务完成为目标,不应输出 Agent 的工具能力说明;
- 对于结构不满足的请求,应主动告知用户缺失信息,而非猜测生成;
'''
# memory/base.py
from abc import ABC, abstractmethod
from typing import List, Dict
class BaseMemory(ABC):
@abstractmethod
def load(self) -> List[Dict]: ...
@abstractmethod
def save(self, messages: List[Dict]): ...
@abstractmethod
def clear(self): ...
@abstractmethod
def append(self, message: Dict): ...
# memory/redis_memory.py
import json
class RedisMemory(BaseMemory):
def __init__(self, session_id: str, agent_id: str, redis_client):
self.key = f"memory:{session_id}:{agent_id}"
self.redis = redis_client
def load(self):
raw = self.redis.get(self.key)
if raw:
try:
return json.loads(raw)
except Exception as e:
print(f"[Memory Load Error] {e}")
return []
return []
def save(self, messages):
try:
self.redis.set(self.key, json.dumps(messages, ensure_ascii=False))
except Exception as e:
print(f"[Memory Save Error] {e}")
def clear(self):
self.redis.delete(self.key)
def append(self, message):
messages = self.load()
messages.append(message)
self.save(messages)
def extend(self, messages_to_add: list):
"""
批量追加多条消息
"""
messages = self.load()
messages.extend(messages_to_add)
self.save(messages)
def __repr__(self):
return f"<RedisMemory key={self.key} count={len(self.load())}>"
# ✅ 全局模拟 Redis 的共享内存字典
in_memory_store: Dict[str, List[Dict]] = {}
class InMemoryMemory(BaseMemory):
def __init__(self, session_id: str, agent_id: str):
self.key = f"memory:{session_id}:{agent_id}"
def load(self) -> List[Dict]:
return in_memory_store.get(self.key, []).copy()
def save(self, messages: List[Dict]):
in_memory_store[self.key] = messages.copy()
def clear(self):
in_memory_store.pop(self.key, None)
def append(self, message: Dict):
messages = in_memory_store.get(self.key, [])
messages.append(message)
in_memory_store[self.key] = messages
def extend(self, messages_to_add: List[Dict]):
messages = in_memory_store.get(self.key, [])
messages.extend(messages_to_add)
in_memory_store[self.key] = messages
def __repr__(self):
count = len(in_memory_store.get(self.key, []))
return f"<InMemoryMemory key={self.key} count={count}>"
\ No newline at end of file
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
import asyncio
import json
from agent.llm.llm import AsyncQwen3Agent
from agent.llm import main
class SessionManager:
def __init__(self, session_id: str):
self.session_id = session_id
self.state = "idle"
self.queue = asyncio.Queue()
self.main_agent: AsyncQwen3Agent = main.create_agent(session_id)
self.history = []
self.connections: set[WebSocket] = set() # 多个连接
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.connections.add(websocket)
def disconnect(self, websocket: WebSocket):
self.connections.discard(websocket)
async def broadcast(self, message: dict):
text = json.dumps(message)
for conn in self.connections.copy():
try:
await conn.send_text(text)
except Exception as e:
print(f"移除连接: {e}")
self.disconnect(conn)
async def _handle_stream(self, content: str = None, tool_calls: dict = None):
stream = (
self.main_agent.stream(user_input=content)
if content else
self.main_agent.stream(tool_calls=tool_calls)
)
async for event in stream:
event["session_id"] = self.session_id
# 推送当前事件
await self.queue.put(event)
await self.broadcast(event)
# 如果触发了工具调用,则递归继续处理
if event.get("type") == "tool_calls":
tool_info = event["info"]
print(f"\n【触发工具调用】函数:{tool_info}")
# 递归处理工具调用流
await self._handle_stream(tool_calls=tool_info)
async def start(self, content: str = None, tool_calls: dict = None):
self.state = "running"
try:
await self._handle_stream(content=content, tool_calls=tool_calls)
except Exception as e:
print(f"[错误] 处理流时出错: {e}")
self.state = "completed"
else:
self.state = "completed"
async def send_history_and_current(self, websocket: WebSocket):
for msg in self.history:
await websocket.send_text(json.dumps(msg))
# mcp_client/base.py
from abc import ABC, abstractmethod
from typing import Any
from abc import ABC, abstractmethod
from typing import Any
from abc import ABC, abstractmethod
from typing import Any
class BaseMcpClient(ABC):
def __init__(self, name: str, config: dict[str, Any]) -> None:
self.name = name
self.config = config
self._tools_cache: list[Any] | None = None # 缓存工具列表
async def __aenter__(self) -> "BaseMcpClient":
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.cleanup()
@abstractmethod
async def initialize(self) -> None:
...
@abstractmethod
async def list_tools(self) -> list[Any]:
...
@abstractmethod
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
retries: int = 2,
delay: float = 1.0,
) -> Any:
...
@abstractmethod
async def cleanup(self) -> None:
...
async def tools(self) -> list[Any]:
"""统一的工具列表访问接口,带缓存"""
if self._tools_cache is None:
tools_response = await self.list_tools()
tools = []
for tool in tools_response:
tools.append(
Tool(tool.name, tool.description, tool.inputSchema)
)
self._tools_cache = tools
return self._tools_cache
def clear_tool_cache(self) -> None:
"""手动清除工具缓存"""
self._tools_cache = None
from .stdio_client import StdioMcpClient
from .http_client import HttpMcpClient
from .base import BaseMcpClient
def create_mcp_client(name: str, config: dict[str, Any]) -> BaseMcpClient:
if "url" in config:
return HttpMcpClient(name, config)
elif "command" in config and "args" in config:
return StdioMcpClient(name, config)
else:
raise ValueError(f"Invalid config for MCP client '{name}': {config}")
class Tool:
"""Represents a tools with its properties and formatting."""
def __init__(
self, name: str, description: str, input_schema: dict[str, Any]
) -> None:
self.name: str = name
self.description: str = description
self.input_schema: dict[str, Any] = input_schema
def format_for_llm(self) -> str:
"""Format tools information for LLM.
Returns:
A formatted string describing the tools.
"""
args_desc = []
if "properties" in self.input_schema:
for param_name, param_info in self.input_schema["properties"].items():
arg_desc = (
f"- {param_name}: {param_info.get('description', 'No description')}"
)
if param_name in self.input_schema.get("required", []):
arg_desc += " (required)"
args_desc.append(arg_desc)
return f"""
Tool: {self.name}
Description: {self.description}
Arguments:
{chr(10).join(args_desc)}
"""
def to_dict(self) -> dict[str, Any]:
"""Serialize tools to OpenAI function calling schema."""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.input_schema
}
}
class MCPClientProxy:
def __init__(self,tools):
self.tools = tools
def list_tools(self):
tools = []
for tool in self.tools:
tools.append({
"type":'function',
"function":{
"name":tool['name'],
"description":tool['description'],
"parameters":tool['inputSchema']
}
})
return tools
async def call_tool(self,tool_name,arguments):
from HotCode.core.globals import tool_set, tool_set_lock
for tool in self.tools:
if tool['name'] == tool_name:
async with tool_set_lock:
tool_client = tool_set.get(tool['toolSetId'])
if not tool_client:
raise Exception(f"{tool_name} 所属的工具集{tool['toolSetId']}不存在或未发布")
async with tool_client as client:
return await client.call_tool(tool_name,arguments)
return None
\ No newline at end of file
# mcp_client/http_client.py
import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from .base import BaseMcpClient
import asyncio
from typing import Any
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from .base import BaseMcpClient
class HttpMcpClient(BaseMcpClient):
def __init__(self, name: str, config: dict[str, Any]) -> None:
super().__init__(name, config)
self._cleanup_lock = asyncio.Lock()
self._stream_ctx = None
self._session_ctx = None
self._stream = None
self.session: ClientSession | None = None
async def initialize(self) -> None:
if self.session:
return # 已初始化,跳过
url = self.config.get("url")
if not url:
raise ValueError("Missing 'url' in config.")
# 使用 AsyncExitStack 来统一管理异步上下文
self._exit_stack = AsyncExitStack()
stream_ctx = streamablehttp_client(url)
stream = await self._exit_stack.enter_async_context(stream_ctx)
read_stream, write_stream, _ = stream
session_ctx = ClientSession(read_stream, write_stream)
self.session = await self._exit_stack.enter_async_context(session_ctx)
await self.session.initialize()
async def list_tools(self) -> list[Any]:
if not self.session:
raise RuntimeError(f"Client {self.name} not initialized")
tools_response = await self.session.list_tools()
tools = []
for item in tools_response:
if isinstance(item, tuple) and item[0] == "tools":
tools.extend(item[1])
return tools
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
retries: int = 2,
delay: float = 1.0,
) -> Any:
url = self.config.get("url")
if not url:
raise ValueError("Missing 'url' in config.")
for attempt in range(1, retries + 1):
try:
async with streamablehttp_client(url) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
return (await session.call_tool(tool_name, arguments)).content[0].text
except Exception as e:
if attempt < retries:
await asyncio.sleep(delay)
else:
raise RuntimeError(f"Tool call failed after {retries} attempts: {e}") from e
async def cleanup(self) -> None:
async with self._cleanup_lock:
if self._exit_stack is not None:
await self._exit_stack.aclose()
self._exit_stack = None
self.session = None
import asyncio
import logging
from agent.tools.mcp.base import BaseMcpClient,create_mcp_client
# 配置日志
logging.basicConfig(level=logging.INFO)
# 示例配置:包含 HTTP 和 Stdio 客户端
CONFIG = {
"mcpServers": {
"FastGPT-mcp-6825886b058310fae2e58a3d": {
"url": "https://cloud.fastgpt.cn/api/mcp/app/og8ghYLRVsFMLi8GZ2nTFALn/mcp"
},
# "filesystem": {
# "command": "npx",
# "args": [
# "-y",
# "@modelcontextprotocol/server-filesystem",
# "C:\\Users\\byshe\\Desktop",
# ]
# },
# "FusionAI-mcp-d92407ae-ab04-4bce-991c-e1a603e5fe67": {
# "url": "http://192.168.120.59:4000/api/mcp_service/d92407ae-ab04-4bce-991c-e1a603e5fe67/mcp/"
# }
}
}
async def run():
clients: list[BaseMcpClient] = []
try:
# 初始化所有 MCP 客户端
for name, conf in CONFIG["mcpServers"].items():
client = create_mcp_client(name, conf)
await client.initialize()
logging.info(f"[{name}] Initialized.")
clients.append(client)
# 遍历客户端调用工具列表
for client in clients:
tools = await client.list_tools()
print(f"\n[{client.name}] 可用工具列表:")
for tool in tools:
print(f"- {tool}")
# # 示例:调用第一个工具(如果有)
# if tools:
# tool_name = tools[0].get("tool_name") if isinstance(tools[0], dict) else tools[0]
# result = await client.call_tool(tool_name, arguments={"query": "hello"})
# print(f"[{client.name}] 调用工具 {tool_name} 返回结果:\n{result}")
except Exception as e:
logging.error(f"运行中发生错误: {e}")
finally:
pass
# # 清理所有客户端资源
# for client in clients:
# await client.cleanup()
# logging.info(f"[{client.name}] 已清理资源.")
if __name__ == "__main__":
import logging
logging.getLogger("httpx").setLevel(logging.WARNING)
asyncio.run(run())
import asyncio
import logging
import traceback
from typing import Any
from agent.tools.mcp.base import BaseMcpClient, create_mcp_client
logger = logging.getLogger(__name__)
class McpProxy:
def __init__(self, config: dict[str, dict[str, Any]]):
self.config = config
self.clients: dict[str, BaseMcpClient] = {}
self.tool_map: dict[str, str] = {} # tool_name -> client_name
self.tool_cache: list[dict[str, Any]] = [] # 缓存聚合后的工具列表
async def initialize(self):
for name, conf in self.config.items():
client = create_mcp_client(name, conf)
await client.initialize()
logger.info(f"[{name}] 初始化成功")
self.clients[name] = client
await self._build_tool_cache()
async def _build_tool_cache(self):
"""构建工具列表缓存并建立 tool -> client 映射"""
self.tool_map.clear()
self.tool_cache.clear()
for name, client in self.clients.items():
try:
tools = await client.list_tools()
for tool in tools:
tool_name = tool.name
self.tool_map[tool_name] = name
self.tool_cache.append({
"type": 'function',
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.inputSchema
}
})
except Exception as e:
logger.warning(f"[{name}] 获取工具失败: {e}")
async def list_tools(self) -> list[dict[str, Any]]:
"""返回缓存中的工具列表,每项为 {'client': name, 'tool': tool_obj}"""
return self.tool_cache
async def refresh_tools(self):
"""强制重新获取所有工具信息并更新缓存"""
await self._build_tool_cache()
logger.info("工具列表已刷新")
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""自动查找并调用工具"""
try:
client_name = self.tool_map.get(tool_name)
if not client_name:
raise ValueError(f"未找到工具: {tool_name}")
client = self.clients[client_name]
logger.info(f"调用 [{client_name}] 的工具: {tool_name}")
res = await client.call_tool(tool_name, arguments)
except Exception as e:
traceback.print_exc()
raise e
return res
async def cleanup(self):
for name, client in self.clients.items():
await client.cleanup()
logger.info(f"[{name}] 已清理资源")
CONFIG = {
"FastGPT": {
"url": "https://cloud.fastgpt.cn/api/mcp/app/oj7BgLJmYh45J17YxC7EEFAj/mcp"
}
}
proxy = McpProxy(CONFIG)
if __name__ == "__main__":
import anyio
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
async def main():
await proxy.initialize()
# 获取缓存中的所有工具
tools = await proxy.list_tools()
print(tools)
# await proxy.initialize()
# tools = await proxy.list_tools()
# print(tools)
# # 调用某个工具(会自动找)
result = await proxy.call_tool("计算器", {"question": "1+1", "a": 1, "b": 1, "operator": "add"})
print("调用结果:", result)
# 刷新工具列表缓存
# await proxy.refresh_tools()
#
# await proxy.cleanup()
anyio.run(main)
import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any
from mcp import ClientSession, StdioServerParameters, types
from mcp.client.stdio import stdio_client
from .base import BaseMcpClient
class StdioMcpClient(BaseMcpClient):
def __init__(self, name: str, config: dict[str, Any]) -> None:
super().__init__(name, config)
self.session: ClientSession | None = None
self.exit_stack: AsyncExitStack = AsyncExitStack()
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
async def sampling_callback(self, message: types.CreateMessageRequestParams) -> types.CreateMessageResult:
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(type="text", text="模拟模型响应"),
model="gpt-3.5-turbo",
stopReason="endTurn",
)
async def initialize(self) -> None:
command = self.config.get("command")
args = self.config.get("args")
env = self.config.get("env", None)
if not command or not args:
raise ValueError(f"Invalid stdio config for client '{self.name}'")
try:
server_params = StdioServerParameters(command=command, args=args, env=env)
read_stream, write_stream, *_ = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
self.session = await self.exit_stack.enter_async_context(
ClientSession(read_stream, write_stream, sampling_callback=self.sampling_callback)
)
await self.session.initialize()
except Exception as e:
logging.error(f"Error initializing Stdio client {self.name}: {e}")
await self.cleanup()
raise
async def list_tools(self) -> list[Any]:
if not self.session:
raise RuntimeError(f"Client {self.name} not initialized")
tools_response = await self.session.list_tools()
tools = []
for item in tools_response:
if isinstance(item, tuple) and item[0] == "tools":
tools.extend(item[1])
return tools
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
retries: int = 2,
delay: float = 1.0,
) -> Any:
if not self.session:
raise RuntimeError(f"Client {self.name} not initialized")
attempt = 0
while attempt < retries:
try:
res = await self.session.call_tool(tool_name, arguments)
return res.content[0].text
except Exception as e:
attempt += 1
logging.warning(f"[{self.name}] Error calling tool '{tool_name}': {e} (attempt {attempt})")
if attempt < retries:
await asyncio.sleep(delay)
else:
raise
async def cleanup(self) -> None:
async with self._cleanup_lock:
try:
# 关闭会话相关资源
if self.session:
# await self.session.close()
self.session = None
# 关闭 exit_stack 管理的资源
await self.exit_stack.aclose()
except asyncio.CancelledError:
self.session = None
except Exception as e:
logging.error(f"Error during cleanup of client {self.name}: {e}")
self.session = None
# 新增异步上下文管理器协议
async def __aenter__(self):
await self.initialize()
return self
async def __aexit__(self, exc_type, exc, tb):
await self.cleanup()
import asyncio
from mcp.client.streamable_http import streamablehttp_client
from mcp import ClientSession
async def main():
try:
# 连接到 streamable HTTP 服务端
async with streamablehttp_client("https://cloud.fastgpt.cn/api/mcp/app/og8ghYLRVsFMLi8GZ2nTFALn/mcp") as (
read_stream,
write_stream,
_
):
# 创建客户端会话
async with ClientSession(read_stream, write_stream) as session:
# 初始化连接
await session.initialize()
list = await session.list_tools()
print(list)
#
# # 调用工具(如 echo 工具)
# tool_result = await session.call_tool("echo", {"message": "hello"})
#
# print("Tool result:", tool_result)
except asyncio.CancelledError:
print("任务被取消")
except Exception as e:
print("发生错误:", e)
if __name__ == "__main__":
asyncio.run(main())
import json
raw = '{"option": "{\"title\": {\"text\": \"示例柱状图\"},\"xAxis\": {\"type\": \"category\",\"data\": [\"类别A\",\"类别B\",\"类别C\"]},\"yAxis\": {\"type\": \"value\"},\"series\": [{\"type\": \"bar\",\"data\": [120,200,150]}]}}'
data = json.loads(raw) # 解析外层JSON
option = json.loads(data['option']) # 再解析 option 字符串
print(option)
from typing import List, Union, Optional, Literal, Dict, Any
from pydantic import BaseModel, root_validator
from pydantic import model_validator
# ===== 基础结构类型 =====
class Title(BaseModel):
text: Optional[str] = None
left: Optional[str] = None
top: Optional[str] = None
class Tooltip(BaseModel):
trigger: Optional[str] = None
axisPointer: Optional[Dict[str, Any]] = None
class Legend(BaseModel):
orient: Optional[str] = None
left: Optional[str] = None
data: Optional[List[str]] = None
class Grid(BaseModel):
left: Optional[Union[str, int]] = None
right: Optional[Union[str, int]] = None
top: Optional[Union[str, int]] = None
bottom: Optional[Union[str, int]] = None
containLabel: Optional[bool] = None
class XAxis(BaseModel):
type: Optional[str] = None # category, value, time, log
data: Optional[List[Union[str, int, float]]] = None
class YAxis(BaseModel):
type: Optional[str] = None
class Radar(BaseModel):
indicator: Optional[List[Dict[str, Union[str, int]]]] = None
shape: Optional[str] = None
# ===== Series 类型(通用) =====
class BaseSeries(BaseModel):
type: Literal['bar', 'line', 'pie', 'scatter', 'radar', 'gauge', 'heatmap']
name: Optional[str] = None
data: Optional[List[Union[int, float, Dict[str, Any]]]] = None
smooth: Optional[bool] = None
stack: Optional[str] = None
radius: Optional[Union[str, List[str]]] = None # for pie
center: Optional[List[str]] = None # for pie
roseType: Optional[str] = None # for pie
itemStyle: Optional[Dict[str, Any]] = None
lineStyle: Optional[Dict[str, Any]] = None
areaStyle: Optional[Dict[str, Any]] = None
label: Optional[Dict[str, Any]] = None
emphasis: Optional[Dict[str, Any]] = None
min: Optional[float] = None # for gauge
max: Optional[float] = None # for gauge
pointer: Optional[Dict[str, Any]] = None # for gauge
# ===== 最终 Option 配置结构 =====
class EChartsOption(BaseModel):
title: Optional[Title] = None
tooltip: Optional[Tooltip] = None
legend: Optional[Legend] = None
grid: Optional[Union[Grid, List[Grid]]] = None
xAxis: Optional[Union[XAxis, List[XAxis]]] = None
yAxis: Optional[Union[YAxis, List[YAxis]]] = None
radar: Optional[Radar] = None
series: List[BaseSeries]
color: Optional[List[str]] = None
backgroundColor: Optional[Union[str, Dict[str, Any]]] = None
animation: Optional[bool] = None
textStyle: Optional[Dict[str, Any]] = None
toolbox: Optional[Dict[str, Any]] = None
dataset: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None
class Config:
extra = "allow" # 允许传递未声明字段
@model_validator(mode='after')
def check_series_exists(cls, values):
if not values.series or len(values.series) == 0:
raise ValueError("配置项中必须包含至少一个 series。")
return values
import asyncio
import re
from datetime import datetime
from typing import Any, Awaitable, Callable
from agent.tools.mcp.mcp_proxy import McpProxy
from pathlib import Path
from markdown import markdown
from bs4 import BeautifulSoup
from docx import Document
from weasyprint import HTML
import tempfile
import asyncio
import inspect
from typing import (
Any,
Awaitable,
Callable,
get_type_hints,
Union,
get_origin,
get_args,
Literal,
Annotated,
)
from agent.tools.mcp.mcp_proxy import McpProxy
from crawl4ai import AsyncWebCrawler
from googlesearch import search
class ToolManager:
def __init__(self, config: dict[str, dict[str, Any]]):
self.id=99999999
self.proxy = McpProxy(config)
self.local_tools: dict[str, dict[str, Any]] = {}
self.local_tool_handlers: dict[str, Callable[..., Awaitable[Any]]] = {}
async def initialize(self):
await self.proxy.initialize()
def _python_type_to_json_schema(self, typ: type) -> dict[str, Any]:
origin = get_origin(typ)
args = get_args(typ)
if origin is Union:
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return self._python_type_to_json_schema(non_none_args[0])
else:
return {"anyOf": [self._python_type_to_json_schema(arg) for arg in non_none_args]}
elif typ is str:
return {"type": "string"}
elif typ in [int, float]:
return {"type": "number"}
elif typ is bool:
return {"type": "boolean"}
elif origin in [list, tuple]:
return {
"type": "array",
"items": self._python_type_to_json_schema(args[0]) if args else {"type": "string"}
}
elif origin is dict:
if len(args) == 2 and args[0] is str:
return {
"type": "object",
"additionalProperties": self._python_type_to_json_schema(args[1])
}
else:
return {"type": "object"}
elif origin is Literal:
return {"enum": list(args)}
elif origin is Annotated:
return self._python_type_to_json_schema(args[0])
return {"type": "string"}
def register_tool(
self,
name: str,
description: str,
input_schema: dict[str, Any],
handler: Callable[..., Awaitable[Any]]
):
self.local_tools[name] = {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": input_schema,
}
}
self.local_tool_handlers[name] = handler
def tool(self, func: Callable[..., Awaitable[Any]]):
sig = inspect.signature(func)
type_hints = get_type_hints(func)
properties = {}
required = []
for name, param in sig.parameters.items():
hint = type_hints.get(name, str)
schema = self._python_type_to_json_schema(hint)
properties[name] = {
"type": schema.get("type", "string"),
**schema,
"description": f"{name}",
}
if param.default is inspect.Parameter.empty:
required.append(name)
input_schema = {
"type": "object",
"properties": properties,
"required": required,
}
self.register_tool(
name=func.__name__,
description=func.__doc__ or "无描述",
input_schema=input_schema,
handler=func
)
return func
async def list_tools(self) -> list[dict[str, Any]]:
mcp_tools = await self.proxy.list_tools()
return mcp_tools + list(self.local_tools.values())
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
if tool_name in self.local_tool_handlers:
return await self.local_tool_handlers[tool_name](**arguments)
return await self.proxy.call_tool(tool_name, arguments)
async def refresh_tools(self):
await self.proxy.refresh_tools()
async def cleanup(self):
await self.proxy.cleanup()
from datetime import datetime
def register_medical_tools(tool_manager):
@tool_manager.tool
async def get_patient_id_by_name(patient_name: str) -> str:
"""通过姓名查询患者唯一ID"""
name_to_id = {
"张三": "P001",
"李四": "P002"
}
return name_to_id.get(patient_name, f"未找到名为 {patient_name} 的患者ID")
@tool_manager.tool
async def get_patient_profile(patient_id: str) -> str:
"""根据患者ID返回基本信息"""
patient_profiles = {
"P001": {"姓名": "张三", "年龄": 35, "性别": "男", "病史": "高血压"},
"P002": {"姓名": "李四", "年龄": 28, "性别": "女", "病史": "无"},
}
profile = patient_profiles.get(patient_id)
if not profile:
return f"未找到患者ID {patient_id} 的资料"
return f"患者 {profile['姓名']},年龄 {profile['年龄']} 岁,性别 {profile['性别']},病史:{profile['病史']}"
@tool_manager.tool
async def get_visit_history(patient_id: str) -> str:
"""查询患者的就诊记录"""
visits = {
"P001": [
{"日期": "2025-04-01", "科室": "内科", "摘要": "血压偏高"},
{"日期": "2025-05-12", "科室": "心内科", "摘要": "建议做动态心电图"},
],
"P002": [
{"日期": "2025-03-20", "科室": "妇科", "摘要": "常规体检"},
]
}
history = visits.get(patient_id)
if not history:
return f"未找到患者ID {patient_id} 的就诊记录"
return "\n".join([f"{v['日期']} {v['科室']}{v['摘要']}" for v in history])
@tool_manager.tool
async def get_recent_lab_results(patient_id: str) -> str:
"""根据患者ID返回最近一次化验结果"""
lab_results = {
"P001": {"血糖": "6.2 mmol/L", "胆固醇": "5.8 mmol/L"},
"P002": {"血糖": "4.9 mmol/L", "胆固醇": "4.3 mmol/L"},
}
results = lab_results.get(patient_id)
if not results:
return f"未找到患者ID {patient_id} 的化验结果"
return f"血糖:{results['血糖']},胆固醇:{results['胆固醇']}"
@tool_manager.tool
async def get_doctor_notes(patient_id: str) -> str:
"""返回医生最近的诊疗记录"""
notes = {
"P001": "建议控制饮食,定期测血压。",
"P002": "无异常,继续观察。",
}
return notes.get(patient_id, f"未找到患者ID {patient_id} 的医生记录")
@tool_manager.tool
async def get_medication_list(patient_id: str) -> str:
"""返回患者当前用药列表"""
medications = {
"P001": ["降压药A", "降压药B"],
"P002": ["维生素C"],
}
meds = medications.get(patient_id)
if not meds:
return f"{patient_id} 当前无记录用药"
return f"当前用药:{', '.join(meds)}"
@tool_manager.tool
async def schedule_appointment(patient_id: str, department: str, date: str) -> str:
"""为患者安排指定科室的预约"""
try:
datetime.strptime(date, "%Y-%m-%d")
except ValueError:
return "日期格式错误,请使用 YYYY-MM-DD 格式"
return f"已为患者 {patient_id} 安排在 {date}{department} 科室就诊"
def register_tools(tool_manager):
@tool_manager.tool
async def execute_sql(sql_list: list[str]) -> str:
"""
并发执行传入的多条 SQL 语句,每条语句使用独立连接,避免并发冲突。
"""
import aiomysql
import asyncio
if not sql_list or not isinstance(sql_list, list):
return "[无有效 SQL 语句列表]"
async def run_single_sql(sql_stmt: str):
try:
# 每个 SQL 独立连接
conn = await aiomysql.connect(
host='10.10.0.2',
port=6606,
user='root',
password='yiPmJHHlRa',
db='crm',
autocommit=True
)
async with conn.cursor() as cur:
await cur.execute(sql_stmt)
result = await cur.fetchall()
columns = [desc[0] for desc in cur.description] if cur.description else []
await conn.ensure_closed()
if not result:
return f"[SQL 执行成功] 无结果: {sql_stmt}"
formatted = "\n".join(
[", ".join(columns)] +
[", ".join(str(item) for item in row) for row in result]
)
return f"[SQL 执行成功] {sql_stmt}\n{formatted}"
except Exception as e:
return f"[SQL 执行错误] {sql_stmt}\n错误: {e}"
# 并发执行
results = await asyncio.gather(
*(run_single_sql(stmt.strip()) for stmt in sql_list if stmt.strip())
)
return "\n\n".join(results)
# @tool_manager.tool
async def show_rendered_element(code: str) -> str:
"""
用于向用户展示一个html元素,你可以使用tailwind
例如:
<div class="p-6 bg-white dark:bg-gray-900 text-gray-900 dark:text-gray-100 rounded-lg shadow-md">
<h1 class="text-3xl font-bold mb-4">欢迎使用 Tailwind CSS</h1>
</div>
"""
# 可选:这里进行代码安全检查或过滤
return '用户已经成功看到了渲染过后的元素'
@tool_manager.tool
async def show_rendered_chart(option: str ) :
"""
功能说明:
接收用户传入的 ECharts 配置项(合法的 YAML 格式字符串),
解析为 dict 后用于前端渲染任意类型的 ECharts 图表。
注意:
1. 不同图表类型的配置结构差异较大,必须严格按照 ECharts 官方文档对应图表类型的配置规范来填写。
2. 参数中必须包含至少一个 series,且 series 中的每个元素 type 字段需对应 ECharts 支持的图表类型。
3. 根据 series 中 type 的不同,必须校验对应配置项的必填字段和数据格式:
- **bar/line/scatter** 等类目轴图:需配置有效的 xAxis、yAxis,series.data 为数组。
- **pie/funnel**:无坐标轴,series.data 需为数组对象,每个对象需有 name 和 value。
- **radar**:需配置 radar 指标,series.data 为多维数组。
- **gauge**:需包含指针和刻度配置,series.data 通常为单个数值。
- 其他图表请参考 [ECharts 官方文档](https://echarts.apache.org/zh/option.html)。
4. 若配置不符合规范,返回详细的错误提示,指出具体缺失或格式错误的字段。
5. 解析 YAML 时请确保格式正确,否则提示格式错误。
6. 所有颜色值(如 color 字段)必须是字符串形式,必须用引号包裹,例如 "#5470c6"、"rgba(255, 0, 0, 0.5)" 等,避免被 YAML 解析器识别为注释。
参数:
- option: str,ECharts 配置的 YAML 格式字符串。
返回:
- 成功时返回提示“图表已展示”。
- 出错时返回错误信息字符串。
示例(柱状图和折线图混合):
title:
text: 示例多图表
tooltip: {}
xAxis:
type: category
data: ["A", "B", "C"]
yAxis:
type: value
series:
type: bar
data: [120, 200, 150]
type: line
data: [100, 180, 90]
"""
import json
import yaml
try:
if isinstance(option, str):
option = yaml.safe_load(option)
except Exception as e:
return {
"success":False,
"error":f"参数 option YAML 解析失败: {str(e)}"
}
if not isinstance(option, dict):
return {
"success":False,
"error":"参数格式错误:option 必须是 dict 或合法的 YAML 字符串"
}
# 这里你可以根据需要继续校验 option 是否包含 series 等
return {
"success": True,
"message": "Echart已经成功向用户展示",
"option": json.dumps(option, ensure_ascii=False)
}
@tool_manager.tool
async def search_and_extract(query: str, top_k: int = 5) -> list[dict]:
"""
搜索关键词并提取清洗后的网页正文,返回格式:
[{"url": "xxx", "text": "正文内容"}, ...]
参数:
- query: 搜索关键词
- top_k: 返回前 N 个搜索结果
"""
def clean_markdown(md_text: str) -> str:
"""清理 Markdown 内容,仅保留有效中文正文"""
md_text = re.sub(r'!\[.*?\]\(.*?\)', '', md_text)
md_text = re.sub(r'\[.*?\]', '', md_text)
md_text = re.sub(r'\(.*?\)', '', md_text)
md_text = re.sub(r'[ \t]+', '', md_text)
md_text = re.sub(r'\n\s*\n+', '', md_text)
md_text = re.sub(r'\n', '', md_text)
md_text = re.sub(r'[*#]', '', md_text)
md_text = re.sub(r'^\s*[\*\-+]\s+', '', md_text, flags=re.MULTILINE)
md_text = md_text.strip()
chinese_chars = re.findall(r'[\u4e00-\u9fff]', md_text)
ratio = len(chinese_chars) / len(md_text) if md_text else 0
return md_text if ratio >= 0.3 else ""
async def fetch_and_clean(url: str) -> dict:
try:
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url=url, strategy="auto")
markdown = result.markdown if result else ""
clean_text = clean_markdown(markdown)
return {"url": url, "text": clean_text or "[正文内容为空]"}
except Exception as e:
return {"url": url, "text": f"[抓取失败] {e}"}
try:
urls = list(search(query, num_results=top_k))
except Exception as e:
# 无 URL 可抓取,只能返回搜索失败信息
return [{"url": "", "text": f"[搜索失败] {e}"}]
tasks = [fetch_and_clean(url) for url in urls]
results = await asyncio.gather(*tasks)
return results
@tool_manager.tool
async def save_markdown_file(content: str, filename: str = "output", format: str = "pdf") -> dict:
"""
📄 将 Markdown 文本保存为 PDF 或 DOCX 文件。
🧠【用途说明】:
- 当任务涉及 **论文撰写**、**总结报告**、**生成文档内容** 等场景时,应调用此工具将 markdown 内容转换为文档格式进行保存;
- 可用于将系统生成的内容导出为 PDF 或 Word 文件,以供后续查看或下载。
⚙️【参数说明】:
- content: markdown 格式的文本内容;
- filename: 保存的文件名(不带后缀);
- format: 保存格式,支持 'pdf' 或 'docx'。
✅【返回结构】:
{
"path": "完整文件路径",
"message": "成功或失败信息"
}
🗂️【存储位置】:固定保存至 D:/tmp/markdown_outputs 目录下。
"""
try:
# 修改后的存储目录
output_dir = Path("D:/tmp/markdown_outputs")
output_dir.mkdir(parents=True, exist_ok=True)
file_path = output_dir / f"{filename}.{format}"
# 将 Markdown 转为 HTML
html_content = markdown(content)
if format == "pdf":
# 使用 weasyprint 转为 PDF
HTML(string=html_content).write_pdf(str(file_path))
elif format == "docx":
# 使用 python-docx 构建简单文档
soup = BeautifulSoup(html_content, "html.parser")
doc = Document()
for elem in soup.find_all(['p', 'h1', 'h2', 'h3', 'li']):
text = elem.get_text(strip=True)
if text:
doc.add_paragraph(text)
doc.save(str(file_path))
else:
return {"path": "", "message": f"[失败] 不支持的格式: {format}"}
return {"path": str(file_path), "message": "[成功] 文件已保存"}
except Exception as e:
return {"path": "", "message": f"[失败] {e}"}
CONFIG = {
# "FastGPT": {
# "url": "https://cloud.fastgpt.cn/api/mcp/app/oj7BgLJmYh45J17YxC7EEFAj/mcp"
# },
"filesystem": {
"args": [
"-y",
"@modelcontextprotocol/server-filesystem",
"/"
],
"command": "npx"
}
}
tool_manager = ToolManager(CONFIG)
async def main():
register_medical_tools(tool_manager)
# await tool_manager.initialize()
tools = await tool_manager.list_tools()
print(tools)
if __name__ == '__main__':
asyncio.run(main())
from .chat.router import router as chat_router
import uuid
from typing import Dict
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
import asyncio
import json
from agent.session import session
import importlib
router = APIRouter()
session_managers: Dict[str, session.SessionManager] = {}
@router.websocket("/ws/{session_id}")
async def websocket_handler(websocket: WebSocket, session_id: str):
if session_id not in session_managers:
importlib.reload(session)
session_managers[session_id] = session.SessionManager(session_id)
await session_managers[session_id].main_agent.tool_manager.initialize()
manager = session_managers[session_id]
await manager.connect(websocket)
try:
await manager.send_history_and_current(websocket)
while True:
text = await websocket.receive_text()
try:
data = json.loads(text)
except Exception:
await websocket.send_text(json.dumps({"error": "消息格式错误,必须是JSON"}))
continue
if manager.state == "running":
await websocket.send_text(json.dumps({"error": "任务正在执行中,请稍后再试"}))
continue
if data.get("action") == "start":
content = data.get("content", "默认内容")
# think = data.get("think", True)
manager.main_agent.cur_stream_id = uuid.uuid4().hex[:8]
asyncio.create_task(manager.start(content = content))
else:
await websocket.send_text(json.dumps({"error": "不支持的操作"}))
except WebSocketDisconnect:
print(f"[WebSocket] 断开: {session_id}")
manager.disconnect(websocket)
\ No newline at end of file
from fastapi import FastAPI
from app import *
app = FastAPI()
app.include_router(chat_router, prefix="/chat", tags=["chat"])
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
# Test your FastAPI endpoints
GET http://127.0.0.1:8000/
Accept: application/json
###
GET http://127.0.0.1:8000/hello/User
Accept: application/json
###
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment