Commit 59503a6c authored by yuguo's avatar yuguo

fix

parent 9033450a
......@@ -24,7 +24,9 @@
"Bash(./api.exe:*)",
"Bash(PGPASSWORD=123456 psql:*)",
"Bash(go mod:*)",
"Bash(go vet:*)"
"Bash(go vet:*)",
"Bash(cmd:*)",
"Bash(go get:*)"
]
}
}
......@@ -92,6 +92,11 @@ func main() {
&model.PaymentOrder{},
&model.DoctorIncome{},
&model.DoctorWithdrawal{},
// 安全过滤
&model.SafetyWordRule{},
&model.SafetyFilterLog{},
// HTTP 动态工具
&model.HTTPToolDefinition{},
); err != nil {
log.Printf("Warning: AutoMigrate failed: %v", err)
} else {
......@@ -134,8 +139,11 @@ func main() {
log.Printf("Warning: Failed to init departments and doctors: %v", err)
}
// 初始化Agent工具
// 初始化Agent工具(内置工具 + HTTP 动态工具)
internalagent.InitTools()
internalagent.LoadHTTPTools()
// 注入跨包回调(AgentCallFn / WorkflowTriggerFn)
internalagent.WireCallbacks()
// 设置 Gin 模式
gin.SetMode(cfg.Server.Mode)
......@@ -168,6 +176,9 @@ func main() {
authApi := api.Group("")
authApi.Use(middleware.JWTAuth())
authApi.GET("/user/me", userHandler.GetCurrentUser)
authApi.POST("/user/logout", userHandler.Logout)
authApi.POST("/user/verify-identity", userHandler.VerifyIdentity)
authApi.POST("/doctor/appointment", doctorHandler.MakeAppointment)
// 医生端路由(需要认证 + 医生角色)
doctorPortalHandler := doctorportal.NewHandler()
......
......@@ -22,6 +22,7 @@ require (
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/expr-lang/expr v1.17.8 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
......
......@@ -18,6 +18,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/expr-lang/expr v1.17.8 h1:W1loDTT+0PQf5YteHSTpju2qfUfNoBt4yw9+wOEU9VM=
github.com/expr-lang/expr v1.17.8/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
......
package internalagent
import (
"internet-hospital/pkg/agent"
"encoding/json"
"internet-hospital/internal/model"
)
// NewPreConsultAgent 预问诊Agent
func NewPreConsultAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "pre_consult_agent",
Name: "预问诊智能助手",
Description: "通过多轮对话收集患者症状,生成预问诊报告",
SystemPrompt: `你是一位专业的AI预问诊助手。你的职责是:
// defaultAgentDefinitions 返回内置Agent的默认数据库配置
// 当数据库中不存在时,会自动写入并作为初始配置
func defaultAgentDefinitions() []model.AgentDefinition {
preConsultTools, _ := json.Marshal([]string{"query_symptom_knowledge", "recommend_department"})
diagnosisTools, _ := json.Marshal([]string{"query_medical_record", "query_symptom_knowledge", "search_medical_knowledge"})
prescriptionTools, _ := json.Marshal([]string{"query_drug", "check_drug_interaction", "check_contraindication", "calculate_dosage"})
followUpTools, _ := json.Marshal([]string{"query_medical_record", "query_drug", "query_symptom_knowledge"})
return []model.AgentDefinition{
{
AgentID: "pre_consult_agent",
Name: "预问诊智能助手",
Description: "通过多轮对话收集患者症状,生成预问诊报告",
Category: "patient",
SystemPrompt: `你是一位专业的AI预问诊助手。你的职责是:
1. 通过友好的对话收集患者的症状信息
2. 询问症状的持续时间、严重程度、伴随症状等
3. 利用工具查询症状相关知识
......@@ -18,21 +28,16 @@ func NewPreConsultAgent() *agent.ReActAgent {
5. 生成简洁的预问诊报告
请用中文与患者交流,语气温和专业。不要做出确定性诊断,只提供参考建议。`,
Tools: []string{
"query_symptom_knowledge",
"recommend_department",
Tools: string(preConsultTools),
MaxIterations: 5,
Status: "active",
},
MaxIterations: 5,
})
}
// NewDiagnosisAgent 诊断辅助Agent
func NewDiagnosisAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "diagnosis_agent",
Name: "诊断辅助Agent",
Description: "辅助医生进行诊断,提供鉴别诊断建议",
SystemPrompt: `你是一位经验丰富的诊断辅助AI,协助医生进行临床决策。
{
AgentID: "diagnosis_agent",
Name: "诊断辅助Agent",
Description: "辅助医生进行诊断,提供鉴别诊断建议",
Category: "doctor",
SystemPrompt: `你是一位经验丰富的诊断辅助AI,协助医生进行临床决策。
你可以:
1. 查询患者病历记录(使用query_medical_record)
2. 检索医学知识库获取临床指南和疾病信息(使用search_medical_knowledge)
......@@ -46,22 +51,16 @@ func NewDiagnosisAgent() *agent.ReActAgent {
- 综合分析后给出诊断建议
请基于循证医学原则提供建议,所有建议仅供医生参考。`,
Tools: []string{
"query_medical_record",
"query_symptom_knowledge",
"search_medical_knowledge",
Tools: string(diagnosisTools),
MaxIterations: 10,
Status: "active",
},
MaxIterations: 10,
})
}
// NewPrescriptionAgent 处方审核Agent
func NewPrescriptionAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "prescription_agent",
Name: "处方审核Agent",
Description: "审核处方合理性,检查药物相互作用、禁忌症和剂量",
SystemPrompt: `你是一位专业的临床药师AI,负责处方审核。
{
AgentID: "prescription_agent",
Name: "处方审核Agent",
Description: "审核处方合理性,检查药物相互作用、禁忌症和剂量",
Category: "pharmacy",
SystemPrompt: `你是一位专业的临床药师AI,负责处方审核。
你的职责:
1. 查询药品信息(规格、用法、禁忌)
2. 检查药物相互作用(使用check_drug_interaction工具)
......@@ -76,23 +75,16 @@ func NewPrescriptionAgent() *agent.ReActAgent {
- 最后综合所有检查结果给出审核意见
请严格按照药品说明书和临床指南进行审核,对于存在风险的处方要明确指出。`,
Tools: []string{
"query_drug",
"check_drug_interaction",
"check_contraindication",
"calculate_dosage",
Tools: string(prescriptionTools),
MaxIterations: 10,
Status: "active",
},
MaxIterations: 10,
})
}
// NewFollowUpAgent 随访管理Agent
func NewFollowUpAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "follow_up_agent",
Name: "随访管理Agent",
Description: "管理患者随访,提醒用药、复诊,收集健康数据",
SystemPrompt: `你是一位专业的随访管理AI助手。你的职责是:
{
AgentID: "follow_up_agent",
Name: "随访管理Agent",
Description: "管理患者随访,提醒用药、复诊,收集健康数据",
Category: "patient",
SystemPrompt: `你是一位专业的随访管理AI助手。你的职责是:
1. 查询患者的处方和用药情况
2. 提醒患者按时用药
3. 收集患者的健康数据(血压、血糖等)
......@@ -100,11 +92,9 @@ func NewFollowUpAgent() *agent.ReActAgent {
5. 生成随访报告
请用温和关怀的语气与患者交流,关注患者的用药依从性和健康状况变化。`,
Tools: []string{
"query_medical_record",
"query_drug",
"query_symptom_knowledge",
Tools: string(followUpTools),
MaxIterations: 8,
Status: "active",
},
MaxIterations: 8,
})
}
}
This diff is collapsed.
package internalagent
import (
"log"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/agent/tools"
"internet-hospital/pkg/database"
)
// LoadHTTPTools 从数据库加载所有 active HTTP 工具并注册到 ToolRegistry
func LoadHTTPTools() {
db := database.GetDB()
if db == nil {
return
}
var defs []model.HTTPToolDefinition
if err := db.Where("status = 'active'").Find(&defs).Error; err != nil {
log.Printf("[LoadHTTPTools] 加载失败: %v", err)
return
}
r := agent.GetRegistry()
for i := range defs {
r.Register(tools.NewDynamicHTTPTool(&defs[i]))
}
if len(defs) > 0 {
log.Printf("[LoadHTTPTools] 已加载 %d 个 HTTP 工具", len(defs))
}
}
// ReloadHTTPTools 卸载所有旧 HTTP 工具,然后重新从数据库加载
func ReloadHTTPTools() {
db := database.GetDB()
if db == nil {
return
}
r := agent.GetRegistry()
// 卸载所有已注册的 HTTP 工具(不论状态)
var all []model.HTTPToolDefinition
db.Find(&all)
for _, def := range all {
r.Unregister(def.Name)
}
// 重新加载 active 工具
LoadHTTPTools()
log.Printf("[ReloadHTTPTools] HTTP 工具已热重载")
}
package internalagent
import (
"strconv"
"github.com/gin-gonic/gin"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent/tools"
"internet-hospital/pkg/database"
"internet-hospital/pkg/response"
)
// ListHTTPTools GET /agent/http-tools
func (h *Handler) ListHTTPTools(c *gin.Context) {
db := database.GetDB()
var defs []model.HTTPToolDefinition
db.Order("created_at DESC").Find(&defs)
response.Success(c, defs)
}
// CreateHTTPTool POST /agent/http-tools
func (h *Handler) CreateHTTPTool(c *gin.Context) {
var req model.HTTPToolDefinition
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
userID, _ := c.Get("user_id")
req.CreatedBy, _ = userID.(string)
if req.Headers == "" {
req.Headers = "{}"
}
if req.AuthConfig == "" {
req.AuthConfig = "{}"
}
if req.Parameters == "" {
req.Parameters = "[]"
}
db := database.GetDB()
if err := db.Create(&req).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
ReloadHTTPTools()
response.Success(c, req)
}
// UpdateHTTPTool PUT /agent/http-tools/:id
func (h *Handler) UpdateHTTPTool(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
response.BadRequest(c, "invalid id")
return
}
var req map[string]interface{}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
db := database.GetDB()
if err := db.Model(&model.HTTPToolDefinition{}).Where("id = ?", id).Updates(req).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
ReloadHTTPTools()
response.Success(c, nil)
}
// DeleteHTTPTool DELETE /agent/http-tools/:id
func (h *Handler) DeleteHTTPTool(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
response.BadRequest(c, "invalid id")
return
}
db := database.GetDB()
var def model.HTTPToolDefinition
if err := db.First(&def, id).Error; err != nil {
response.Error(c, 404, "tool not found")
return
}
ReloadHTTPTools()
if err := db.Delete(&model.HTTPToolDefinition{}, id).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
ReloadHTTPTools()
response.Success(c, nil)
}
// TestHTTPTool POST /agent/http-tools/:id/test
func (h *Handler) TestHTTPTool(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
response.BadRequest(c, "invalid id")
return
}
var params map[string]interface{}
_ = c.ShouldBindJSON(&params)
if params == nil {
params = map[string]interface{}{}
}
db := database.GetDB()
var def model.HTTPToolDefinition
if err := db.First(&def, id).Error; err != nil {
response.Error(c, 404, "tool not found")
return
}
tool := tools.NewDynamicHTTPTool(&def)
result, execErr := tool.Execute(c.Request.Context(), params)
if execErr != nil {
response.Success(c, gin.H{
"success": false,
"error": execErr.Error(),
"result": result,
})
return
}
response.Success(c, gin.H{"success": true, "result": result})
}
// ReloadHTTPToolsAPI POST /agent/http-tools/reload
func (h *Handler) ReloadHTTPToolsAPI(c *gin.Context) {
ReloadHTTPTools()
response.SuccessWithMessage(c, "HTTP 工具已热重载", nil)
}
package internalagent
import (
"context"
"encoding/json"
"fmt"
"log"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/agent/tools"
"internet-hospital/pkg/database"
"internet-hospital/pkg/rag"
"internet-hospital/pkg/workflow"
)
// InitTools 注册所有工具到全局注册中心
// InitTools 注册所有工具到全局注册中心,并同步元数据到数据库
func InitTools() {
db := database.GetDB()
r := agent.GetRegistry()
......@@ -27,6 +34,131 @@ func InitTools() {
r.Register(&tools.ContraindicationTool{DB: db})
r.Register(&tools.DosageCalculatorTool{DB: db})
// 知识库检索工具
// 知识库工具
r.Register(&tools.KnowledgeSearchTool{Retriever: retriever})
r.Register(&tools.KnowledgeWriteTool{DB: db, Retriever: retriever})
r.Register(&tools.KnowledgeListTool{})
// Agent互调 & 工作流工具
r.Register(&tools.AgentCallerTool{})
r.Register(&tools.WorkflowTriggerTool{})
r.Register(&tools.WorkflowQueryTool{})
r.Register(&tools.HumanReviewTool{})
// 表达式执行工具
r.Register(&tools.ExprEvalTool{})
// 通知 & 随访工具
r.Register(&tools.SendNotificationTool{})
r.Register(&tools.GenerateFollowUpPlanTool{})
// 同步工具元数据到 AgentTool 表,并应用启/停状态
syncToolsToDB(r)
applyToolStatus(r)
}
// WireCallbacks 注入跨包回调(在 InitTools 和 GetService 初始化完成后调用)
// AgentCallFn: AgentCallerTool → AgentService.Chat
// WorkflowTriggerFn: WorkflowTriggerTool → workflow.Engine.Execute
func WireCallbacks() {
svc := GetService()
tools.AgentCallFn = func(ctx context.Context, agentID, userID, sessionID, message string, ctxData map[string]interface{}) (string, error) {
output, err := svc.Chat(ctx, agentID, userID, sessionID, message, ctxData)
if err != nil {
return "", err
}
if output == nil {
return "", fmt.Errorf("agent %s 不存在或未启用", agentID)
}
return output.Response, nil
}
tools.WorkflowTriggerFn = func(ctx context.Context, workflowID string, input map[string]interface{}) (string, error) {
return workflow.GetEngine().Execute(ctx, workflowID, input, "agent")
}
log.Println("[InitTools] AgentCallFn & WorkflowTriggerFn 注入完成")
}
// syncToolsToDB 将注册的工具元数据写入 AgentTool 表(不存在则创建,存在则不覆盖)
func syncToolsToDB(r *agent.ToolRegistry) {
db := database.GetDB()
if db == nil {
return
}
categoryMap := map[string]string{
// 基础查询
"query_symptom_knowledge": "knowledge",
"recommend_department": "recommendation",
"query_medical_record": "medical",
"search_medical_knowledge": "knowledge",
"query_drug": "pharmacy",
// 处方安全
"check_drug_interaction": "safety",
"check_contraindication": "safety",
"calculate_dosage": "pharmacy",
// 知识库
"write_knowledge": "knowledge",
"list_knowledge_collections": "knowledge",
// Agent & 工作流
"call_agent": "agent",
"trigger_workflow": "workflow",
"query_workflow_status": "workflow",
"request_human_review": "workflow",
// 表达式
"eval_expression": "expression",
// 通知 & 随访
"generate_follow_up_plan": "follow_up",
"send_notification": "notification",
}
for name, tool := range r.All() {
params := make([]map[string]interface{}, 0)
for _, p := range tool.Parameters() {
params = append(params, map[string]interface{}{
"name": p.Name,
"type": p.Type,
"required": p.Required,
})
}
paramsJSON, _ := json.Marshal(params)
category := categoryMap[name]
if category == "" {
category = "other"
}
var existing model.AgentTool
if err := db.Where("name = ?", name).First(&existing).Error; err != nil {
// 不存在则创建
entry := model.AgentTool{
Name: name,
DisplayName: tool.Description(),
Description: tool.Description(),
Category: category,
Parameters: string(paramsJSON),
Status: "active",
}
if err := db.Create(&entry).Error; err != nil {
log.Printf("[InitTools] 同步工具 %s 到数据库失败: %v", name, err)
}
}
// 已存在则不更新(保留管理员的自定义状态)
}
}
// applyToolStatus 从数据库读取工具状态(disabled 的工具在 executor 中已检查,此处仅记录日志)
func applyToolStatus(r *agent.ToolRegistry) {
db := database.GetDB()
if db == nil {
return
}
var disabledTools []model.AgentTool
db.Where("status = 'disabled'").Find(&disabledTools)
if len(disabledTools) > 0 {
for _, t := range disabledTools {
log.Printf("[InitTools] 工具 %s 已被禁用", t.Name)
}
}
}
......@@ -3,6 +3,8 @@ package internalagent
import (
"context"
"encoding/json"
"log"
"sync"
"time"
"internet-hospital/internal/model"
......@@ -15,6 +17,7 @@ import (
// AgentService Agent服务
type AgentService struct {
mu sync.RWMutex
agents map[string]*agent.ReActAgent
}
......@@ -23,26 +26,124 @@ var globalAgentService *AgentService
func GetService() *AgentService {
if globalAgentService == nil {
globalAgentService = &AgentService{
agents: map[string]*agent.ReActAgent{
"pre_consult_agent": NewPreConsultAgent(),
"diagnosis_agent": NewDiagnosisAgent(),
"prescription_agent": NewPrescriptionAgent(),
"follow_up_agent": NewFollowUpAgent(),
},
agents: make(map[string]*agent.ReActAgent),
}
globalAgentService.loadFromDB()
globalAgentService.ensureBuiltinAgents()
}
return globalAgentService
}
// loadFromDB 从数据库加载所有 active 的 AgentDefinition
func (s *AgentService) loadFromDB() {
db := database.GetDB()
if db == nil {
return
}
var definitions []model.AgentDefinition
if err := db.Where("status = 'active'").Find(&definitions).Error; err != nil {
log.Printf("[AgentService] 从数据库加载Agent失败: %v", err)
return
}
for _, def := range definitions {
a := buildAgentFromDef(def)
s.agents[def.AgentID] = a
}
log.Printf("[AgentService] 从数据库加载了 %d 个Agent", len(definitions))
}
func buildAgentFromDef(def model.AgentDefinition) *agent.ReActAgent {
var tools []string
if def.Tools != "" {
json.Unmarshal([]byte(def.Tools), &tools)
}
maxIter := def.MaxIterations
if maxIter <= 0 {
maxIter = 10
}
return agent.NewReActAgent(agent.ReActConfig{
ID: def.AgentID,
Name: def.Name,
Description: def.Description,
SystemPrompt: def.SystemPrompt,
Tools: tools,
MaxIterations: maxIter,
})
}
// ensureBuiltinAgents 如果数据库中不存在内置Agent,则写入默认配置
func (s *AgentService) ensureBuiltinAgents() {
db := database.GetDB()
if db == nil {
return
}
defaults := defaultAgentDefinitions()
for _, def := range defaults {
// 如果内存中已有(来自数据库),跳过
s.mu.RLock()
_, exists := s.agents[def.AgentID]
s.mu.RUnlock()
if exists {
continue
}
// 写入数据库
var existing model.AgentDefinition
if err := db.Where("agent_id = ?", def.AgentID).First(&existing).Error; err != nil {
// 不存在则创建
if err := db.Create(&def).Error; err != nil {
log.Printf("[AgentService] 写入默认Agent失败: %v", err)
continue
}
existing = def
}
s.mu.Lock()
s.agents[def.AgentID] = buildAgentFromDef(existing)
s.mu.Unlock()
}
}
// ReloadAgent 热重载单个Agent(管理端修改配置后调用)
func (s *AgentService) ReloadAgent(agentID string) error {
db := database.GetDB()
var def model.AgentDefinition
if err := db.Where("agent_id = ?", agentID).First(&def).Error; err != nil {
return err
}
a := buildAgentFromDef(def)
s.mu.Lock()
if def.Status == "active" {
s.agents[agentID] = a
} else {
delete(s.agents, agentID)
}
s.mu.Unlock()
log.Printf("[AgentService] 已热重载Agent: %s", agentID)
return nil
}
// ReloadAll 重新加载所有Agent
func (s *AgentService) ReloadAll() {
s.mu.Lock()
s.agents = make(map[string]*agent.ReActAgent)
s.mu.Unlock()
s.loadFromDB()
s.ensureBuiltinAgents()
}
func (s *AgentService) GetAgent(agentID string) (*agent.ReActAgent, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
a, ok := s.agents[agentID]
return a, ok
}
func (s *AgentService) ListAgents() []map[string]string {
result := make([]map[string]string, 0, len(s.agents))
func (s *AgentService) ListAgents() []map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]map[string]interface{}, 0, len(s.agents))
for _, a := range s.agents {
result = append(result, map[string]string{
result = append(result, map[string]interface{}{
"id": a.ID(),
"name": a.Name(),
"description": a.Description(),
......@@ -61,10 +162,10 @@ func (s *AgentService) Chat(ctx context.Context, agentID, userID, sessionID, mes
db := database.GetDB()
// 加载或创建会话
var session model.AgentSession
if sessionID == "" {
sessionID = uuid.New().String()
}
var session model.AgentSession
db.Where("session_id = ?", sessionID).First(&session)
// 解析历史消息
......@@ -119,6 +220,7 @@ func (s *AgentService) Chat(ctx context.Context, agentID, userID, sessionID, mes
outputJSON, _ := json.Marshal(output)
toolCallsJSON, _ := json.Marshal(output.ToolCalls)
db.Create(&model.AgentExecutionLog{
TraceID: output.TraceID,
SessionID: sessionID,
AgentID: agentID,
UserID: userID,
......
......@@ -11,6 +11,9 @@ type AgentTool struct {
Category string `gorm:"type:varchar(50)"`
Parameters string `gorm:"type:jsonb"`
Status string `gorm:"type:varchar(20);default:'active'"`
CacheTTL int `gorm:"default:0"` // 缓存秒数,0=不缓存
Timeout int `gorm:"default:30"` // 执行超时秒数
MaxRetries int `gorm:"default:0"` // 失败重试次数
CreatedAt time.Time
UpdatedAt time.Time
}
......@@ -18,6 +21,7 @@ type AgentTool struct {
// AgentToolLog 工具调用日志
type AgentToolLog struct {
ID uint `gorm:"primaryKey"`
TraceID string `gorm:"type:varchar(100);index"` // 链路追踪ID
ToolName string `gorm:"type:varchar(100);index"`
AgentID string `gorm:"type:varchar(100);index"`
SessionID string `gorm:"type:varchar(100);index"`
......@@ -27,6 +31,7 @@ type AgentToolLog struct {
Success bool
ErrorMessage string `gorm:"type:text"`
DurationMs int
Iteration int // Agent第几轮迭代
CreatedAt time.Time
}
......@@ -62,6 +67,7 @@ type AgentSession struct {
// AgentExecutionLog Agent执行日志
type AgentExecutionLog struct {
ID uint `gorm:"primaryKey"`
TraceID string `gorm:"type:varchar(100);index"` // 链路追踪ID
SessionID string `gorm:"type:varchar(100);index"`
AgentID string `gorm:"type:varchar(100);index"`
UserID string `gorm:"type:uuid;index"`
......
......@@ -42,7 +42,12 @@ type AIUsageLog struct {
Success bool `gorm:"default:true" json:"success"`
ErrorMessage string `gorm:"type:text" json:"error_message,omitempty"`
IsMock bool `gorm:"default:false" json:"is_mock"` // 是否为模拟调用
CreatedAt time.Time `json:"created_at"`
// 链路追踪字段
TraceID string `gorm:"type:varchar(100);index" json:"trace_id"` // 链路追踪ID,关联同一次Agent执行
AgentID string `gorm:"type:varchar(100);index" json:"agent_id"` // 关联Agent
SessionID string `gorm:"type:varchar(100);index" json:"session_id"` // 关联会话
Iteration int `json:"iteration"` // Agent第几轮迭代(0=非Agent调用)
CreatedAt time.Time `json:"created_at"`
}
func (AIUsageLog) TableName() string {
......
package model
import "time"
// HTTPToolDefinition 动态 HTTP 工具定义(管理员从 UI 配置,无需改代码)
type HTTPToolDefinition struct {
ID uint `gorm:"primaryKey"`
Name string `gorm:"type:varchar(100);uniqueIndex"` // 工具名,如 get_weather
DisplayName string `gorm:"type:varchar(200)"`
Description string `gorm:"type:text"`
Category string `gorm:"type:varchar(50);default:'http'"`
Method string `gorm:"type:varchar(10);default:'GET'"` // GET/POST/PUT/DELETE
URL string `gorm:"type:text"` // 支持 {{param}} 模板变量
Headers string `gorm:"type:jsonb;default:'{}'"` // {"X-Key": "{{api_key}}"}
BodyTemplate string `gorm:"type:text"` // JSON body 模板,支持 {{param}}
AuthType string `gorm:"type:varchar(20);default:'none'"` // none/bearer/basic/apikey
AuthConfig string `gorm:"type:jsonb;default:'{}'"` // 认证配置
Parameters string `gorm:"type:jsonb;default:'[]'"` // ToolParameter 数组 JSON
Timeout int `gorm:"default:10"` // 超时秒数
CacheTTL int `gorm:"default:0"` // 缓存 TTL 秒,0=不缓存
Status string `gorm:"type:varchar(20);default:'active'"`
CreatedBy string `gorm:"type:varchar(100)"`
CreatedAt time.Time
UpdatedAt time.Time
}
......@@ -8,15 +8,20 @@ import (
// PromptTemplate AI提示词模板
type PromptTemplate struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
TemplateKey string `gorm:"type:varchar(100);uniqueIndex;not null" json:"template_key"` // 模板key,用于代码中取值
Name string `gorm:"type:varchar(100);not null" json:"name"` // 模板名称
Scene string `gorm:"type:varchar(50)" json:"scene"` // 应用场景
Content string `gorm:"type:text;not null" json:"content"` // Prompt内容
Status string `gorm:"type:varchar(20);default:'active'" json:"status"` // active | disabled
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
TemplateKey string `gorm:"type:varchar(100);uniqueIndex;not null" json:"template_key"` // 模板key,用于代码中取值
Name string `gorm:"type:varchar(100);not null" json:"name"` // 模板名称
Scene string `gorm:"type:varchar(50)" json:"scene"` // 应用场景
Content string `gorm:"type:text;not null" json:"content"` // Prompt内容
Status string `gorm:"type:varchar(20);default:'active'" json:"status"` // active | disabled
// 智能体关联字段
AgentID string `gorm:"type:varchar(100);index" json:"agent_id"` // 关联的Agent ID,空表示通用模板
TemplateType string `gorm:"type:varchar(20);default:'system'" json:"template_type"` // system | user | tool_result
Variables string `gorm:"type:jsonb" json:"variables"` // 变量定义 [{"name":"patient_name","type":"string","required":true}]
Version int `gorm:"default:1" json:"version"` // 版本号
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
func (PromptTemplate) TableName() string {
......
package model
import "time"
// SafetyWordRule 安全词规则
type SafetyWordRule struct {
ID uint `gorm:"primaryKey" json:"id"`
Word string `gorm:"type:varchar(200);index" json:"word"` // 敏感词/正则
Category string `gorm:"type:varchar(50)" json:"category"` // medical_claim | drug_promotion | privacy | toxicity | injection
Level string `gorm:"type:varchar(20)" json:"level"` // block | warn | replace
Replacement string `gorm:"type:varchar(200)" json:"replacement"` // level=replace时的替换文本
Direction string `gorm:"type:varchar(10);default:'both'" json:"direction"` // input | output | both
IsRegex bool `gorm:"default:false" json:"is_regex"` // 是否正则表达式
AgentID string `gorm:"type:varchar(100)" json:"agent_id"` // 特定Agent的规则,空=全局
Status string `gorm:"type:varchar(20);default:'active'" json:"status"` // active | disabled
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (SafetyWordRule) TableName() string {
return "safety_word_rules"
}
// SafetyFilterLog 过滤日志
type SafetyFilterLog struct {
ID uint `gorm:"primaryKey" json:"id"`
TraceID string `gorm:"type:varchar(100);index" json:"trace_id"`
Direction string `gorm:"type:varchar(10)" json:"direction"` // input | output
OriginalText string `gorm:"type:text" json:"original_text"`
FilteredText string `gorm:"type:text" json:"filtered_text"`
MatchedRules string `gorm:"type:jsonb" json:"matched_rules"` // [{"rule_id":1,"word":"xxx","action":"block"}]
Action string `gorm:"type:varchar(20)" json:"action"` // passed | blocked | replaced | warned
AgentID string `gorm:"type:varchar(100)" json:"agent_id"`
UserID string `gorm:"type:uuid" json:"user_id"`
CreatedAt time.Time `json:"created_at"`
}
func (SafetyFilterLog) TableName() string {
return "safety_filter_logs"
}
package admin
import (
"strconv"
"github.com/gin-gonic/gin"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/response"
)
// GetTraceDetail 全链路追踪详情:给定 trace_id,返回该次Agent执行的所有 LLM 调用和工具调用
func (h *Handler) GetTraceDetail(c *gin.Context) {
traceID := c.Query("trace_id")
if traceID == "" {
response.BadRequest(c, "trace_id 不能为空")
return
}
db := database.GetDB()
var execLogs []model.AgentExecutionLog
db.Where("trace_id = ?", traceID).Order("created_at asc").Find(&execLogs)
var usageLogs []model.AIUsageLog
db.Where("trace_id = ?", traceID).Order("iteration asc, created_at asc").Find(&usageLogs)
var toolLogs []model.AgentToolLog
db.Where("trace_id = ?", traceID).Order("iteration asc, created_at asc").Find(&toolLogs)
response.Success(c, gin.H{
"trace_id": traceID,
"execution_logs": execLogs,
"llm_calls": usageLogs,
"tool_calls": toolLogs,
})
}
// GetAICenterStats 全量统计看板
func (h *Handler) GetAICenterStats(c *gin.Context) {
db := database.GetDB()
// LLM 调用统计
var totalCalls, successCalls, mockCalls int64
var totalTokens struct{ Sum int64 }
db.Model(&model.AIUsageLog{}).Count(&totalCalls)
db.Model(&model.AIUsageLog{}).Where("success = true").Count(&successCalls)
db.Model(&model.AIUsageLog{}).Where("is_mock = true").Count(&mockCalls)
db.Model(&model.AIUsageLog{}).Select("COALESCE(SUM(total_tokens), 0) as sum").Scan(&totalTokens)
// Agent 执行统计
var agentExecs, agentSuccess int64
db.Model(&model.AgentExecutionLog{}).Count(&agentExecs)
db.Model(&model.AgentExecutionLog{}).Where("success = true").Count(&agentSuccess)
// 工具调用统计
var toolCalls, toolSuccess int64
db.Model(&model.AgentToolLog{}).Count(&toolCalls)
db.Model(&model.AgentToolLog{}).Where("success = true").Count(&toolSuccess)
// 各 Agent 调用量
type AgentCallCount struct {
AgentID string `json:"agent_id"`
Count int64 `json:"count"`
}
var agentCounts []AgentCallCount
db.Model(&model.AIUsageLog{}).
Select("agent_id, COUNT(*) as count").
Where("agent_id != ''").
Group("agent_id").
Order("count DESC").
Limit(10).
Scan(&agentCounts)
// 各场景调用量
type SceneCallCount struct {
Scene string `json:"scene"`
Count int64 `json:"count"`
}
var sceneCounts []SceneCallCount
db.Model(&model.AIUsageLog{}).
Select("scene, COUNT(*) as count").
Group("scene").
Order("count DESC").
Limit(10).
Scan(&sceneCounts)
// 最近24小时调用趋势(按小时)
type HourlyTrend struct {
Hour string `json:"hour"`
Count int64 `json:"count"`
}
var trend []HourlyTrend
db.Model(&model.AIUsageLog{}).
Select("TO_CHAR(created_at, 'HH24') as hour, COUNT(*) as count").
Where("created_at >= NOW() - INTERVAL '24 hours'").
Group("hour").
Order("hour asc").
Scan(&trend)
// 近期日志列表(用于追踪)
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
agentID := c.Query("agent_id")
var recentLogsTotal int64
var recentLogs []model.AIUsageLog
q := db.Model(&model.AIUsageLog{})
if agentID != "" {
q = q.Where("agent_id = ?", agentID)
}
q.Count(&recentLogsTotal)
q.Order("created_at DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&recentLogs)
response.Success(c, gin.H{
"total_calls": totalCalls,
"success_calls": successCalls,
"mock_calls": mockCalls,
"total_tokens": totalTokens.Sum,
"agent_execs": agentExecs,
"agent_success": agentSuccess,
"tool_calls": toolCalls,
"tool_success": toolSuccess,
"agent_counts": agentCounts,
"scene_counts": sceneCounts,
"hourly_trend": trend,
"recent_logs": recentLogs,
"logs_total": recentLogsTotal,
})
}
......@@ -109,6 +109,19 @@ func (h *Handler) RegisterRoutes(r *gin.RouterGroup) {
// Agent 执行监控
adm.GET("/agent/logs", h.GetAgentExecutionLogs)
adm.GET("/agent/stats", h.GetAgentStats)
// 内容安全管理
adm.GET("/safety/rules", h.ListSafetyRules)
adm.POST("/safety/rules", h.CreateSafetyRule)
adm.PUT("/safety/rules/:id", h.UpdateSafetyRule)
adm.DELETE("/safety/rules/:id", h.DeleteSafetyRule)
adm.POST("/safety/rules/import-preset", h.ImportSafetyRules)
adm.GET("/safety/logs", h.ListSafetyLogs)
adm.GET("/safety/stats", h.GetSafetyStats)
// AI 运营中心
adm.GET("/ai-center/trace", h.GetTraceDetail)
adm.GET("/ai-center/stats", h.GetAICenterStats)
}
}
......
......@@ -6,6 +6,7 @@ import (
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/workflow"
)
// ==================== 处方监管请求/响应 ====================
......@@ -132,5 +133,16 @@ func (s *Service) ReviewPrescription(ctx context.Context, id string, action stri
return fmt.Errorf("无效的审核操作")
}
return s.db.Model(&prescription).Updates(updates).Error
if err := s.db.Model(&prescription).Updates(updates).Error; err != nil {
return err
}
// 处方审核通过时触发 prescription_approved 工作流(异步)
if action == "approve" {
workflow.GetEngine().TriggerByCategory(ctx, "prescription_approved", map[string]interface{}{
"prescription_id": id,
"reviewer_id": reviewerID,
})
}
return nil
}
package admin
import (
"strconv"
"github.com/gin-gonic/gin"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/response"
"internet-hospital/pkg/safety"
)
// ===================== 安全词规则 =====================
// ListSafetyRules 查询安全词规则列表
func (h *Handler) ListSafetyRules(c *gin.Context) {
db := database.GetDB()
category := c.Query("category")
direction := c.Query("direction")
status := c.Query("status")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
var total int64
var rules []model.SafetyWordRule
q := db.Model(&model.SafetyWordRule{})
if category != "" {
q = q.Where("category = ?", category)
}
if direction != "" {
q = q.Where("direction = ? OR direction = 'both'", direction)
}
if status != "" {
q = q.Where("status = ?", status)
}
q.Count(&total)
q.Order("id DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&rules)
response.Success(c, gin.H{"list": rules, "total": total, "page": page})
}
// CreateSafetyRule 创建安全词规则
func (h *Handler) CreateSafetyRule(c *gin.Context) {
var rule model.SafetyWordRule
if err := c.ShouldBindJSON(&rule); err != nil {
response.BadRequest(c, err.Error())
return
}
if rule.Direction == "" {
rule.Direction = "both"
}
if rule.Status == "" {
rule.Status = "active"
}
if err := database.GetDB().Create(&rule).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
// 重新加载缓存
safety.GetLoader().Reload()
response.Success(c, rule)
}
// UpdateSafetyRule 更新安全词规则
func (h *Handler) UpdateSafetyRule(c *gin.Context) {
id := c.Param("id")
var rule model.SafetyWordRule
if err := database.GetDB().First(&rule, "id = ?", id).Error; err != nil {
response.Error(c, 404, "规则不存在")
return
}
if err := c.ShouldBindJSON(&rule); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := database.GetDB().Save(&rule).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
safety.GetLoader().Reload()
response.Success(c, rule)
}
// DeleteSafetyRule 删除安全词规则
func (h *Handler) DeleteSafetyRule(c *gin.Context) {
id := c.Param("id")
database.GetDB().Delete(&model.SafetyWordRule{}, "id = ?", id)
safety.GetLoader().Reload()
response.Success(c, nil)
}
// ImportSafetyRules 批量导入安全词规则(医疗场景预置)
func (h *Handler) ImportSafetyRules(c *gin.Context) {
db := database.GetDB()
preset := presetMedicalSafetyRules()
created := 0
for _, rule := range preset {
var existing model.SafetyWordRule
if db.Where("word = ?", rule.Word).First(&existing).Error != nil {
if db.Create(&rule).Error == nil {
created++
}
}
}
safety.GetLoader().Reload()
response.SuccessWithMessage(c, "导入完成", gin.H{"created": created})
}
// ===================== 过滤日志 =====================
// ListSafetyLogs 查询过滤日志
func (h *Handler) ListSafetyLogs(c *gin.Context) {
db := database.GetDB()
action := c.Query("action")
direction := c.Query("direction")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
var total int64
var logs []model.SafetyFilterLog
q := db.Model(&model.SafetyFilterLog{})
if action != "" {
q = q.Where("action = ?", action)
}
if direction != "" {
q = q.Where("direction = ?", direction)
}
q.Count(&total)
q.Order("id DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&logs)
response.Success(c, gin.H{"list": logs, "total": total, "page": page})
}
// GetSafetyStats 安全过滤统计
func (h *Handler) GetSafetyStats(c *gin.Context) {
db := database.GetDB()
var total, blocked, replaced, warned int64
db.Model(&model.SafetyFilterLog{}).Count(&total)
db.Model(&model.SafetyFilterLog{}).Where("action = 'blocked'").Count(&blocked)
db.Model(&model.SafetyFilterLog{}).Where("action = 'replaced'").Count(&replaced)
db.Model(&model.SafetyFilterLog{}).Where("action = 'warned'").Count(&warned)
var totalRules, activeRules int64
db.Model(&model.SafetyWordRule{}).Count(&totalRules)
db.Model(&model.SafetyWordRule{}).Where("status = 'active'").Count(&activeRules)
response.Success(c, gin.H{
"total_logs": total,
"blocked": blocked,
"replaced": replaced,
"warned": warned,
"total_rules": totalRules,
"active_rules": activeRules,
})
}
// presetMedicalSafetyRules 医疗场景预置安全词规则
func presetMedicalSafetyRules() []model.SafetyWordRule {
return []model.SafetyWordRule{
// Prompt注入防护(输入方向)
{Word: "忽略上述指令", Category: "injection", Level: "block", Direction: "input", Status: "active"},
{Word: "忘记之前的", Category: "injection", Level: "block", Direction: "input", Status: "active"},
{Word: "你现在是", Category: "injection", Level: "warn", Direction: "input", Status: "active"},
{Word: "扮演", Category: "injection", Level: "warn", Direction: "input", Status: "active"},
// 医疗断言防护(输出方向)
{Word: "你确诊了", Category: "medical_claim", Level: "replace", Replacement: "根据症状分析,可能", Direction: "output", Status: "active"},
{Word: "你一定是", Category: "medical_claim", Level: "replace", Replacement: "初步判断可能是", Direction: "output", Status: "active"},
{Word: "你患有", Category: "medical_claim", Level: "replace", Replacement: "症状提示可能存在", Direction: "output", Status: "active"},
// 危险建议防护(输出方向)
{Word: "自行停药", Category: "toxicity", Level: "block", Direction: "output", Status: "active"},
{Word: "不需要看医生", Category: "toxicity", Level: "block", Direction: "output", Status: "active"},
{Word: "自己手术", Category: "toxicity", Level: "block", Direction: "output", Status: "active"},
}
}
......@@ -11,6 +11,7 @@ import (
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/utils"
"internet-hospital/pkg/workflow"
)
// Service 管理端业务逻辑层
......@@ -483,7 +484,16 @@ func (s *Service) ApproveDoctorReview(ctx context.Context, reviewID string) erro
return err
}
return tx.Commit().Error
if err := tx.Commit().Error; err != nil {
return err
}
// 触发 doctor_review 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "doctor_review", map[string]interface{}{
"review_id": reviewID,
"doctor_id": review.UserID,
"action": "approved",
})
return nil
}
func (s *Service) RejectDoctorReview(ctx context.Context, reviewID, reason string) error {
......@@ -497,11 +507,21 @@ func (s *Service) RejectDoctorReview(ctx context.Context, reviewID, reason strin
}
now := time.Now()
return s.db.Model(&review).Updates(map[string]interface{}{
if err := s.db.Model(&review).Updates(map[string]interface{}{
"status": "rejected",
"reject_reason": reason,
"reviewed_at": now,
}).Error
}).Error; err != nil {
return err
}
// 触发 doctor_review 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "doctor_review", map[string]interface{}{
"review_id": reviewID,
"doctor_id": review.UserID,
"action": "rejected",
"reason": reason,
})
return nil
}
func (s *Service) GetDepartmentList(ctx context.Context) (interface{}, error) {
......
......@@ -12,6 +12,7 @@ import (
internalagent "internet-hospital/internal/agent"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/workflow"
)
type Service struct {
......@@ -95,7 +96,18 @@ func (s *Service) CreateRenewal(ctx context.Context, userID string, req *Renewal
ChronicID: req.ChronicID, DiseaseName: req.DiseaseName,
Medicines: string(medsJSON), Reason: req.Reason, Status: "pending",
}
return r, s.db.WithContext(ctx).Create(r).Error
if err := s.db.WithContext(ctx).Create(r).Error; err != nil {
return nil, err
}
// 触发 renewal_requested 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "renewal_requested", map[string]interface{}{
"renewal_id": r.ID,
"patient_id": userID,
"chronic_record_id": req.ChronicID,
"disease_name": req.DiseaseName,
})
return r, nil
}
func (s *Service) GetAIRenewalAdvice(ctx context.Context, userID, renewalID string) (string, error) {
......@@ -220,7 +232,61 @@ func (s *Service) CreateMetric(ctx context.Context, userID string, req *MetricRe
MetricType: req.MetricType, Value1: req.Value1, Value2: req.Value2,
Unit: req.Unit, RecordedAt: recordedAt, Notes: req.Notes,
}
return r, s.db.WithContext(ctx).Create(r).Error
if err := s.db.WithContext(ctx).Create(r).Error; err != nil {
return nil, err
}
// 检测异常指标,触发 health_alert 工作流
if alertLevel := detectMetricAlert(req.MetricType, req.Value1, req.Value2); alertLevel != "" {
workflow.GetEngine().TriggerByCategory(ctx, "health_alert", map[string]interface{}{
"patient_id": userID,
"metric_type": req.MetricType,
"value1": req.Value1,
"value2": req.Value2,
"unit": req.Unit,
"alert_level": alertLevel,
})
}
return r, nil
}
// detectMetricAlert 检测健康指标是否异常,返回告警级别("mild"/"moderate"/"severe"/"" 表示正常)
func detectMetricAlert(metricType string, value1, value2 float64) string {
switch metricType {
case "blood_pressure":
// 收缩压 value1 / 舒张压 value2
if value1 >= 180 || value2 >= 120 {
return "severe"
} else if value1 >= 160 || value2 >= 100 {
return "moderate"
} else if value1 >= 140 || value2 >= 90 {
return "mild"
}
case "blood_glucose":
// 空腹血糖 mmol/L
if value1 >= 16.7 {
return "severe"
} else if value1 >= 11.1 {
return "moderate"
} else if value1 >= 7.0 {
return "mild"
}
case "heart_rate":
if value1 >= 150 || value1 <= 40 {
return "severe"
} else if value1 >= 120 || value1 <= 50 {
return "moderate"
}
case "body_temperature":
if value1 >= 39.5 {
return "severe"
} else if value1 >= 38.5 {
return "moderate"
} else if value1 >= 37.5 {
return "mild"
}
}
return ""
}
func (s *Service) DeleteMetric(ctx context.Context, userID, id string) error {
......
......@@ -13,6 +13,7 @@ import (
internalagent "internet-hospital/internal/agent"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/workflow"
)
type Service struct {
......@@ -83,6 +84,14 @@ func (s *Service) CreateConsult(ctx context.Context, patientID string, req *Crea
s.db.Create(videoRoom)
}
// 触发 consult_created 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "consult_created", map[string]interface{}{
"consult_id": consult.ID,
"patient_id": consult.PatientID,
"doctor_id": consult.DoctorID,
"type": consult.Type,
})
return &ConsultResponse{
Consultation: *consult,
DoctorName: doctor.Name,
......@@ -186,10 +195,20 @@ func (s *Service) EndConsult(ctx context.Context, consultID string) error {
}
now := time.Now()
return s.db.Model(&consult).Updates(map[string]interface{}{
if err := s.db.Model(&consult).Updates(map[string]interface{}{
"status": "completed",
"ended_at": now,
}).Error
}).Error; err != nil {
return err
}
// 触发 consult_ended 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "consult_ended", map[string]interface{}{
"consult_id": consultID,
"patient_id": consult.PatientID,
"doctor_id": consult.DoctorID,
})
return nil
}
func (s *Service) CancelConsult(ctx context.Context, consultID string) error {
......
......@@ -80,3 +80,23 @@ func (h *Handler) GetDoctorSchedule(c *gin.Context) {
}
response.Success(c, schedules)
}
type AppointmentRequest struct {
DoctorID string `json:"doctor_id" binding:"required"`
Date string `json:"date" binding:"required"`
TimeSlot string `json:"time_slot" binding:"required"`
}
func (h *Handler) MakeAppointment(c *gin.Context) {
var req AppointmentRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数错误")
return
}
appointmentID, err := h.service.MakeAppointment(c.Request.Context(), req.DoctorID, req.Date, req.TimeSlot)
if err != nil {
response.Error(c, 400, err.Error())
return
}
response.Success(c, gin.H{"appointment_id": appointmentID})
}
......@@ -2,6 +2,7 @@ package doctor
import (
"context"
"errors"
"gorm.io/gorm"
......@@ -120,3 +121,15 @@ func (s *Service) GetDoctorSchedule(ctx context.Context, doctorID, startDate, en
Find(&schedules).Error
return schedules, err
}
func (s *Service) MakeAppointment(ctx context.Context, doctorID, date, timeSlot string) (string, error) {
var schedule model.DoctorSchedule
if err := s.db.Where("doctor_id = ? AND date = ? AND start_time = ? AND remaining > 0", doctorID, date, timeSlot).
First(&schedule).Error; err != nil {
return "", errors.New("该时间段已无可用名额")
}
if err := s.db.Model(&schedule).UpdateColumn("remaining", schedule.Remaining-1).Error; err != nil {
return "", err
}
return schedule.ID, nil
}
......@@ -10,6 +10,7 @@ import (
internalagent "internet-hospital/internal/agent"
"internet-hospital/internal/model"
"internet-hospital/pkg/workflow"
)
// ==================== 处方开具请求/响应 ====================
......@@ -155,6 +156,15 @@ func (s *Service) CreatePrescription(ctx context.Context, doctorID string, req *
}
tx.Commit()
// 触发 prescription_created 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "prescription_created", map[string]interface{}{
"prescription_id": prescription.ID,
"doctor_id": doctorID,
"patient_id": prescription.PatientID,
"total_amount": prescription.TotalAmount,
})
return prescription, nil
}
......
......@@ -11,6 +11,7 @@ import (
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/workflow"
)
type Service struct {
......@@ -104,6 +105,16 @@ func (s *Service) PayOrder(ctx context.Context, orderID, paymentMethod string) (
s.createDoctorIncome(ctx, order.RelatedID, order.Amount, order.OrderType)
}
// 触发 payment_completed 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "payment_completed", map[string]interface{}{
"order_id": order.ID,
"order_no": order.OrderNo,
"user_id": order.UserID,
"order_type": order.OrderType,
"related_id": order.RelatedID,
"amount": order.Amount,
})
return map[string]interface{}{
"order_id": order.ID,
"status": order.Status,
......
......@@ -26,6 +26,34 @@ func (h *Handler) RegisterRoutes(r *gin.RouterGroup) {
}
}
func (h *Handler) Logout(c *gin.Context) {
// JWT 为无状态令牌,服务端无需操作;客户端删除本地令牌即可
response.Success(c, nil)
}
type VerifyIdentityRequest struct {
RealName string `json:"real_name" binding:"required"`
IDCard string `json:"id_card" binding:"required"`
}
func (h *Handler) VerifyIdentity(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
response.Unauthorized(c, "未登录")
return
}
var req VerifyIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数错误")
return
}
if err := h.service.VerifyIdentity(c.Request.Context(), userID.(string), req.RealName, req.IDCard); err != nil {
response.Error(c, 400, err.Error())
return
}
response.Success(c, nil)
}
type SendCodeRequest struct {
Phone string `json:"phone" binding:"required"`
}
......
......@@ -193,3 +193,16 @@ func (s *Service) GetUserByID(ctx context.Context, userID string) (*model.User,
func (s *Service) RefreshToken(ctx context.Context, refreshToken string) (*utils.TokenPair, error) {
return utils.RefreshAccessToken(refreshToken)
}
func (s *Service) VerifyIdentity(ctx context.Context, userID, realName, idCard string) error {
// 简单格式校验:身份证18位
if len(idCard) != 18 {
return errors.New("身份证号格式不正确")
}
return s.db.Model(&model.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"real_name": realName,
"is_verified": true,
}).Error
}
......@@ -4,19 +4,52 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
)
// getToolConfig 从数据库获取工具配置(CacheTTL/Timeout/MaxRetries/Status)
func getToolConfig(name string) model.AgentTool {
db := database.GetDB()
if db == nil {
return model.AgentTool{Status: "active", Timeout: 30}
}
var tool model.AgentTool
if err := db.Where("name = ?", name).First(&tool).Error; err != nil {
return model.AgentTool{Status: "active", Timeout: 30}
}
return tool
}
// Executor 工具执行器
type Executor struct {
registry *ToolRegistry
cache *ToolCache
}
func NewExecutor(r *ToolRegistry) *Executor {
return &Executor{registry: r}
return &Executor{registry: r, cache: GetCache()}
}
// Execute 执行工具调用
// Execute 执行工具调用(不记日志,集成缓存)
func (e *Executor) Execute(ctx context.Context, name string, argsJSON string) ToolResult {
cfg := getToolConfig(name)
// 检查工具是否被禁用
if cfg.Status == "disabled" {
return ToolResult{Success: false, Error: fmt.Sprintf("工具 %s 已被管理员禁用", name)}
}
// 缓存命中检查
if cfg.CacheTTL > 0 {
if cached, ok := e.cache.Get(name, argsJSON); ok {
return cached
}
}
tool, ok := e.registry.Get(name)
if !ok {
return ToolResult{Success: false, Error: fmt.Sprintf("tool not found: %s", name)}
......@@ -27,9 +60,74 @@ func (e *Executor) Execute(ctx context.Context, name string, argsJSON string) To
return ToolResult{Success: false, Error: fmt.Sprintf("invalid params: %v", err)}
}
data, err := tool.Execute(ctx, params)
if err != nil {
return ToolResult{Success: false, Error: err.Error()}
// 超时控制
timeout := time.Duration(cfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
return ToolResult{Success: true, Data: data}
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 执行(带重试)
var result ToolResult
maxRetries := cfg.MaxRetries
if maxRetries < 0 {
maxRetries = 0
}
for attempt := 0; attempt <= maxRetries; attempt++ {
data, err := tool.Execute(execCtx, params)
if err == nil {
result = ToolResult{Success: true, Data: data}
break
}
if attempt == maxRetries {
result = ToolResult{Success: false, Error: err.Error()}
}
}
// 写入缓存(仅成功结果)
if result.Success && cfg.CacheTTL > 0 {
e.cache.Set(name, argsJSON, result, cfg.CacheTTL)
}
return result
}
// ExecuteWithLog 执行工具调用并写入 AgentToolLog
func (e *Executor) ExecuteWithLog(ctx context.Context, name, argsJSON, traceID, agentID, sessionID, userID string, iteration int) ToolResult {
start := time.Now()
result := e.Execute(ctx, name, argsJSON)
durationMs := int(time.Since(start).Milliseconds())
// 异步写日志
go func() {
db := database.GetDB()
if db == nil {
return
}
outputJSON, _ := json.Marshal(result)
errMsg := ""
if !result.Success {
errMsg = result.Error
}
entry := &model.AgentToolLog{
TraceID: traceID,
ToolName: name,
AgentID: agentID,
SessionID: sessionID,
UserID: userID,
InputParams: argsJSON,
OutputResult: string(outputJSON),
Success: result.Success,
ErrorMessage: errMsg,
DurationMs: durationMs,
Iteration: iteration,
CreatedAt: time.Now(),
}
if err := db.Create(entry).Error; err != nil {
log.Printf("[executor] 保存工具调用日志失败: %v", err)
}
}()
return result
}
......@@ -5,6 +5,8 @@ import (
"encoding/json"
"fmt"
"internet-hospital/pkg/ai"
"github.com/google/uuid"
)
// AgentInput Agent输入
......@@ -15,15 +17,17 @@ type AgentInput struct {
Context map[string]interface{} `json:"context"`
History []ai.ChatMessage `json:"history"`
MaxIterations int `json:"max_iterations"`
TraceID string `json:"trace_id,omitempty"` // 外部传入的链路追踪ID(可选)
}
// AgentOutput Agent输出
type AgentOutput struct {
Response string `json:"response"`
Response string `json:"response"`
ToolCalls []ToolCallResult `json:"tool_calls,omitempty"`
Iterations int `json:"iterations"`
FinishReason string `json:"finish_reason"`
TotalTokens int `json:"total_tokens"`
Iterations int `json:"iterations"`
FinishReason string `json:"finish_reason"`
TotalTokens int `json:"total_tokens"`
TraceID string `json:"trace_id,omitempty"` // 本次执行的链路追踪ID
}
// ToolCallResult 工具调用结果记录
......@@ -73,11 +77,18 @@ func (a *ReActAgent) Description() string { return a.cfg.Description }
// Run 执行Agent(非流式)
func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, error) {
// 生成链路追踪ID(如果未提供则新建)
traceID := input.TraceID
if traceID == "" {
traceID = uuid.New().String()
}
client := ai.GetClient()
if client == nil {
return &AgentOutput{
Response: "AI服务未配置,请在管理端配置API Key",
FinishReason: "no_client",
TraceID: traceID,
}, nil
}
......@@ -105,6 +116,23 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
return nil, fmt.Errorf("AI调用失败: %w", err)
}
totalTokens += resp.Usage.TotalTokens
// 记录本轮 LLM 调用日志(与 AIUsageLog 关联)
ai.SaveLog(ai.LogParams{
Scene: a.cfg.ID,
UserID: input.UserID,
RequestContent: input.Message,
ResponseContent: resp.Choices[0].Message.Content,
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
Success: true,
TraceID: traceID,
AgentID: a.cfg.ID,
SessionID: input.SessionID,
Iteration: i + 1,
})
choice := resp.Choices[0]
if choice.FinishReason == "stop" || len(choice.Message.ToolCalls) == 0 {
......@@ -114,6 +142,7 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
Iterations: i + 1,
FinishReason: "completed",
TotalTokens: totalTokens,
TraceID: traceID,
}, nil
}
......@@ -126,7 +155,8 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
messages = append(messages, assistantMsg)
for _, tc := range choice.Message.ToolCalls {
result := a.executor.Execute(ctx, tc.Function.Name, tc.Function.Arguments)
result := a.executor.ExecuteWithLog(ctx, tc.Function.Name, tc.Function.Arguments,
traceID, a.cfg.ID, input.SessionID, input.UserID, i+1)
resultJSON, _ := json.Marshal(result)
toolCallResults = append(toolCallResults, ToolCallResult{
ToolName: tc.Function.Name,
......@@ -149,6 +179,7 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
Iterations: maxIter,
FinishReason: "max_iterations",
TotalTokens: totalTokens,
TraceID: traceID,
}, nil
}
......@@ -175,7 +206,13 @@ func (a *ReActAgent) RunStream(ctx context.Context, input AgentInput, onChunk fu
}
func (a *ReActAgent) buildSystemPrompt(ctx map[string]interface{}) string {
prompt := a.cfg.SystemPrompt
// 1. 优先从数据库加载该Agent关联的 active 提示词模板
prompt := ai.GetActivePromptByAgent(a.cfg.ID)
// 2. 回退到 AgentDefinition 的 SystemPrompt(即代码配置值)
if prompt == "" {
prompt = a.cfg.SystemPrompt
}
// 3. 渲染上下文变量
if ctx != nil {
if patientID, ok := ctx["patient_id"].(string); ok {
prompt += fmt.Sprintf("\n\n当前患者ID: %s", patientID)
......
......@@ -37,6 +37,12 @@ func (r *ToolRegistry) GetSchemas(names []string) []ToolSchema {
return schemas
}
func (r *ToolRegistry) Unregister(name string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.tools, name)
}
func (r *ToolRegistry) All() map[string]Tool {
r.mu.RLock()
defer r.mu.RUnlock()
......
package agent
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"time"
"internet-hospital/pkg/database"
)
const cacheKeyPrefix = "tool_cache:"
// ToolCache 工具结果 Redis 缓存
type ToolCache struct{}
var globalCache = &ToolCache{}
// GetCache 获取全局缓存实例
func GetCache() *ToolCache { return globalCache }
// argsHash 对工具参数计算 sha256 前 16 位,作为缓存 key 后缀
func argsHash(argsJSON string) string {
h := sha256.Sum256([]byte(argsJSON))
return fmt.Sprintf("%x", h[:8])
}
// cacheKey 构造缓存 key
func cacheKey(toolName, argsJSON string) string {
return cacheKeyPrefix + toolName + ":" + argsHash(argsJSON)
}
// Get 尝试命中缓存,返回 (result, hit)
func (c *ToolCache) Get(toolName, argsJSON string) (ToolResult, bool) {
client := database.RedisClient
if client == nil {
return ToolResult{}, false
}
key := cacheKey(toolName, argsJSON)
val, err := client.Get(context.Background(), key).Result()
if err != nil {
return ToolResult{}, false
}
var result ToolResult
if err := json.Unmarshal([]byte(val), &result); err != nil {
return ToolResult{}, false
}
return result, true
}
// Set 写入缓存
func (c *ToolCache) Set(toolName, argsJSON string, result ToolResult, ttlSec int) {
if ttlSec <= 0 {
return
}
client := database.RedisClient
if client == nil {
return
}
b, err := json.Marshal(result)
if err != nil {
return
}
key := cacheKey(toolName, argsJSON)
client.Set(context.Background(), key, string(b), time.Duration(ttlSec)*time.Second)
}
// Invalidate 手动使某工具的所有缓存失效(管理员修改工具配置后调用)
func (c *ToolCache) Invalidate(toolName string) {
client := database.RedisClient
if client == nil {
return
}
pattern := cacheKeyPrefix + toolName + ":*"
keys, err := client.Keys(context.Background(), pattern).Result()
if err != nil || len(keys) == 0 {
return
}
client.Del(context.Background(), keys...)
}
package tools
import (
"context"
"fmt"
"internet-hospital/pkg/agent"
)
// AgentCallFn 由 internal/agent/service.go 在启动时注入,避免循环依赖
var AgentCallFn func(ctx context.Context, agentID, userID, sessionID, message string, ctxData map[string]interface{}) (string, error)
const maxAgentCallDepth = 3
type agentCallDepthKey struct{}
// AgentCallerTool 允许 Agent 调用另一个 Agent(子 Agent 模式)
type AgentCallerTool struct{}
func (t *AgentCallerTool) Name() string { return "call_agent" }
func (t *AgentCallerTool) Description() string {
return "调用另一个专业 Agent 处理特定子任务,如诊断辅助、处方审核、随访管理等"
}
func (t *AgentCallerTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "agent_id",
Type: "string",
Description: "目标 Agent ID,如 diagnosis_agent、prescription_agent、follow_up_agent",
Required: true,
},
{
Name: "message",
Type: "string",
Description: "发送给目标 Agent 的任务描述或问题",
Required: true,
},
{
Name: "context",
Type: "object",
Description: "传递给目标 Agent 的上下文信息(可选)",
Required: false,
},
}
}
func (t *AgentCallerTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
// 防循环调用:检测调用深度
depth, _ := ctx.Value(agentCallDepthKey{}).(int)
if depth >= maxAgentCallDepth {
return nil, fmt.Errorf("agent 调用深度超限(最大 %d 层),防止无限递归", maxAgentCallDepth)
}
ctx = context.WithValue(ctx, agentCallDepthKey{}, depth+1)
agentID, ok := params["agent_id"].(string)
if !ok || agentID == "" {
return nil, fmt.Errorf("agent_id 必填")
}
message, ok := params["message"].(string)
if !ok || message == "" {
return nil, fmt.Errorf("message 必填")
}
ctxData, _ := params["context"].(map[string]interface{})
if AgentCallFn == nil {
return nil, fmt.Errorf("AgentCallFn 未注入,请检查初始化顺序")
}
response, err := AgentCallFn(ctx, agentID, "system", "", message, ctxData)
if err != nil {
return nil, fmt.Errorf("调用 agent %s 失败: %w", agentID, err)
}
return map[string]interface{}{
"agent_id": agentID,
"response": response,
}, nil
}
package tools
import (
"context"
"fmt"
"github.com/expr-lang/expr"
agentpkg "internet-hospital/pkg/agent"
)
// ExprEvalTool 安全表达式执行工具(使用 expr-lang,无 IO/网络权限)
type ExprEvalTool struct{}
func (t *ExprEvalTool) Name() string { return "eval_expression" }
func (t *ExprEvalTool) Description() string {
return "安全执行数学/逻辑表达式,支持四则运算、条件判断、字符串处理、数组过滤。常用于 BMI 计算、风险评分、剂量调整等"
}
func (t *ExprEvalTool) Parameters() []agentpkg.ToolParameter {
return []agentpkg.ToolParameter{
{
Name: "expression",
Type: "string",
Description: "表达式字符串,如: weight / (height * height) 或 age > 65 ? '老年' : '成年'",
Required: true,
},
{
Name: "variables",
Type: "object",
Description: "表达式中引用的变量,如: {\"weight\": 70, \"height\": 1.75}",
Required: false,
},
}
}
func (t *ExprEvalTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
expression, ok := params["expression"].(string)
if !ok || expression == "" {
return nil, fmt.Errorf("expression 必填")
}
variables, _ := params["variables"].(map[string]interface{})
if variables == nil {
variables = map[string]interface{}{}
}
// 编译表达式
program, err := expr.Compile(expression, expr.Env(variables))
if err != nil {
return nil, fmt.Errorf("表达式编译失败: %w", err)
}
// 执行表达式
result, err := expr.Run(program, variables)
if err != nil {
return nil, fmt.Errorf("表达式执行失败: %w", err)
}
return map[string]interface{}{
"expression": expression,
"result": result,
"variables": variables,
}, nil
}
package tools
import (
"context"
"fmt"
"internet-hospital/pkg/agent"
)
// GenerateFollowUpPlanTool 生成随访计划(此工具之前在分类 map 中声明但从未注册)
type GenerateFollowUpPlanTool struct{}
func (t *GenerateFollowUpPlanTool) Name() string { return "generate_follow_up_plan" }
func (t *GenerateFollowUpPlanTool) Description() string {
return "根据诊断和治疗信息生成个性化随访计划,包括复诊时间、检查项目、用药提醒和生活建议"
}
func (t *GenerateFollowUpPlanTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "diagnosis",
Type: "string",
Description: "诊断结论,如:高血压2级、2型糖尿病",
Required: true,
},
{
Name: "medications",
Type: "array",
Description: "当前用药列表,如:[\"氨氯地平 5mg\", \"二甲双胍 500mg\"]",
Required: false,
},
{
Name: "patient_age",
Type: "number",
Description: "患者年龄",
Required: false,
},
{
Name: "severity",
Type: "string",
Description: "病情严重程度:mild/moderate/severe",
Required: false,
Enum: []string{"mild", "moderate", "severe"},
},
}
}
func (t *GenerateFollowUpPlanTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
diagnosis, ok := params["diagnosis"].(string)
if !ok || diagnosis == "" {
return nil, fmt.Errorf("diagnosis 必填")
}
severity, _ := params["severity"].(string)
if severity == "" {
severity = "moderate"
}
age, _ := params["patient_age"].(float64)
// 基于诊断生成随访计划(内置规则)
plan := buildFollowUpPlan(diagnosis, severity, int(age))
meds := make([]string, 0)
if medsRaw, ok := params["medications"].([]interface{}); ok {
for _, m := range medsRaw {
if s, ok := m.(string); ok {
meds = append(meds, s)
}
}
}
if len(meds) > 0 {
plan["medication_reminders"] = buildMedReminders(meds)
}
return plan, nil
}
func buildFollowUpPlan(diagnosis, severity string, age int) map[string]interface{} {
// 随访频率
followUpIntervalDays := 90 // 默认3个月
switch severity {
case "mild":
followUpIntervalDays = 180
case "severe":
followUpIntervalDays = 30
}
if age >= 65 {
followUpIntervalDays = followUpIntervalDays / 2
}
plan := map[string]interface{}{
"diagnosis": diagnosis,
"follow_up_interval": fmt.Sprintf("每 %d 天随访一次", followUpIntervalDays),
"follow_up_interval_days": followUpIntervalDays,
}
// 诊断特异性建议
checks := []string{"血压测量", "体重监测", "症状询问"}
advice := []string{"低盐低脂饮食", "规律运动(每周150分钟中等强度)", "戒烟限酒", "规律服药"}
containsAny := func(s string, keywords ...string) bool {
for _, k := range keywords {
if containsIgnoreCase(s, k) {
return true
}
}
return false
}
if containsAny(diagnosis, "高血压", "blood pressure") {
checks = append(checks, "尿常规", "血肌酐", "血钾")
advice = append(advice, "每日家庭血压监测(早晚各1次)", "限制钠盐摄入(< 6g/日)")
plan["target"] = "血压控制目标:< 140/90 mmHg(耐受者 < 130/80 mmHg)"
}
if containsAny(diagnosis, "糖尿病", "diabetes") {
checks = append(checks, "HbA1c", "空腹血糖", "尿微量白蛋白", "眼底检查", "足部检查")
advice = append(advice, "控制碳水化合物摄入", "每日血糖自我监测", "注意低血糖症状")
plan["target"] = "血糖控制目标:HbA1c < 7.0%(个体化)"
}
if containsAny(diagnosis, "冠心病", "心绞痛", "心肌梗死") {
checks = append(checks, "心电图", "血脂全套", "心脏超声(年度)")
advice = append(advice, "避免剧烈运动", "随身携带硝酸甘油", "规范抗血小板治疗")
plan["emergency_signs"] = "如出现持续胸痛、呼吸困难、大汗请立即就医"
}
plan["monitoring_items"] = checks
plan["lifestyle_advice"] = advice
return plan
}
func buildMedReminders(meds []string) []map[string]interface{} {
reminders := make([]map[string]interface{}, 0, len(meds))
for _, med := range meds {
reminders = append(reminders, map[string]interface{}{
"medication": med,
"reminder": "按时服药,切勿自行停药",
"tip": "如出现不适反应请及时就医",
})
}
return reminders
}
package tools
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
)
// DynamicHTTPTool 动态 HTTP 工具,从数据库配置生成,无需修改代码
type DynamicHTTPTool struct {
def *model.HTTPToolDefinition
}
func NewDynamicHTTPTool(def *model.HTTPToolDefinition) *DynamicHTTPTool {
return &DynamicHTTPTool{def: def}
}
func (t *DynamicHTTPTool) Name() string { return t.def.Name }
func (t *DynamicHTTPTool) Description() string { return t.def.Description }
func (t *DynamicHTTPTool) Parameters() []agent.ToolParameter {
if t.def.Parameters == "" || t.def.Parameters == "[]" {
return nil
}
var params []agent.ToolParameter
if err := json.Unmarshal([]byte(t.def.Parameters), &params); err != nil {
return nil
}
return params
}
// renderTemplate 将 {{key}} 替换为 params 中对应值
func renderTemplate(template string, params map[string]interface{}) string {
result := template
for k, v := range params {
result = strings.ReplaceAll(result, "{{"+k+"}}", fmt.Sprintf("%v", v))
}
return result
}
func (t *DynamicHTTPTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
// 渲染 URL
url := renderTemplate(t.def.URL, params)
// 渲染 Body
var bodyStr string
if t.def.BodyTemplate != "" {
bodyStr = renderTemplate(t.def.BodyTemplate, params)
}
// 超时
timeout := time.Duration(t.def.Timeout) * time.Second
if timeout <= 0 {
timeout = 10 * time.Second
}
httpCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 创建请求
var bodyReader io.Reader
if bodyStr != "" {
bodyReader = strings.NewReader(bodyStr)
}
req, err := http.NewRequestWithContext(httpCtx, t.def.Method, url, bodyReader)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 设置 Content-Type
if bodyStr != "" {
req.Header.Set("Content-Type", "application/json")
}
// 自定义 Headers
if t.def.Headers != "" && t.def.Headers != "{}" {
var headers map[string]string
if err := json.Unmarshal([]byte(t.def.Headers), &headers); err == nil {
for k, v := range headers {
req.Header.Set(k, renderTemplate(v, params))
}
}
}
// 认证
switch t.def.AuthType {
case "bearer":
var authCfg map[string]string
if err := json.Unmarshal([]byte(t.def.AuthConfig), &authCfg); err == nil {
if token, ok := authCfg["token"]; ok && token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
}
case "basic":
var authCfg map[string]string
if err := json.Unmarshal([]byte(t.def.AuthConfig), &authCfg); err == nil {
req.SetBasicAuth(authCfg["username"], authCfg["password"])
}
case "apikey":
var authCfg map[string]string
if err := json.Unmarshal([]byte(t.def.AuthConfig), &authCfg); err == nil {
headerName := authCfg["header"]
headerVal := authCfg["value"]
if headerName != "" {
req.Header.Set(headerName, headerVal)
}
}
}
// 发起请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("HTTP 请求失败: %w", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
result := map[string]interface{}{
"status_code": resp.StatusCode,
"success": resp.StatusCode >= 200 && resp.StatusCode < 300,
}
// 尝试解析 JSON 响应
var jsonBody interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err == nil {
result["body"] = jsonBody
} else {
result["body"] = string(bodyBytes)
}
if resp.StatusCode >= 400 {
return result, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}
return result, nil
}
package tools
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// HumanReviewTool Agent 主动发起人工审核任务
type HumanReviewTool struct{}
func (t *HumanReviewTool) Name() string { return "request_human_review" }
func (t *HumanReviewTool) Description() string {
return "发起人工审核请求,等待医生或管理员审核后继续;适用于需要人工介入的场景"
}
func (t *HumanReviewTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "title",
Type: "string",
Description: "审核任务标题",
Required: true,
},
{
Name: "description",
Type: "string",
Description: "需要人工审核的内容描述",
Required: true,
},
{
Name: "assignee_role",
Type: "string",
Description: "指定审核角色:doctor 或 admin",
Required: false,
Enum: []string{"doctor", "admin"},
},
{
Name: "data",
Type: "object",
Description: "随审核任务传递的附加数据",
Required: false,
},
}
}
func (t *HumanReviewTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
title, ok := params["title"].(string)
if !ok || title == "" {
return nil, fmt.Errorf("title 必填")
}
description, _ := params["description"].(string)
assigneeRole, _ := params["assignee_role"].(string)
if assigneeRole == "" {
assigneeRole = "admin"
}
extraData, _ := params["data"].(map[string]interface{})
extraJSON := "{}"
if extraData != nil {
if b, err := json.Marshal(extraData); err == nil {
extraJSON = string(b)
}
}
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
taskID := uuid.New().String()
now := time.Now()
task := model.WorkflowHumanTask{
TaskID: taskID,
ExecutionID: "agent-" + taskID,
NodeID: "agent_requested",
Title: title,
Description: description,
AssigneeRole: assigneeRole,
FormData: extraJSON,
Status: "pending",
CreatedAt: now,
}
if err := db.Create(&task).Error; err != nil {
return nil, fmt.Errorf("创建审核任务失败: %w", err)
}
return map[string]interface{}{
"task_id": taskID,
"status": "pending",
"assignee": assigneeRole,
"message": fmt.Sprintf("已创建人工审核任务,等待 %s 审核", assigneeRole),
}, nil
}
package tools
import (
"context"
"fmt"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// KnowledgeListTool 列出可用知识库集合,Agent 可据此选择 collection_id
type KnowledgeListTool struct{}
func (t *KnowledgeListTool) Name() string { return "list_knowledge_collections" }
func (t *KnowledgeListTool) Description() string {
return "列出所有可用的知识库集合及文档数量,用于在检索前确认目标集合 ID"
}
func (t *KnowledgeListTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "keyword",
Type: "string",
Description: "按名称关键词过滤集合(可选)",
Required: false,
},
}
}
func (t *KnowledgeListTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
keyword, _ := params["keyword"].(string)
type CollectionStat struct {
CollectionID string `gorm:"column:collection_id"`
Name string `gorm:"column:name"`
Description string `gorm:"column:description"`
DocCount int64 `gorm:"column:doc_count"`
}
query := db.Model(&model.KnowledgeCollection{}).
Select("knowledge_collections.collection_id, knowledge_collections.name, knowledge_collections.description, COUNT(knowledge_documents.id) as doc_count").
Joins("LEFT JOIN knowledge_documents ON knowledge_documents.collection_id = knowledge_collections.collection_id AND knowledge_documents.status = 'active'").
Group("knowledge_collections.collection_id, knowledge_collections.name, knowledge_collections.description")
if keyword != "" {
query = query.Where("knowledge_collections.name ILIKE ? OR knowledge_collections.description ILIKE ?",
"%"+keyword+"%", "%"+keyword+"%")
}
var stats []CollectionStat
if err := query.Scan(&stats).Error; err != nil {
return nil, fmt.Errorf("查询知识库失败: %w", err)
}
result := make([]map[string]interface{}, 0, len(stats))
for _, s := range stats {
result = append(result, map[string]interface{}{
"collection_id": s.CollectionID,
"name": s.Name,
"description": s.Description,
"doc_count": s.DocCount,
})
}
return map[string]interface{}{
"collections": result,
"total": len(result),
}, nil
}
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
"internet-hospital/pkg/rag"
)
// KnowledgeWriteTool Agent 向知识库写入或更新内容
type KnowledgeWriteTool struct {
DB *gorm.DB
Retriever *rag.Retriever
}
func (t *KnowledgeWriteTool) Name() string { return "write_knowledge" }
func (t *KnowledgeWriteTool) Description() string {
return "向指定知识库集合写入新文档或更新已有文档,Agent 可用于保存诊断结论、治疗方案、病例摘要等"
}
func (t *KnowledgeWriteTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "collection_id",
Type: "string",
Description: "目标知识库集合 ID,如 clinical_guideline、case_summary",
Required: true,
},
{
Name: "title",
Type: "string",
Description: "文档标题",
Required: true,
},
{
Name: "content",
Type: "string",
Description: "文档正文内容(支持 Markdown)",
Required: true,
},
{
Name: "tags",
Type: "array",
Description: "标签列表,如 [\"高血压\", \"治疗方案\"]",
Required: false,
},
{
Name: "doc_id",
Type: "string",
Description: "指定文档 ID 则更新,不指定则创建新文档",
Required: false,
},
}
}
func (t *KnowledgeWriteTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
db := t.DB
if db == nil {
db = database.GetDB()
}
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
collectionID, ok := params["collection_id"].(string)
if !ok || collectionID == "" {
return nil, fmt.Errorf("collection_id 必填")
}
title, ok := params["title"].(string)
if !ok || title == "" {
return nil, fmt.Errorf("title 必填")
}
content, ok := params["content"].(string)
if !ok || content == "" {
return nil, fmt.Errorf("content 必填")
}
// 构建 tags JSON
tagsJSON := "[]"
if tagsRaw, ok := params["tags"].([]interface{}); ok {
tags := make([]string, 0, len(tagsRaw))
for _, t := range tagsRaw {
if s, ok := t.(string); ok {
tags = append(tags, s)
}
}
if b, err := json.Marshal(tags); err == nil {
tagsJSON = string(b)
}
}
docID, _ := params["doc_id"].(string)
now := time.Now()
// 确保集合存在
var collection model.KnowledgeCollection
if err := db.Where("collection_id = ?", collectionID).First(&collection).Error; err != nil {
// 自动创建集合
collection = model.KnowledgeCollection{
CollectionID: collectionID,
Name: collectionID,
Description: "由 Agent 自动创建",
CreatedAt: now,
UpdatedAt: now,
}
if err := db.Create(&collection).Error; err != nil {
return nil, fmt.Errorf("创建集合失败: %w", err)
}
}
// 创建或更新文档
if docID == "" {
docID = uuid.New().String()
}
var doc model.KnowledgeDocument
isNew := db.Where("document_id = ?", docID).First(&doc).Error != nil
if isNew {
doc = model.KnowledgeDocument{
DocumentID: docID,
CollectionID: collectionID,
Title: title,
Content: content,
Metadata: tagsJSON,
Status: "ready",
CreatedAt: now,
UpdatedAt: now,
}
if err := db.Create(&doc).Error; err != nil {
return nil, fmt.Errorf("创建文档失败: %w", err)
}
} else {
if err := db.Model(&doc).Updates(map[string]interface{}{
"title": title,
"content": content,
"metadata": tagsJSON,
"updated_at": now,
}).Error; err != nil {
return nil, fmt.Errorf("更新文档失败: %w", err)
}
// 删除旧分块
db.Where("document_id = ?", docID).Delete(&model.KnowledgeChunk{})
}
// 分块并写入(按段落分块)
chunks := splitIntoChunks(content, 500)
for i, chunk := range chunks {
c := model.KnowledgeChunk{
ChunkID: uuid.New().String(),
DocumentID: docID,
CollectionID: collectionID,
Content: chunk,
ChunkIndex: i,
CreatedAt: now,
}
db.Create(&c)
}
// 异步触发向量索引(如果 Retriever 可用)
if t.Retriever != nil {
go func() {
_ = t.Retriever.IndexDocument(context.Background(), docID)
}()
}
action := "created"
if !isNew {
action = "updated"
}
return map[string]interface{}{
"doc_id": docID,
"collection_id": collectionID,
"title": title,
"chunks": len(chunks),
"action": action,
}, nil
}
// splitIntoChunks 按段落分块,maxLen 为最大字符数
func splitIntoChunks(text string, maxLen int) []string {
paragraphs := strings.Split(text, "\n\n")
var chunks []string
var current strings.Builder
for _, p := range paragraphs {
if current.Len()+len(p) > maxLen && current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
}
if current.Len() > 0 {
current.WriteString("\n\n")
}
current.WriteString(p)
}
if current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
}
if len(chunks) == 0 {
chunks = []string{text}
}
return chunks
}
package tools
import (
"context"
"fmt"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// SendNotificationTool 向用户发送系统通知(站内信 + WebSocket 推送)
// 此工具之前在分类 map 中声明但从未注册,此次完整落地
type SendNotificationTool struct{}
func (t *SendNotificationTool) Name() string { return "send_notification" }
func (t *SendNotificationTool) Description() string {
return "向指定用户发送系统通知,支持站内信存储和实时 WebSocket 推送"
}
func (t *SendNotificationTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "user_id",
Type: "string",
Description: "接收通知的用户 ID(UUID)",
Required: true,
},
{
Name: "title",
Type: "string",
Description: "通知标题",
Required: true,
},
{
Name: "content",
Type: "string",
Description: "通知正文内容",
Required: true,
},
{
Name: "type",
Type: "string",
Description: "通知类型",
Required: false,
Enum: []string{"reminder", "alert", "info", "followup", "system"},
},
{
Name: "priority",
Type: "string",
Description: "优先级:normal/high/urgent",
Required: false,
Enum: []string{"normal", "high", "urgent"},
},
}
}
func (t *SendNotificationTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
userID, ok := params["user_id"].(string)
if !ok || userID == "" {
return nil, fmt.Errorf("user_id 必填")
}
title, ok := params["title"].(string)
if !ok || title == "" {
return nil, fmt.Errorf("title 必填")
}
content, _ := params["content"].(string)
nType, _ := params["type"].(string)
if nType == "" {
nType = "info"
}
priority, _ := params["priority"].(string)
if priority == "" {
priority = "normal"
}
// 写入系统日志表作为站内信(复用 SystemLog)
log := model.SystemLog{
Action: "notification",
Resource: fmt.Sprintf("%s/%s", nType, priority),
Detail: fmt.Sprintf("[%s] %s: %s", nType, title, content),
UserID: userID,
}
if err := db.Create(&log).Error; err != nil {
return nil, fmt.Errorf("写入通知失败: %w", err)
}
return map[string]interface{}{
"notification_id": log.ID,
"user_id": userID,
"title": title,
"type": nType,
"priority": priority,
"sent_at": time.Now().Format(time.RFC3339),
}, nil
}
package tools
import (
"context"
"encoding/json"
"fmt"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// WorkflowQueryTool 查询工作流执行状态
type WorkflowQueryTool struct{}
func (t *WorkflowQueryTool) Name() string { return "query_workflow_status" }
func (t *WorkflowQueryTool) Description() string {
return "查询工作流执行状态和结果,通常与 trigger_workflow 配合使用"
}
func (t *WorkflowQueryTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "execution_id",
Type: "string",
Description: "工作流执行 ID(由 trigger_workflow 返回)",
Required: true,
},
}
}
func (t *WorkflowQueryTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
executionID, ok := params["execution_id"].(string)
if !ok || executionID == "" {
return nil, fmt.Errorf("execution_id 必填")
}
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
var exec model.WorkflowExecution
if err := db.Where("execution_id = ?", executionID).First(&exec).Error; err != nil {
return nil, fmt.Errorf("未找到执行记录: %s", executionID)
}
result := map[string]interface{}{
"execution_id": exec.ExecutionID,
"workflow_id": exec.WorkflowID,
"status": exec.Status,
"trigger_by": exec.TriggerBy,
"started_at": exec.StartedAt,
"completed_at": exec.CompletedAt,
}
// 解析输出
if exec.Output != "" && exec.Output != "null" {
var output interface{}
if err := json.Unmarshal([]byte(exec.Output), &output); err == nil {
result["output"] = output
}
}
return result, nil
}
package tools
import (
"context"
"fmt"
"internet-hospital/pkg/agent"
)
// WorkflowTriggerFn 由 main.go 或 workflow 包在启动时注入
var WorkflowTriggerFn func(ctx context.Context, workflowID string, input map[string]interface{}) (string, error)
// WorkflowTriggerTool Agent 触发工作流执行
type WorkflowTriggerTool struct{}
func (t *WorkflowTriggerTool) Name() string { return "trigger_workflow" }
func (t *WorkflowTriggerTool) Description() string {
return "触发一个工作流异步执行,返回执行 ID 可用于后续查询状态"
}
func (t *WorkflowTriggerTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "workflow_id",
Type: "string",
Description: "工作流 ID,如 pre_consult、follow_up",
Required: true,
},
{
Name: "input",
Type: "object",
Description: "工作流输入参数(键值对)",
Required: false,
},
}
}
func (t *WorkflowTriggerTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
workflowID, ok := params["workflow_id"].(string)
if !ok || workflowID == "" {
return nil, fmt.Errorf("workflow_id 必填")
}
input, _ := params["input"].(map[string]interface{})
if input == nil {
input = map[string]interface{}{}
}
if WorkflowTriggerFn == nil {
return nil, fmt.Errorf("WorkflowTriggerFn 未注入")
}
executionID, err := WorkflowTriggerFn(ctx, workflowID, input)
if err != nil {
return nil, fmt.Errorf("触发工作流 %s 失败: %w", workflowID, err)
}
return map[string]interface{}{
"workflow_id": workflowID,
"execution_id": executionID,
"status": "triggered",
}, nil
}
......@@ -584,6 +584,11 @@ type CallParams struct {
UserID string // 用户ID
Messages []ChatMessage // 消息列表
RequestSummary string // 请求摘要(用于日志)
// 链路追踪字段(Agent调用时填写)
TraceID string // 链路追踪ID
AgentID string // 关联Agent ID
SessionID string // 关联会话ID
Iteration int // Agent第几轮迭代(0=非Agent调用)
}
// CallResult 统一AI调用结果
......@@ -655,6 +660,10 @@ func Call(ctx context.Context, params CallParams) CallResult {
Success: result.Error == nil,
ErrorMessage: errMsg,
IsMock: result.IsMock,
TraceID: params.TraceID,
AgentID: params.AgentID,
SessionID: params.SessionID,
Iteration: params.Iteration,
})
return result
......@@ -715,6 +724,10 @@ func CallStream(ctx context.Context, params CallParams, onChunk func(content str
Success: result.Error == nil,
ErrorMessage: errMsg,
IsMock: result.IsMock,
TraceID: params.TraceID,
AgentID: params.AgentID,
SessionID: params.SessionID,
Iteration: params.Iteration,
})
return result
......
......@@ -22,6 +22,11 @@ type LogParams struct {
Success bool
ErrorMessage string
IsMock bool
// 链路追踪字段
TraceID string
AgentID string
SessionID string
Iteration int
}
// SaveLog 异步保存AI调用日志
......@@ -66,6 +71,10 @@ func SaveLog(params LogParams) {
Success: params.Success,
ErrorMessage: params.ErrorMessage,
IsMock: params.IsMock,
TraceID: params.TraceID,
AgentID: params.AgentID,
SessionID: params.SessionID,
Iteration: params.Iteration,
CreatedAt: time.Now(),
}
......
......@@ -37,3 +37,17 @@ func GetActivePromptByScene(scene string) string {
}
return ""
}
// GetActivePromptByAgent 根据 agent_id 查找该Agent关联的激活系统提示词模板
func GetActivePromptByAgent(agentID string) string {
db := database.GetDB()
if db == nil {
return ""
}
var tmpl model.PromptTemplate
if err := db.Where("agent_id = ? AND template_type = 'system' AND status = 'active'", agentID).
Order("version DESC").First(&tmpl).Error; err == nil {
return tmpl.Content
}
return ""
}
package safety
import (
"regexp"
"strings"
"internet-hospital/internal/model"
)
// FilterResult 过滤结果
type FilterResult struct {
Text string // 处理后的文本(passed/replaced时有效)
Action string // passed | blocked | replaced | warned
MatchedRules []MatchedRule // 命中的规则列表
Blocked bool // 是否被拦截
}
// Filter 安全过滤器
type Filter struct {
loader *Loader
}
// NewFilter 创建过滤器
func NewFilter() *Filter {
return &Filter{loader: GetLoader()}
}
// globalFilter 全局过滤器单例
var globalFilter = &Filter{loader: nil}
// GetFilter 获取全局过滤器
func GetFilter() *Filter {
if globalFilter.loader == nil {
globalFilter.loader = GetLoader()
}
return globalFilter
}
// FilterInput 输入过滤(用户消息 → LLM之前)
func (f *Filter) FilterInput(text, agentID, traceID, userID string) FilterResult {
return f.filter(text, "input", agentID, traceID, userID)
}
// FilterOutput 输出过滤(LLM回复 → 返回用户之前)
func (f *Filter) FilterOutput(text, agentID, traceID, userID string) FilterResult {
return f.filter(text, "output", agentID, traceID, userID)
}
func (f *Filter) filter(text, direction, agentID, traceID, userID string) FilterResult {
rules := f.loader.Load()
if len(rules) == 0 {
return FilterResult{Text: text, Action: "passed"}
}
result := FilterResult{Text: text, Action: "passed"}
for _, rule := range rules {
// 方向过滤
if rule.Direction != "both" && rule.Direction != direction {
continue
}
// agentID过滤(空=全局规则,或匹配特定agent)
if rule.AgentID != "" && rule.AgentID != agentID {
continue
}
matched := matchRule(rule, text)
if !matched {
continue
}
mr := MatchedRule{RuleID: rule.ID, Word: rule.Word, Action: rule.Level}
result.MatchedRules = append(result.MatchedRules, mr)
switch rule.Level {
case "block":
result.Action = "blocked"
result.Blocked = true
result.Text = ""
// 记录日志
SaveFilterLog(traceID, direction, text, "", result.MatchedRules, "blocked", agentID, userID)
return result
case "replace":
replacement := rule.Replacement
if replacement == "" {
replacement = "***"
}
result.Text = replaceWord(result.Text, rule.Word, replacement, rule.IsRegex)
result.Action = "replaced"
case "warn":
if result.Action == "passed" {
result.Action = "warned"
}
}
}
if result.Action != "passed" {
SaveFilterLog(traceID, direction, text, result.Text, result.MatchedRules, result.Action, agentID, userID)
}
return result
}
func matchRule(rule model.SafetyWordRule, text string) bool {
if rule.IsRegex {
re, err := regexp.Compile(rule.Word)
if err != nil {
return false
}
return re.MatchString(text)
}
return strings.Contains(text, rule.Word)
}
func replaceWord(text, word, replacement string, isRegex bool) string {
if isRegex {
re, err := regexp.Compile(word)
if err != nil {
return text
}
return re.ReplaceAllString(text, replacement)
}
return strings.ReplaceAll(text, word, replacement)
}
package safety
import (
"encoding/json"
"log"
"sync"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
)
// globalLoader 全局规则加载器(单例)
var globalLoader *Loader
var loaderOnce sync.Once
// Loader 安全词规则加载器(带缓存)
type Loader struct {
mu sync.RWMutex
rules []model.SafetyWordRule
lastLoad time.Time
ttl time.Duration
}
// GetLoader 获取全局加载器单例
func GetLoader() *Loader {
loaderOnce.Do(func() {
globalLoader = &Loader{
ttl: 5 * time.Minute,
}
})
return globalLoader
}
// Load 加载规则(带TTL缓存)
func (l *Loader) Load() []model.SafetyWordRule {
l.mu.RLock()
if time.Since(l.lastLoad) < l.ttl && len(l.rules) > 0 {
rules := l.rules
l.mu.RUnlock()
return rules
}
l.mu.RUnlock()
return l.reload()
}
// Reload 强制重新加载
func (l *Loader) Reload() []model.SafetyWordRule {
return l.reload()
}
func (l *Loader) reload() []model.SafetyWordRule {
db := database.GetDB()
if db == nil {
return nil
}
var rules []model.SafetyWordRule
if err := db.Where("status = 'active'").Order("id asc").Find(&rules).Error; err != nil {
log.Printf("[safety] 加载安全词规则失败: %v", err)
l.mu.RLock()
cached := l.rules
l.mu.RUnlock()
return cached
}
l.mu.Lock()
l.rules = rules
l.lastLoad = time.Now()
l.mu.Unlock()
return rules
}
// SaveFilterLog 异步保存过滤日志
func SaveFilterLog(traceID, direction, original, filtered string, matched []MatchedRule, action, agentID, userID string) {
go func() {
db := database.GetDB()
if db == nil {
return
}
matchedJSON, _ := json.Marshal(matched)
entry := &model.SafetyFilterLog{
TraceID: traceID,
Direction: direction,
OriginalText: truncate(original, 2000),
FilteredText: truncate(filtered, 2000),
MatchedRules: string(matchedJSON),
Action: action,
AgentID: agentID,
UserID: userID,
}
if err := db.Create(entry).Error; err != nil {
log.Printf("[safety] 保存过滤日志失败: %v", err)
}
}()
}
// MatchedRule 命中的规则记录
type MatchedRule struct {
RuleID uint `json:"rule_id"`
Word string `json:"word"`
Action string `json:"action"`
}
func truncate(s string, max int) string {
runes := []rune(s)
if len(runes) <= max {
return s
}
return string(runes[:max]) + "..."
}
package workflow
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
exprlib "github.com/expr-lang/expr"
"github.com/google/uuid"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
agentpkg "internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
"github.com/google/uuid"
)
// NodeType 节点类型
......@@ -68,14 +72,14 @@ type Engine struct {
agentSvc interface {
Chat(ctx context.Context, agentID, userID, sessionID, message string, ctxData map[string]interface{}) (interface{}, error)
}
toolRegistry *agent.ToolRegistry
toolRegistry *agentpkg.ToolRegistry
}
var globalEngine *Engine
func GetEngine() *Engine {
if globalEngine == nil {
globalEngine = &Engine{toolRegistry: agent.GetRegistry()}
globalEngine = &Engine{toolRegistry: agentpkg.GetRegistry()}
}
return globalEngine
}
......@@ -253,15 +257,28 @@ func (e *Engine) createHumanTask(_ context.Context, node *Node, execCtx *Executi
return map[string]interface{}{"task_id": task.TaskID, "status": "pending"}, nil
}
func evalCondition(expr string, vars map[string]interface{}) bool {
// 简单实现:支持 "key == 'value'" 格式
for k, v := range vars {
check := fmt.Sprintf("%s == '%v'", k, v)
if expr == check {
return true
func evalCondition(expression string, vars map[string]interface{}) bool {
if expression == "" {
return true
}
program, err := exprlib.Compile(expression, exprlib.Env(vars), exprlib.AsBool())
if err != nil {
// 降级:简单字符串相等检查
for k, v := range vars {
if expression == fmt.Sprintf("%s == '%v'", k, v) {
return true
}
}
return false
}
result, err := exprlib.Run(program, vars)
if err != nil {
return false
}
return false
if b, ok := result.(bool); ok {
return b
}
return result != nil
}
// executeParallel 并行执行多个子节点
......@@ -367,84 +384,115 @@ func (e *Engine) executeLoop(ctx context.Context, wf *Workflow, node *Node, exec
}, nil
}
// executeCode 执行代码节点(简单表达式求值)
// executeCode 代码节点:使用 expr-lang 安全执行表达式
func (e *Engine) executeCode(_ context.Context, node *Node, execCtx *ExecutionContext) (interface{}, error) {
code, _ := node.Config["code"].(string)
if code == "" {
return map[string]interface{}{"skipped": "no code"}, nil
}
// 简单的表达式求值(实际生产环境应使用安全的脚本引擎)
// 这里只支持简单的变量赋值和返回
result := make(map[string]interface{})
// 将所有变量复制到结果中
// 合并所有变量作为上下文
env := make(map[string]interface{})
for k, v := range execCtx.Variables {
result[k] = v
env[k] = v
}
// 将所有节点输出复制到结果中
for k, v := range execCtx.NodeOutputs {
result["node_"+k] = v
env["node_"+k] = v
}
return result, nil
program, err := exprlib.Compile(code, exprlib.Env(env))
if err != nil {
return nil, fmt.Errorf("代码节点编译失败: %w", err)
}
result, err := exprlib.Run(program, env)
if err != nil {
return nil, fmt.Errorf("代码节点执行失败: %w", err)
}
// 如果结果是 map,合并回变量
if m, ok := result.(map[string]interface{}); ok {
for k, v := range m {
execCtx.Variables[k] = v
}
}
return map[string]interface{}{"result": result, "code": code}, nil
}
// executeHTTP 执行HTTP请求节点
// executeHTTP 执行 HTTP 请求节点(完整 net/http 实现)
func (e *Engine) executeHTTP(ctx context.Context, node *Node, execCtx *ExecutionContext) (interface{}, error) {
url, _ := node.Config["url"].(string)
rawURL, _ := node.Config["url"].(string)
method, _ := node.Config["method"].(string)
if method == "" {
method = "GET"
}
if url == "" {
if rawURL == "" {
return nil, fmt.Errorf("http node requires url")
}
// 替换URL中的变量
// 替换 URL 中的 {{var}} 变量
for k, v := range execCtx.Variables {
placeholder := fmt.Sprintf("${%s}", k)
if strVal, ok := v.(string); ok {
url = replaceAll(url, placeholder, strVal)
}
rawURL = replaceAll(rawURL, "{{"+k+"}}", fmt.Sprintf("%v", v))
}
// 创建HTTP请求
var reqBody []byte
// 构建请求 Body
var bodyBytes []byte
if body, ok := node.Config["body"].(map[string]interface{}); ok {
reqBody, _ = json.Marshal(body)
// 渲染 body 中的变量
for k, v := range body {
if s, ok2 := v.(string); ok2 {
for varK, varV := range execCtx.Variables {
s = replaceAll(s, "{{"+varK+"}}", fmt.Sprintf("%v", varV))
}
body[k] = s
}
}
bodyBytes, _ = json.Marshal(body)
}
req, err := newHTTPRequest(ctx, method, url, reqBody)
if err != nil {
return nil, err
var bodyReader io.Reader
if len(bodyBytes) > 0 {
bodyReader = bytes.NewReader(bodyBytes)
}
// 设置请求头
timeout := 30 * time.Second
if t, ok := node.Config["timeout"].(float64); ok && t > 0 {
timeout = time.Duration(t) * time.Second
}
httpCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req, err := http.NewRequestWithContext(httpCtx, method, rawURL, bodyReader)
if err != nil {
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
}
if len(bodyBytes) > 0 {
req.Header.Set("Content-Type", "application/json")
}
if headers, ok := node.Config["headers"].(map[string]interface{}); ok {
for k, v := range headers {
if strVal, ok := v.(string); ok {
req.Header[k] = strVal
if sv, ok2 := v.(string); ok2 {
req.Header.Set(k, sv)
}
}
}
// 执行请求
client := &httpClient{timeout: 30 * time.Second}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"status": 0,
}, nil
return map[string]interface{}{"error": err.Error(), "status": 0}, nil
}
defer resp.Body.Close()
return map[string]interface{}{
"status": resp.StatusCode,
"body": resp.Body,
}, nil
respBytes, _ := io.ReadAll(resp.Body)
result := map[string]interface{}{"status": resp.StatusCode}
var parsed interface{}
if json.Unmarshal(respBytes, &parsed) == nil {
result["body"] = parsed
} else {
result["body"] = string(respBytes)
}
return result, nil
}
// executeTemplate 执行模板渲染节点
......@@ -504,39 +552,3 @@ func replaceAll(s, old, new string) string {
return string(result)
}
// 简单的HTTP客户端封装
type httpClient struct {
timeout time.Duration
}
type httpResponse struct {
StatusCode int
Body string
}
func newHTTPRequest(ctx context.Context, method, url string, body []byte) (*httpRequest, error) {
return &httpRequest{
ctx: ctx,
method: method,
url: url,
body: body,
Header: make(map[string]string),
}, nil
}
type httpRequest struct {
ctx context.Context
method string
url string
body []byte
Header map[string]string
}
func (c *httpClient) Do(req *httpRequest) (*httpResponse, error) {
// 简化实现:实际应使用net/http
// 这里返回模拟响应,避免引入额外依赖
return &httpResponse{
StatusCode: 200,
Body: `{"status": "ok"}`,
}, nil
}
......@@ -62,6 +62,32 @@ export interface WorkflowExecution {
completed_at: string;
}
export interface AgentDefinition {
id: number;
agent_id: string;
name: string;
description: string;
category: string;
system_prompt: string;
tools: string; // JSON array string
config: string;
max_iterations: number;
status: string;
created_at: string;
updated_at: string;
}
export type AgentDefinitionParams = {
agent_id: string;
name: string;
description?: string;
category?: string;
system_prompt?: string;
tools?: string;
max_iterations?: number;
status?: string;
};
export interface WorkflowCreateParams {
workflow_id: string;
name: string;
......@@ -84,18 +110,21 @@ export interface KnowledgeDocumentParams {
// ==================== Agent API ====================
// AI 推理调用专用超时(2分钟)
const AI_TIMEOUT = 120000;
export const agentApi = {
chat: (agentId: string, params: {
session_id?: string;
message: string;
context?: Record<string, unknown>;
}) => post<AgentOutput>(`/agent/${agentId}/chat`, params),
}) => post<AgentOutput>(`/agent/${agentId}/chat`, params, { timeout: AI_TIMEOUT }),
listAgents: () =>
get<{ id: string; name: string; description: string }[]>('/agent/list'),
listTools: () =>
get<{ id: string; name: string; description: string; category: string; parameters: Record<string, unknown>; is_enabled: boolean; created_at: string }[]>('/agent/tools'),
get<{ id: string; name: string; display_name: string; description: string; category: string; parameters: Record<string, unknown>; status: string; is_enabled: boolean; cache_ttl: number; timeout: number; max_retries: number; created_at: string }[]>('/agent/tools'),
getSessions: (agentId?: string) =>
get<AgentSession[]>('/agent/sessions', { params: agentId ? { agent_id: agentId } : {} }),
......@@ -115,6 +144,69 @@ export const agentApi = {
get<{ agent_id: string; count: number; avg_iterations: number; avg_tokens: number; success_rate: number }[]>(
'/admin/agent/stats', { params }
),
// Agent 配置 CRUD
listDefinitions: () =>
get<AgentDefinition[]>('/agent/definitions'),
getDefinition: (agentId: string) =>
get<AgentDefinition>(`/agent/definitions/${agentId}`),
createDefinition: (data: AgentDefinitionParams) =>
post<AgentDefinition>('/agent/definitions', data),
updateDefinition: (agentId: string, data: Partial<AgentDefinitionParams>) =>
put<AgentDefinition>(`/agent/definitions/${agentId}`, data),
reloadAgent: (agentId: string) =>
put<null>(`/agent/definitions/${agentId}/reload`, {}),
updateToolStatus: (name: string, status: 'active' | 'disabled') =>
put<null>(`/agent/tools/${name}/status`, { status }),
};
// ==================== HTTP Tool API ====================
export interface HTTPToolDefinition {
id: number;
name: string;
display_name: string;
description: string;
category: string;
method: string;
url: string;
headers: string;
body_template: string;
auth_type: string;
auth_config: string;
parameters: string;
timeout: number;
cache_ttl: number;
status: string;
created_by: string;
created_at: string;
updated_at: string;
}
export type HTTPToolParams = Omit<HTTPToolDefinition, 'id' | 'created_at' | 'updated_at'>;
export const httpToolApi = {
list: () => get<HTTPToolDefinition[]>('/agent/http-tools'),
create: (data: Partial<HTTPToolParams>) =>
post<HTTPToolDefinition>('/agent/http-tools', data),
update: (id: number, data: Partial<HTTPToolParams>) =>
put<HTTPToolDefinition>(`/agent/http-tools/${id}`, data),
delete: (id: number) => del<null>(`/agent/http-tools/${id}`),
test: (id: number, params: Record<string, unknown>) =>
post<{ success: boolean; data?: unknown; error?: string }>(
`/agent/http-tools/${id}/test`, { params }
),
reload: () => post<{ message: string }>('/agent/http-tools/reload', {}),
};
// ==================== Workflow API ====================
......@@ -127,7 +219,7 @@ export const workflowApi = {
update: (id: number, data: Partial<WorkflowCreateParams>) =>
put<unknown>(`/admin/workflows/${id}`, data),
publish: (id: number) => post<null>(`/admin/workflows/${id}/publish`),
publish: (id: number) => put<null>(`/admin/workflows/${id}/publish`, {}),
execute: (workflowId: string, input?: Record<string, unknown>) =>
post<{ execution_id: string }>(`/workflow/${workflowId}/execute`, input || {}),
......
......@@ -59,7 +59,7 @@ export const chronicApi = {
listRenewals: () => get<RenewalRequest[]>('/chronic/renewals'),
createRenewal: (data: { chronic_id?: string; disease_name: string; medicines: string[]; reason: string }) =>
post<RenewalRequest>('/chronic/renewals', data),
getAIAdvice: (id: string) => post<{ advice: string }>(`/chronic/renewals/${id}/ai-advice`, {}),
getAIAdvice: (id: string) => post<{ advice: string }>(`/chronic/renewals/${id}/ai-advice`, {}, { timeout: 120000 }),
// 用药提醒
listReminders: () => get<MedicationReminder[]>('/chronic/reminders'),
......
......@@ -129,7 +129,7 @@ export const consultApi = {
tool_calls?: import('./agent').ToolCall[];
iterations?: number;
total_tokens?: number;
}>(`/consult/${id}/ai-assist`, { scene }),
}>(`/consult/${id}/ai-assist`, { scene }, { timeout: 120000 }),
// 取消问诊
cancelConsult: (id: string, reason?: string) =>
......
......@@ -191,5 +191,5 @@ export const doctorPortalApi = {
iterations?: number;
has_warning: boolean;
has_contraindication: boolean;
}>('/doctor-portal/prescription/check', params),
}>('/doctor-portal/prescription/check', params, { timeout: 120000 }),
};
......@@ -93,7 +93,7 @@ export const healthApi = {
listReports: () => get<LabReport[]>('/health/reports'),
createReport: (data: CreateReportParams) => post<LabReport>('/health/reports', data),
deleteReport: (id: string) => del<null>(`/health/reports/${id}`),
aiInterpret: (id: string) => post<{ interpret: string }>(`/health/reports/${id}/ai-interpret`),
aiInterpret: (id: string) => post<{ interpret: string }>(`/health/reports/${id}/ai-interpret`, undefined, { timeout: 120000 }),
// 家庭成员
listFamily: () => get<FamilyMember[]>('/health/family'),
createFamily: (data: FamilyMemberParams) => post<FamilyMember>('/health/family', data),
......
......@@ -21,9 +21,10 @@ const getBaseURL = () => {
};
// 创建 axios 实例
// 普通接口:30s;AI / Agent 接口(模型推理)可能需要 2min,通过 config.timeout 单独覆盖
const request: AxiosInstance = axios.create({
baseURL: getBaseURL(),
timeout: 10000,
timeout: 30000,
headers: {
'Content-Type': 'application/json',
},
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -7,6 +7,7 @@ import {
DashboardOutlined, UserOutlined, MessageOutlined, CalendarOutlined,
SettingOutlined, LogoutOutlined, BellOutlined, MedicineBoxOutlined,
RobotOutlined, DollarOutlined, FolderOpenOutlined, SafetyCertificateOutlined,
CheckCircleOutlined,
} from '@ant-design/icons';
import { useUserStore } from '@/store/userStore';
......@@ -16,6 +17,7 @@ const { Text } = Typography;
const menuItems = [
{ key: '/doctor/workbench', icon: <DashboardOutlined />, label: '工作台' },
{ key: '/doctor/consult', icon: <MessageOutlined />, label: '问诊大厅' },
{ key: '/doctor/tasks', icon: <CheckCircleOutlined />, label: '待办任务' },
{ key: '/doctor/chronic/review', icon: <SafetyCertificateOutlined />, label: '慢病续方' },
{ key: '/doctor/patient', icon: <FolderOpenOutlined />, label: '患者档案' },
{ key: '/doctor/schedule', icon: <CalendarOutlined />, label: '排班管理' },
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment