Commit 59503a6c authored by yuguo's avatar yuguo

fix

parent 9033450a
......@@ -24,7 +24,9 @@
"Bash(./api.exe:*)",
"Bash(PGPASSWORD=123456 psql:*)",
"Bash(go mod:*)",
"Bash(go vet:*)"
"Bash(go vet:*)",
"Bash(cmd:*)",
"Bash(go get:*)"
]
}
}
......@@ -92,6 +92,11 @@ func main() {
&model.PaymentOrder{},
&model.DoctorIncome{},
&model.DoctorWithdrawal{},
// 安全过滤
&model.SafetyWordRule{},
&model.SafetyFilterLog{},
// HTTP 动态工具
&model.HTTPToolDefinition{},
); err != nil {
log.Printf("Warning: AutoMigrate failed: %v", err)
} else {
......@@ -134,8 +139,11 @@ func main() {
log.Printf("Warning: Failed to init departments and doctors: %v", err)
}
// 初始化Agent工具
// 初始化Agent工具(内置工具 + HTTP 动态工具)
internalagent.InitTools()
internalagent.LoadHTTPTools()
// 注入跨包回调(AgentCallFn / WorkflowTriggerFn)
internalagent.WireCallbacks()
// 设置 Gin 模式
gin.SetMode(cfg.Server.Mode)
......@@ -168,6 +176,9 @@ func main() {
authApi := api.Group("")
authApi.Use(middleware.JWTAuth())
authApi.GET("/user/me", userHandler.GetCurrentUser)
authApi.POST("/user/logout", userHandler.Logout)
authApi.POST("/user/verify-identity", userHandler.VerifyIdentity)
authApi.POST("/doctor/appointment", doctorHandler.MakeAppointment)
// 医生端路由(需要认证 + 医生角色)
doctorPortalHandler := doctorportal.NewHandler()
......
......@@ -22,6 +22,7 @@ require (
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/expr-lang/expr v1.17.8 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
......
......@@ -18,6 +18,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/expr-lang/expr v1.17.8 h1:W1loDTT+0PQf5YteHSTpju2qfUfNoBt4yw9+wOEU9VM=
github.com/expr-lang/expr v1.17.8/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
......
package internalagent
import (
"internet-hospital/pkg/agent"
"encoding/json"
"internet-hospital/internal/model"
)
// NewPreConsultAgent 预问诊Agent
func NewPreConsultAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "pre_consult_agent",
// defaultAgentDefinitions 返回内置Agent的默认数据库配置
// 当数据库中不存在时,会自动写入并作为初始配置
func defaultAgentDefinitions() []model.AgentDefinition {
preConsultTools, _ := json.Marshal([]string{"query_symptom_knowledge", "recommend_department"})
diagnosisTools, _ := json.Marshal([]string{"query_medical_record", "query_symptom_knowledge", "search_medical_knowledge"})
prescriptionTools, _ := json.Marshal([]string{"query_drug", "check_drug_interaction", "check_contraindication", "calculate_dosage"})
followUpTools, _ := json.Marshal([]string{"query_medical_record", "query_drug", "query_symptom_knowledge"})
return []model.AgentDefinition{
{
AgentID: "pre_consult_agent",
Name: "预问诊智能助手",
Description: "通过多轮对话收集患者症状,生成预问诊报告",
Category: "patient",
SystemPrompt: `你是一位专业的AI预问诊助手。你的职责是:
1. 通过友好的对话收集患者的症状信息
2. 询问症状的持续时间、严重程度、伴随症状等
......@@ -18,20 +28,15 @@ func NewPreConsultAgent() *agent.ReActAgent {
5. 生成简洁的预问诊报告
请用中文与患者交流,语气温和专业。不要做出确定性诊断,只提供参考建议。`,
Tools: []string{
"query_symptom_knowledge",
"recommend_department",
},
Tools: string(preConsultTools),
MaxIterations: 5,
})
}
// NewDiagnosisAgent 诊断辅助Agent
func NewDiagnosisAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "diagnosis_agent",
Status: "active",
},
{
AgentID: "diagnosis_agent",
Name: "诊断辅助Agent",
Description: "辅助医生进行诊断,提供鉴别诊断建议",
Category: "doctor",
SystemPrompt: `你是一位经验丰富的诊断辅助AI,协助医生进行临床决策。
你可以:
1. 查询患者病历记录(使用query_medical_record)
......@@ -46,21 +51,15 @@ func NewDiagnosisAgent() *agent.ReActAgent {
- 综合分析后给出诊断建议
请基于循证医学原则提供建议,所有建议仅供医生参考。`,
Tools: []string{
"query_medical_record",
"query_symptom_knowledge",
"search_medical_knowledge",
},
Tools: string(diagnosisTools),
MaxIterations: 10,
})
}
// NewPrescriptionAgent 处方审核Agent
func NewPrescriptionAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "prescription_agent",
Status: "active",
},
{
AgentID: "prescription_agent",
Name: "处方审核Agent",
Description: "审核处方合理性,检查药物相互作用、禁忌症和剂量",
Category: "pharmacy",
SystemPrompt: `你是一位专业的临床药师AI,负责处方审核。
你的职责:
1. 查询药品信息(规格、用法、禁忌)
......@@ -76,22 +75,15 @@ func NewPrescriptionAgent() *agent.ReActAgent {
- 最后综合所有检查结果给出审核意见
请严格按照药品说明书和临床指南进行审核,对于存在风险的处方要明确指出。`,
Tools: []string{
"query_drug",
"check_drug_interaction",
"check_contraindication",
"calculate_dosage",
},
Tools: string(prescriptionTools),
MaxIterations: 10,
})
}
// NewFollowUpAgent 随访管理Agent
func NewFollowUpAgent() *agent.ReActAgent {
return agent.NewReActAgent(agent.ReActConfig{
ID: "follow_up_agent",
Status: "active",
},
{
AgentID: "follow_up_agent",
Name: "随访管理Agent",
Description: "管理患者随访,提醒用药、复诊,收集健康数据",
Category: "patient",
SystemPrompt: `你是一位专业的随访管理AI助手。你的职责是:
1. 查询患者的处方和用药情况
2. 提醒患者按时用药
......@@ -100,11 +92,9 @@ func NewFollowUpAgent() *agent.ReActAgent {
5. 生成随访报告
请用温和关怀的语气与患者交流,关注患者的用药依从性和健康状况变化。`,
Tools: []string{
"query_medical_record",
"query_drug",
"query_symptom_knowledge",
},
Tools: string(followUpTools),
MaxIterations: 8,
})
Status: "active",
},
}
}
This diff is collapsed.
package internalagent
import (
"log"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/agent/tools"
"internet-hospital/pkg/database"
)
// LoadHTTPTools 从数据库加载所有 active HTTP 工具并注册到 ToolRegistry
func LoadHTTPTools() {
db := database.GetDB()
if db == nil {
return
}
var defs []model.HTTPToolDefinition
if err := db.Where("status = 'active'").Find(&defs).Error; err != nil {
log.Printf("[LoadHTTPTools] 加载失败: %v", err)
return
}
r := agent.GetRegistry()
for i := range defs {
r.Register(tools.NewDynamicHTTPTool(&defs[i]))
}
if len(defs) > 0 {
log.Printf("[LoadHTTPTools] 已加载 %d 个 HTTP 工具", len(defs))
}
}
// ReloadHTTPTools 卸载所有旧 HTTP 工具,然后重新从数据库加载
func ReloadHTTPTools() {
db := database.GetDB()
if db == nil {
return
}
r := agent.GetRegistry()
// 卸载所有已注册的 HTTP 工具(不论状态)
var all []model.HTTPToolDefinition
db.Find(&all)
for _, def := range all {
r.Unregister(def.Name)
}
// 重新加载 active 工具
LoadHTTPTools()
log.Printf("[ReloadHTTPTools] HTTP 工具已热重载")
}
package internalagent
import (
"strconv"
"github.com/gin-gonic/gin"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent/tools"
"internet-hospital/pkg/database"
"internet-hospital/pkg/response"
)
// ListHTTPTools GET /agent/http-tools
func (h *Handler) ListHTTPTools(c *gin.Context) {
db := database.GetDB()
var defs []model.HTTPToolDefinition
db.Order("created_at DESC").Find(&defs)
response.Success(c, defs)
}
// CreateHTTPTool POST /agent/http-tools
func (h *Handler) CreateHTTPTool(c *gin.Context) {
var req model.HTTPToolDefinition
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
userID, _ := c.Get("user_id")
req.CreatedBy, _ = userID.(string)
if req.Headers == "" {
req.Headers = "{}"
}
if req.AuthConfig == "" {
req.AuthConfig = "{}"
}
if req.Parameters == "" {
req.Parameters = "[]"
}
db := database.GetDB()
if err := db.Create(&req).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
ReloadHTTPTools()
response.Success(c, req)
}
// UpdateHTTPTool PUT /agent/http-tools/:id
func (h *Handler) UpdateHTTPTool(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
response.BadRequest(c, "invalid id")
return
}
var req map[string]interface{}
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, err.Error())
return
}
db := database.GetDB()
if err := db.Model(&model.HTTPToolDefinition{}).Where("id = ?", id).Updates(req).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
ReloadHTTPTools()
response.Success(c, nil)
}
// DeleteHTTPTool DELETE /agent/http-tools/:id
func (h *Handler) DeleteHTTPTool(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
response.BadRequest(c, "invalid id")
return
}
db := database.GetDB()
var def model.HTTPToolDefinition
if err := db.First(&def, id).Error; err != nil {
response.Error(c, 404, "tool not found")
return
}
ReloadHTTPTools()
if err := db.Delete(&model.HTTPToolDefinition{}, id).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
ReloadHTTPTools()
response.Success(c, nil)
}
// TestHTTPTool POST /agent/http-tools/:id/test
func (h *Handler) TestHTTPTool(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.Atoi(idStr)
if err != nil {
response.BadRequest(c, "invalid id")
return
}
var params map[string]interface{}
_ = c.ShouldBindJSON(&params)
if params == nil {
params = map[string]interface{}{}
}
db := database.GetDB()
var def model.HTTPToolDefinition
if err := db.First(&def, id).Error; err != nil {
response.Error(c, 404, "tool not found")
return
}
tool := tools.NewDynamicHTTPTool(&def)
result, execErr := tool.Execute(c.Request.Context(), params)
if execErr != nil {
response.Success(c, gin.H{
"success": false,
"error": execErr.Error(),
"result": result,
})
return
}
response.Success(c, gin.H{"success": true, "result": result})
}
// ReloadHTTPToolsAPI POST /agent/http-tools/reload
func (h *Handler) ReloadHTTPToolsAPI(c *gin.Context) {
ReloadHTTPTools()
response.SuccessWithMessage(c, "HTTP 工具已热重载", nil)
}
package internalagent
import (
"context"
"encoding/json"
"fmt"
"log"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/agent/tools"
"internet-hospital/pkg/database"
"internet-hospital/pkg/rag"
"internet-hospital/pkg/workflow"
)
// InitTools 注册所有工具到全局注册中心
// InitTools 注册所有工具到全局注册中心,并同步元数据到数据库
func InitTools() {
db := database.GetDB()
r := agent.GetRegistry()
......@@ -27,6 +34,131 @@ func InitTools() {
r.Register(&tools.ContraindicationTool{DB: db})
r.Register(&tools.DosageCalculatorTool{DB: db})
// 知识库检索工具
// 知识库工具
r.Register(&tools.KnowledgeSearchTool{Retriever: retriever})
r.Register(&tools.KnowledgeWriteTool{DB: db, Retriever: retriever})
r.Register(&tools.KnowledgeListTool{})
// Agent互调 & 工作流工具
r.Register(&tools.AgentCallerTool{})
r.Register(&tools.WorkflowTriggerTool{})
r.Register(&tools.WorkflowQueryTool{})
r.Register(&tools.HumanReviewTool{})
// 表达式执行工具
r.Register(&tools.ExprEvalTool{})
// 通知 & 随访工具
r.Register(&tools.SendNotificationTool{})
r.Register(&tools.GenerateFollowUpPlanTool{})
// 同步工具元数据到 AgentTool 表,并应用启/停状态
syncToolsToDB(r)
applyToolStatus(r)
}
// WireCallbacks 注入跨包回调(在 InitTools 和 GetService 初始化完成后调用)
// AgentCallFn: AgentCallerTool → AgentService.Chat
// WorkflowTriggerFn: WorkflowTriggerTool → workflow.Engine.Execute
func WireCallbacks() {
svc := GetService()
tools.AgentCallFn = func(ctx context.Context, agentID, userID, sessionID, message string, ctxData map[string]interface{}) (string, error) {
output, err := svc.Chat(ctx, agentID, userID, sessionID, message, ctxData)
if err != nil {
return "", err
}
if output == nil {
return "", fmt.Errorf("agent %s 不存在或未启用", agentID)
}
return output.Response, nil
}
tools.WorkflowTriggerFn = func(ctx context.Context, workflowID string, input map[string]interface{}) (string, error) {
return workflow.GetEngine().Execute(ctx, workflowID, input, "agent")
}
log.Println("[InitTools] AgentCallFn & WorkflowTriggerFn 注入完成")
}
// syncToolsToDB 将注册的工具元数据写入 AgentTool 表(不存在则创建,存在则不覆盖)
func syncToolsToDB(r *agent.ToolRegistry) {
db := database.GetDB()
if db == nil {
return
}
categoryMap := map[string]string{
// 基础查询
"query_symptom_knowledge": "knowledge",
"recommend_department": "recommendation",
"query_medical_record": "medical",
"search_medical_knowledge": "knowledge",
"query_drug": "pharmacy",
// 处方安全
"check_drug_interaction": "safety",
"check_contraindication": "safety",
"calculate_dosage": "pharmacy",
// 知识库
"write_knowledge": "knowledge",
"list_knowledge_collections": "knowledge",
// Agent & 工作流
"call_agent": "agent",
"trigger_workflow": "workflow",
"query_workflow_status": "workflow",
"request_human_review": "workflow",
// 表达式
"eval_expression": "expression",
// 通知 & 随访
"generate_follow_up_plan": "follow_up",
"send_notification": "notification",
}
for name, tool := range r.All() {
params := make([]map[string]interface{}, 0)
for _, p := range tool.Parameters() {
params = append(params, map[string]interface{}{
"name": p.Name,
"type": p.Type,
"required": p.Required,
})
}
paramsJSON, _ := json.Marshal(params)
category := categoryMap[name]
if category == "" {
category = "other"
}
var existing model.AgentTool
if err := db.Where("name = ?", name).First(&existing).Error; err != nil {
// 不存在则创建
entry := model.AgentTool{
Name: name,
DisplayName: tool.Description(),
Description: tool.Description(),
Category: category,
Parameters: string(paramsJSON),
Status: "active",
}
if err := db.Create(&entry).Error; err != nil {
log.Printf("[InitTools] 同步工具 %s 到数据库失败: %v", name, err)
}
}
// 已存在则不更新(保留管理员的自定义状态)
}
}
// applyToolStatus 从数据库读取工具状态(disabled 的工具在 executor 中已检查,此处仅记录日志)
func applyToolStatus(r *agent.ToolRegistry) {
db := database.GetDB()
if db == nil {
return
}
var disabledTools []model.AgentTool
db.Where("status = 'disabled'").Find(&disabledTools)
if len(disabledTools) > 0 {
for _, t := range disabledTools {
log.Printf("[InitTools] 工具 %s 已被禁用", t.Name)
}
}
}
......@@ -3,6 +3,8 @@ package internalagent
import (
"context"
"encoding/json"
"log"
"sync"
"time"
"internet-hospital/internal/model"
......@@ -15,6 +17,7 @@ import (
// AgentService Agent服务
type AgentService struct {
mu sync.RWMutex
agents map[string]*agent.ReActAgent
}
......@@ -23,26 +26,124 @@ var globalAgentService *AgentService
func GetService() *AgentService {
if globalAgentService == nil {
globalAgentService = &AgentService{
agents: map[string]*agent.ReActAgent{
"pre_consult_agent": NewPreConsultAgent(),
"diagnosis_agent": NewDiagnosisAgent(),
"prescription_agent": NewPrescriptionAgent(),
"follow_up_agent": NewFollowUpAgent(),
},
agents: make(map[string]*agent.ReActAgent),
}
globalAgentService.loadFromDB()
globalAgentService.ensureBuiltinAgents()
}
return globalAgentService
}
// loadFromDB 从数据库加载所有 active 的 AgentDefinition
func (s *AgentService) loadFromDB() {
db := database.GetDB()
if db == nil {
return
}
var definitions []model.AgentDefinition
if err := db.Where("status = 'active'").Find(&definitions).Error; err != nil {
log.Printf("[AgentService] 从数据库加载Agent失败: %v", err)
return
}
for _, def := range definitions {
a := buildAgentFromDef(def)
s.agents[def.AgentID] = a
}
log.Printf("[AgentService] 从数据库加载了 %d 个Agent", len(definitions))
}
func buildAgentFromDef(def model.AgentDefinition) *agent.ReActAgent {
var tools []string
if def.Tools != "" {
json.Unmarshal([]byte(def.Tools), &tools)
}
maxIter := def.MaxIterations
if maxIter <= 0 {
maxIter = 10
}
return agent.NewReActAgent(agent.ReActConfig{
ID: def.AgentID,
Name: def.Name,
Description: def.Description,
SystemPrompt: def.SystemPrompt,
Tools: tools,
MaxIterations: maxIter,
})
}
// ensureBuiltinAgents 如果数据库中不存在内置Agent,则写入默认配置
func (s *AgentService) ensureBuiltinAgents() {
db := database.GetDB()
if db == nil {
return
}
defaults := defaultAgentDefinitions()
for _, def := range defaults {
// 如果内存中已有(来自数据库),跳过
s.mu.RLock()
_, exists := s.agents[def.AgentID]
s.mu.RUnlock()
if exists {
continue
}
// 写入数据库
var existing model.AgentDefinition
if err := db.Where("agent_id = ?", def.AgentID).First(&existing).Error; err != nil {
// 不存在则创建
if err := db.Create(&def).Error; err != nil {
log.Printf("[AgentService] 写入默认Agent失败: %v", err)
continue
}
existing = def
}
s.mu.Lock()
s.agents[def.AgentID] = buildAgentFromDef(existing)
s.mu.Unlock()
}
}
// ReloadAgent 热重载单个Agent(管理端修改配置后调用)
func (s *AgentService) ReloadAgent(agentID string) error {
db := database.GetDB()
var def model.AgentDefinition
if err := db.Where("agent_id = ?", agentID).First(&def).Error; err != nil {
return err
}
a := buildAgentFromDef(def)
s.mu.Lock()
if def.Status == "active" {
s.agents[agentID] = a
} else {
delete(s.agents, agentID)
}
s.mu.Unlock()
log.Printf("[AgentService] 已热重载Agent: %s", agentID)
return nil
}
// ReloadAll 重新加载所有Agent
func (s *AgentService) ReloadAll() {
s.mu.Lock()
s.agents = make(map[string]*agent.ReActAgent)
s.mu.Unlock()
s.loadFromDB()
s.ensureBuiltinAgents()
}
func (s *AgentService) GetAgent(agentID string) (*agent.ReActAgent, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
a, ok := s.agents[agentID]
return a, ok
}
func (s *AgentService) ListAgents() []map[string]string {
result := make([]map[string]string, 0, len(s.agents))
func (s *AgentService) ListAgents() []map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
result := make([]map[string]interface{}, 0, len(s.agents))
for _, a := range s.agents {
result = append(result, map[string]string{
result = append(result, map[string]interface{}{
"id": a.ID(),
"name": a.Name(),
"description": a.Description(),
......@@ -61,10 +162,10 @@ func (s *AgentService) Chat(ctx context.Context, agentID, userID, sessionID, mes
db := database.GetDB()
// 加载或创建会话
var session model.AgentSession
if sessionID == "" {
sessionID = uuid.New().String()
}
var session model.AgentSession
db.Where("session_id = ?", sessionID).First(&session)
// 解析历史消息
......@@ -119,6 +220,7 @@ func (s *AgentService) Chat(ctx context.Context, agentID, userID, sessionID, mes
outputJSON, _ := json.Marshal(output)
toolCallsJSON, _ := json.Marshal(output.ToolCalls)
db.Create(&model.AgentExecutionLog{
TraceID: output.TraceID,
SessionID: sessionID,
AgentID: agentID,
UserID: userID,
......
......@@ -11,6 +11,9 @@ type AgentTool struct {
Category string `gorm:"type:varchar(50)"`
Parameters string `gorm:"type:jsonb"`
Status string `gorm:"type:varchar(20);default:'active'"`
CacheTTL int `gorm:"default:0"` // 缓存秒数,0=不缓存
Timeout int `gorm:"default:30"` // 执行超时秒数
MaxRetries int `gorm:"default:0"` // 失败重试次数
CreatedAt time.Time
UpdatedAt time.Time
}
......@@ -18,6 +21,7 @@ type AgentTool struct {
// AgentToolLog 工具调用日志
type AgentToolLog struct {
ID uint `gorm:"primaryKey"`
TraceID string `gorm:"type:varchar(100);index"` // 链路追踪ID
ToolName string `gorm:"type:varchar(100);index"`
AgentID string `gorm:"type:varchar(100);index"`
SessionID string `gorm:"type:varchar(100);index"`
......@@ -27,6 +31,7 @@ type AgentToolLog struct {
Success bool
ErrorMessage string `gorm:"type:text"`
DurationMs int
Iteration int // Agent第几轮迭代
CreatedAt time.Time
}
......@@ -62,6 +67,7 @@ type AgentSession struct {
// AgentExecutionLog Agent执行日志
type AgentExecutionLog struct {
ID uint `gorm:"primaryKey"`
TraceID string `gorm:"type:varchar(100);index"` // 链路追踪ID
SessionID string `gorm:"type:varchar(100);index"`
AgentID string `gorm:"type:varchar(100);index"`
UserID string `gorm:"type:uuid;index"`
......
......@@ -42,6 +42,11 @@ type AIUsageLog struct {
Success bool `gorm:"default:true" json:"success"`
ErrorMessage string `gorm:"type:text" json:"error_message,omitempty"`
IsMock bool `gorm:"default:false" json:"is_mock"` // 是否为模拟调用
// 链路追踪字段
TraceID string `gorm:"type:varchar(100);index" json:"trace_id"` // 链路追踪ID,关联同一次Agent执行
AgentID string `gorm:"type:varchar(100);index" json:"agent_id"` // 关联Agent
SessionID string `gorm:"type:varchar(100);index" json:"session_id"` // 关联会话
Iteration int `json:"iteration"` // Agent第几轮迭代(0=非Agent调用)
CreatedAt time.Time `json:"created_at"`
}
......
package model
import "time"
// HTTPToolDefinition 动态 HTTP 工具定义(管理员从 UI 配置,无需改代码)
type HTTPToolDefinition struct {
ID uint `gorm:"primaryKey"`
Name string `gorm:"type:varchar(100);uniqueIndex"` // 工具名,如 get_weather
DisplayName string `gorm:"type:varchar(200)"`
Description string `gorm:"type:text"`
Category string `gorm:"type:varchar(50);default:'http'"`
Method string `gorm:"type:varchar(10);default:'GET'"` // GET/POST/PUT/DELETE
URL string `gorm:"type:text"` // 支持 {{param}} 模板变量
Headers string `gorm:"type:jsonb;default:'{}'"` // {"X-Key": "{{api_key}}"}
BodyTemplate string `gorm:"type:text"` // JSON body 模板,支持 {{param}}
AuthType string `gorm:"type:varchar(20);default:'none'"` // none/bearer/basic/apikey
AuthConfig string `gorm:"type:jsonb;default:'{}'"` // 认证配置
Parameters string `gorm:"type:jsonb;default:'[]'"` // ToolParameter 数组 JSON
Timeout int `gorm:"default:10"` // 超时秒数
CacheTTL int `gorm:"default:0"` // 缓存 TTL 秒,0=不缓存
Status string `gorm:"type:varchar(20);default:'active'"`
CreatedBy string `gorm:"type:varchar(100)"`
CreatedAt time.Time
UpdatedAt time.Time
}
......@@ -14,6 +14,11 @@ type PromptTemplate struct {
Scene string `gorm:"type:varchar(50)" json:"scene"` // 应用场景
Content string `gorm:"type:text;not null" json:"content"` // Prompt内容
Status string `gorm:"type:varchar(20);default:'active'" json:"status"` // active | disabled
// 智能体关联字段
AgentID string `gorm:"type:varchar(100);index" json:"agent_id"` // 关联的Agent ID,空表示通用模板
TemplateType string `gorm:"type:varchar(20);default:'system'" json:"template_type"` // system | user | tool_result
Variables string `gorm:"type:jsonb" json:"variables"` // 变量定义 [{"name":"patient_name","type":"string","required":true}]
Version int `gorm:"default:1" json:"version"` // 版本号
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
......
package model
import "time"
// SafetyWordRule 安全词规则
type SafetyWordRule struct {
ID uint `gorm:"primaryKey" json:"id"`
Word string `gorm:"type:varchar(200);index" json:"word"` // 敏感词/正则
Category string `gorm:"type:varchar(50)" json:"category"` // medical_claim | drug_promotion | privacy | toxicity | injection
Level string `gorm:"type:varchar(20)" json:"level"` // block | warn | replace
Replacement string `gorm:"type:varchar(200)" json:"replacement"` // level=replace时的替换文本
Direction string `gorm:"type:varchar(10);default:'both'" json:"direction"` // input | output | both
IsRegex bool `gorm:"default:false" json:"is_regex"` // 是否正则表达式
AgentID string `gorm:"type:varchar(100)" json:"agent_id"` // 特定Agent的规则,空=全局
Status string `gorm:"type:varchar(20);default:'active'" json:"status"` // active | disabled
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (SafetyWordRule) TableName() string {
return "safety_word_rules"
}
// SafetyFilterLog 过滤日志
type SafetyFilterLog struct {
ID uint `gorm:"primaryKey" json:"id"`
TraceID string `gorm:"type:varchar(100);index" json:"trace_id"`
Direction string `gorm:"type:varchar(10)" json:"direction"` // input | output
OriginalText string `gorm:"type:text" json:"original_text"`
FilteredText string `gorm:"type:text" json:"filtered_text"`
MatchedRules string `gorm:"type:jsonb" json:"matched_rules"` // [{"rule_id":1,"word":"xxx","action":"block"}]
Action string `gorm:"type:varchar(20)" json:"action"` // passed | blocked | replaced | warned
AgentID string `gorm:"type:varchar(100)" json:"agent_id"`
UserID string `gorm:"type:uuid" json:"user_id"`
CreatedAt time.Time `json:"created_at"`
}
func (SafetyFilterLog) TableName() string {
return "safety_filter_logs"
}
package admin
import (
"strconv"
"github.com/gin-gonic/gin"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/response"
)
// GetTraceDetail 全链路追踪详情:给定 trace_id,返回该次Agent执行的所有 LLM 调用和工具调用
func (h *Handler) GetTraceDetail(c *gin.Context) {
traceID := c.Query("trace_id")
if traceID == "" {
response.BadRequest(c, "trace_id 不能为空")
return
}
db := database.GetDB()
var execLogs []model.AgentExecutionLog
db.Where("trace_id = ?", traceID).Order("created_at asc").Find(&execLogs)
var usageLogs []model.AIUsageLog
db.Where("trace_id = ?", traceID).Order("iteration asc, created_at asc").Find(&usageLogs)
var toolLogs []model.AgentToolLog
db.Where("trace_id = ?", traceID).Order("iteration asc, created_at asc").Find(&toolLogs)
response.Success(c, gin.H{
"trace_id": traceID,
"execution_logs": execLogs,
"llm_calls": usageLogs,
"tool_calls": toolLogs,
})
}
// GetAICenterStats 全量统计看板
func (h *Handler) GetAICenterStats(c *gin.Context) {
db := database.GetDB()
// LLM 调用统计
var totalCalls, successCalls, mockCalls int64
var totalTokens struct{ Sum int64 }
db.Model(&model.AIUsageLog{}).Count(&totalCalls)
db.Model(&model.AIUsageLog{}).Where("success = true").Count(&successCalls)
db.Model(&model.AIUsageLog{}).Where("is_mock = true").Count(&mockCalls)
db.Model(&model.AIUsageLog{}).Select("COALESCE(SUM(total_tokens), 0) as sum").Scan(&totalTokens)
// Agent 执行统计
var agentExecs, agentSuccess int64
db.Model(&model.AgentExecutionLog{}).Count(&agentExecs)
db.Model(&model.AgentExecutionLog{}).Where("success = true").Count(&agentSuccess)
// 工具调用统计
var toolCalls, toolSuccess int64
db.Model(&model.AgentToolLog{}).Count(&toolCalls)
db.Model(&model.AgentToolLog{}).Where("success = true").Count(&toolSuccess)
// 各 Agent 调用量
type AgentCallCount struct {
AgentID string `json:"agent_id"`
Count int64 `json:"count"`
}
var agentCounts []AgentCallCount
db.Model(&model.AIUsageLog{}).
Select("agent_id, COUNT(*) as count").
Where("agent_id != ''").
Group("agent_id").
Order("count DESC").
Limit(10).
Scan(&agentCounts)
// 各场景调用量
type SceneCallCount struct {
Scene string `json:"scene"`
Count int64 `json:"count"`
}
var sceneCounts []SceneCallCount
db.Model(&model.AIUsageLog{}).
Select("scene, COUNT(*) as count").
Group("scene").
Order("count DESC").
Limit(10).
Scan(&sceneCounts)
// 最近24小时调用趋势(按小时)
type HourlyTrend struct {
Hour string `json:"hour"`
Count int64 `json:"count"`
}
var trend []HourlyTrend
db.Model(&model.AIUsageLog{}).
Select("TO_CHAR(created_at, 'HH24') as hour, COUNT(*) as count").
Where("created_at >= NOW() - INTERVAL '24 hours'").
Group("hour").
Order("hour asc").
Scan(&trend)
// 近期日志列表(用于追踪)
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
agentID := c.Query("agent_id")
var recentLogsTotal int64
var recentLogs []model.AIUsageLog
q := db.Model(&model.AIUsageLog{})
if agentID != "" {
q = q.Where("agent_id = ?", agentID)
}
q.Count(&recentLogsTotal)
q.Order("created_at DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&recentLogs)
response.Success(c, gin.H{
"total_calls": totalCalls,
"success_calls": successCalls,
"mock_calls": mockCalls,
"total_tokens": totalTokens.Sum,
"agent_execs": agentExecs,
"agent_success": agentSuccess,
"tool_calls": toolCalls,
"tool_success": toolSuccess,
"agent_counts": agentCounts,
"scene_counts": sceneCounts,
"hourly_trend": trend,
"recent_logs": recentLogs,
"logs_total": recentLogsTotal,
})
}
......@@ -109,6 +109,19 @@ func (h *Handler) RegisterRoutes(r *gin.RouterGroup) {
// Agent 执行监控
adm.GET("/agent/logs", h.GetAgentExecutionLogs)
adm.GET("/agent/stats", h.GetAgentStats)
// 内容安全管理
adm.GET("/safety/rules", h.ListSafetyRules)
adm.POST("/safety/rules", h.CreateSafetyRule)
adm.PUT("/safety/rules/:id", h.UpdateSafetyRule)
adm.DELETE("/safety/rules/:id", h.DeleteSafetyRule)
adm.POST("/safety/rules/import-preset", h.ImportSafetyRules)
adm.GET("/safety/logs", h.ListSafetyLogs)
adm.GET("/safety/stats", h.GetSafetyStats)
// AI 运营中心
adm.GET("/ai-center/trace", h.GetTraceDetail)
adm.GET("/ai-center/stats", h.GetAICenterStats)
}
}
......
......@@ -6,6 +6,7 @@ import (
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/workflow"
)
// ==================== 处方监管请求/响应 ====================
......@@ -132,5 +133,16 @@ func (s *Service) ReviewPrescription(ctx context.Context, id string, action stri
return fmt.Errorf("无效的审核操作")
}
return s.db.Model(&prescription).Updates(updates).Error
if err := s.db.Model(&prescription).Updates(updates).Error; err != nil {
return err
}
// 处方审核通过时触发 prescription_approved 工作流(异步)
if action == "approve" {
workflow.GetEngine().TriggerByCategory(ctx, "prescription_approved", map[string]interface{}{
"prescription_id": id,
"reviewer_id": reviewerID,
})
}
return nil
}
package admin
import (
"strconv"
"github.com/gin-gonic/gin"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/response"
"internet-hospital/pkg/safety"
)
// ===================== 安全词规则 =====================
// ListSafetyRules 查询安全词规则列表
func (h *Handler) ListSafetyRules(c *gin.Context) {
db := database.GetDB()
category := c.Query("category")
direction := c.Query("direction")
status := c.Query("status")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
var total int64
var rules []model.SafetyWordRule
q := db.Model(&model.SafetyWordRule{})
if category != "" {
q = q.Where("category = ?", category)
}
if direction != "" {
q = q.Where("direction = ? OR direction = 'both'", direction)
}
if status != "" {
q = q.Where("status = ?", status)
}
q.Count(&total)
q.Order("id DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&rules)
response.Success(c, gin.H{"list": rules, "total": total, "page": page})
}
// CreateSafetyRule 创建安全词规则
func (h *Handler) CreateSafetyRule(c *gin.Context) {
var rule model.SafetyWordRule
if err := c.ShouldBindJSON(&rule); err != nil {
response.BadRequest(c, err.Error())
return
}
if rule.Direction == "" {
rule.Direction = "both"
}
if rule.Status == "" {
rule.Status = "active"
}
if err := database.GetDB().Create(&rule).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
// 重新加载缓存
safety.GetLoader().Reload()
response.Success(c, rule)
}
// UpdateSafetyRule 更新安全词规则
func (h *Handler) UpdateSafetyRule(c *gin.Context) {
id := c.Param("id")
var rule model.SafetyWordRule
if err := database.GetDB().First(&rule, "id = ?", id).Error; err != nil {
response.Error(c, 404, "规则不存在")
return
}
if err := c.ShouldBindJSON(&rule); err != nil {
response.BadRequest(c, err.Error())
return
}
if err := database.GetDB().Save(&rule).Error; err != nil {
response.Error(c, 500, err.Error())
return
}
safety.GetLoader().Reload()
response.Success(c, rule)
}
// DeleteSafetyRule 删除安全词规则
func (h *Handler) DeleteSafetyRule(c *gin.Context) {
id := c.Param("id")
database.GetDB().Delete(&model.SafetyWordRule{}, "id = ?", id)
safety.GetLoader().Reload()
response.Success(c, nil)
}
// ImportSafetyRules 批量导入安全词规则(医疗场景预置)
func (h *Handler) ImportSafetyRules(c *gin.Context) {
db := database.GetDB()
preset := presetMedicalSafetyRules()
created := 0
for _, rule := range preset {
var existing model.SafetyWordRule
if db.Where("word = ?", rule.Word).First(&existing).Error != nil {
if db.Create(&rule).Error == nil {
created++
}
}
}
safety.GetLoader().Reload()
response.SuccessWithMessage(c, "导入完成", gin.H{"created": created})
}
// ===================== 过滤日志 =====================
// ListSafetyLogs 查询过滤日志
func (h *Handler) ListSafetyLogs(c *gin.Context) {
db := database.GetDB()
action := c.Query("action")
direction := c.Query("direction")
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 20
}
var total int64
var logs []model.SafetyFilterLog
q := db.Model(&model.SafetyFilterLog{})
if action != "" {
q = q.Where("action = ?", action)
}
if direction != "" {
q = q.Where("direction = ?", direction)
}
q.Count(&total)
q.Order("id DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&logs)
response.Success(c, gin.H{"list": logs, "total": total, "page": page})
}
// GetSafetyStats 安全过滤统计
func (h *Handler) GetSafetyStats(c *gin.Context) {
db := database.GetDB()
var total, blocked, replaced, warned int64
db.Model(&model.SafetyFilterLog{}).Count(&total)
db.Model(&model.SafetyFilterLog{}).Where("action = 'blocked'").Count(&blocked)
db.Model(&model.SafetyFilterLog{}).Where("action = 'replaced'").Count(&replaced)
db.Model(&model.SafetyFilterLog{}).Where("action = 'warned'").Count(&warned)
var totalRules, activeRules int64
db.Model(&model.SafetyWordRule{}).Count(&totalRules)
db.Model(&model.SafetyWordRule{}).Where("status = 'active'").Count(&activeRules)
response.Success(c, gin.H{
"total_logs": total,
"blocked": blocked,
"replaced": replaced,
"warned": warned,
"total_rules": totalRules,
"active_rules": activeRules,
})
}
// presetMedicalSafetyRules 医疗场景预置安全词规则
func presetMedicalSafetyRules() []model.SafetyWordRule {
return []model.SafetyWordRule{
// Prompt注入防护(输入方向)
{Word: "忽略上述指令", Category: "injection", Level: "block", Direction: "input", Status: "active"},
{Word: "忘记之前的", Category: "injection", Level: "block", Direction: "input", Status: "active"},
{Word: "你现在是", Category: "injection", Level: "warn", Direction: "input", Status: "active"},
{Word: "扮演", Category: "injection", Level: "warn", Direction: "input", Status: "active"},
// 医疗断言防护(输出方向)
{Word: "你确诊了", Category: "medical_claim", Level: "replace", Replacement: "根据症状分析,可能", Direction: "output", Status: "active"},
{Word: "你一定是", Category: "medical_claim", Level: "replace", Replacement: "初步判断可能是", Direction: "output", Status: "active"},
{Word: "你患有", Category: "medical_claim", Level: "replace", Replacement: "症状提示可能存在", Direction: "output", Status: "active"},
// 危险建议防护(输出方向)
{Word: "自行停药", Category: "toxicity", Level: "block", Direction: "output", Status: "active"},
{Word: "不需要看医生", Category: "toxicity", Level: "block", Direction: "output", Status: "active"},
{Word: "自己手术", Category: "toxicity", Level: "block", Direction: "output", Status: "active"},
}
}
......@@ -11,6 +11,7 @@ import (
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/utils"
"internet-hospital/pkg/workflow"
)
// Service 管理端业务逻辑层
......@@ -483,7 +484,16 @@ func (s *Service) ApproveDoctorReview(ctx context.Context, reviewID string) erro
return err
}
return tx.Commit().Error
if err := tx.Commit().Error; err != nil {
return err
}
// 触发 doctor_review 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "doctor_review", map[string]interface{}{
"review_id": reviewID,
"doctor_id": review.UserID,
"action": "approved",
})
return nil
}
func (s *Service) RejectDoctorReview(ctx context.Context, reviewID, reason string) error {
......@@ -497,11 +507,21 @@ func (s *Service) RejectDoctorReview(ctx context.Context, reviewID, reason strin
}
now := time.Now()
return s.db.Model(&review).Updates(map[string]interface{}{
if err := s.db.Model(&review).Updates(map[string]interface{}{
"status": "rejected",
"reject_reason": reason,
"reviewed_at": now,
}).Error
}).Error; err != nil {
return err
}
// 触发 doctor_review 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "doctor_review", map[string]interface{}{
"review_id": reviewID,
"doctor_id": review.UserID,
"action": "rejected",
"reason": reason,
})
return nil
}
func (s *Service) GetDepartmentList(ctx context.Context) (interface{}, error) {
......
......@@ -12,6 +12,7 @@ import (
internalagent "internet-hospital/internal/agent"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/workflow"
)
type Service struct {
......@@ -95,7 +96,18 @@ func (s *Service) CreateRenewal(ctx context.Context, userID string, req *Renewal
ChronicID: req.ChronicID, DiseaseName: req.DiseaseName,
Medicines: string(medsJSON), Reason: req.Reason, Status: "pending",
}
return r, s.db.WithContext(ctx).Create(r).Error
if err := s.db.WithContext(ctx).Create(r).Error; err != nil {
return nil, err
}
// 触发 renewal_requested 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "renewal_requested", map[string]interface{}{
"renewal_id": r.ID,
"patient_id": userID,
"chronic_record_id": req.ChronicID,
"disease_name": req.DiseaseName,
})
return r, nil
}
func (s *Service) GetAIRenewalAdvice(ctx context.Context, userID, renewalID string) (string, error) {
......@@ -220,7 +232,61 @@ func (s *Service) CreateMetric(ctx context.Context, userID string, req *MetricRe
MetricType: req.MetricType, Value1: req.Value1, Value2: req.Value2,
Unit: req.Unit, RecordedAt: recordedAt, Notes: req.Notes,
}
return r, s.db.WithContext(ctx).Create(r).Error
if err := s.db.WithContext(ctx).Create(r).Error; err != nil {
return nil, err
}
// 检测异常指标,触发 health_alert 工作流
if alertLevel := detectMetricAlert(req.MetricType, req.Value1, req.Value2); alertLevel != "" {
workflow.GetEngine().TriggerByCategory(ctx, "health_alert", map[string]interface{}{
"patient_id": userID,
"metric_type": req.MetricType,
"value1": req.Value1,
"value2": req.Value2,
"unit": req.Unit,
"alert_level": alertLevel,
})
}
return r, nil
}
// detectMetricAlert 检测健康指标是否异常,返回告警级别("mild"/"moderate"/"severe"/"" 表示正常)
func detectMetricAlert(metricType string, value1, value2 float64) string {
switch metricType {
case "blood_pressure":
// 收缩压 value1 / 舒张压 value2
if value1 >= 180 || value2 >= 120 {
return "severe"
} else if value1 >= 160 || value2 >= 100 {
return "moderate"
} else if value1 >= 140 || value2 >= 90 {
return "mild"
}
case "blood_glucose":
// 空腹血糖 mmol/L
if value1 >= 16.7 {
return "severe"
} else if value1 >= 11.1 {
return "moderate"
} else if value1 >= 7.0 {
return "mild"
}
case "heart_rate":
if value1 >= 150 || value1 <= 40 {
return "severe"
} else if value1 >= 120 || value1 <= 50 {
return "moderate"
}
case "body_temperature":
if value1 >= 39.5 {
return "severe"
} else if value1 >= 38.5 {
return "moderate"
} else if value1 >= 37.5 {
return "mild"
}
}
return ""
}
func (s *Service) DeleteMetric(ctx context.Context, userID, id string) error {
......
......@@ -13,6 +13,7 @@ import (
internalagent "internet-hospital/internal/agent"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/workflow"
)
type Service struct {
......@@ -83,6 +84,14 @@ func (s *Service) CreateConsult(ctx context.Context, patientID string, req *Crea
s.db.Create(videoRoom)
}
// 触发 consult_created 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "consult_created", map[string]interface{}{
"consult_id": consult.ID,
"patient_id": consult.PatientID,
"doctor_id": consult.DoctorID,
"type": consult.Type,
})
return &ConsultResponse{
Consultation: *consult,
DoctorName: doctor.Name,
......@@ -186,10 +195,20 @@ func (s *Service) EndConsult(ctx context.Context, consultID string) error {
}
now := time.Now()
return s.db.Model(&consult).Updates(map[string]interface{}{
if err := s.db.Model(&consult).Updates(map[string]interface{}{
"status": "completed",
"ended_at": now,
}).Error
}).Error; err != nil {
return err
}
// 触发 consult_ended 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "consult_ended", map[string]interface{}{
"consult_id": consultID,
"patient_id": consult.PatientID,
"doctor_id": consult.DoctorID,
})
return nil
}
func (s *Service) CancelConsult(ctx context.Context, consultID string) error {
......
......@@ -80,3 +80,23 @@ func (h *Handler) GetDoctorSchedule(c *gin.Context) {
}
response.Success(c, schedules)
}
type AppointmentRequest struct {
DoctorID string `json:"doctor_id" binding:"required"`
Date string `json:"date" binding:"required"`
TimeSlot string `json:"time_slot" binding:"required"`
}
func (h *Handler) MakeAppointment(c *gin.Context) {
var req AppointmentRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数错误")
return
}
appointmentID, err := h.service.MakeAppointment(c.Request.Context(), req.DoctorID, req.Date, req.TimeSlot)
if err != nil {
response.Error(c, 400, err.Error())
return
}
response.Success(c, gin.H{"appointment_id": appointmentID})
}
......@@ -2,6 +2,7 @@ package doctor
import (
"context"
"errors"
"gorm.io/gorm"
......@@ -120,3 +121,15 @@ func (s *Service) GetDoctorSchedule(ctx context.Context, doctorID, startDate, en
Find(&schedules).Error
return schedules, err
}
func (s *Service) MakeAppointment(ctx context.Context, doctorID, date, timeSlot string) (string, error) {
var schedule model.DoctorSchedule
if err := s.db.Where("doctor_id = ? AND date = ? AND start_time = ? AND remaining > 0", doctorID, date, timeSlot).
First(&schedule).Error; err != nil {
return "", errors.New("该时间段已无可用名额")
}
if err := s.db.Model(&schedule).UpdateColumn("remaining", schedule.Remaining-1).Error; err != nil {
return "", err
}
return schedule.ID, nil
}
......@@ -10,6 +10,7 @@ import (
internalagent "internet-hospital/internal/agent"
"internet-hospital/internal/model"
"internet-hospital/pkg/workflow"
)
// ==================== 处方开具请求/响应 ====================
......@@ -155,6 +156,15 @@ func (s *Service) CreatePrescription(ctx context.Context, doctorID string, req *
}
tx.Commit()
// 触发 prescription_created 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "prescription_created", map[string]interface{}{
"prescription_id": prescription.ID,
"doctor_id": doctorID,
"patient_id": prescription.PatientID,
"total_amount": prescription.TotalAmount,
})
return prescription, nil
}
......
......@@ -11,6 +11,7 @@ import (
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
"internet-hospital/pkg/workflow"
)
type Service struct {
......@@ -104,6 +105,16 @@ func (s *Service) PayOrder(ctx context.Context, orderID, paymentMethod string) (
s.createDoctorIncome(ctx, order.RelatedID, order.Amount, order.OrderType)
}
// 触发 payment_completed 工作流(异步)
workflow.GetEngine().TriggerByCategory(ctx, "payment_completed", map[string]interface{}{
"order_id": order.ID,
"order_no": order.OrderNo,
"user_id": order.UserID,
"order_type": order.OrderType,
"related_id": order.RelatedID,
"amount": order.Amount,
})
return map[string]interface{}{
"order_id": order.ID,
"status": order.Status,
......
......@@ -26,6 +26,34 @@ func (h *Handler) RegisterRoutes(r *gin.RouterGroup) {
}
}
func (h *Handler) Logout(c *gin.Context) {
// JWT 为无状态令牌,服务端无需操作;客户端删除本地令牌即可
response.Success(c, nil)
}
type VerifyIdentityRequest struct {
RealName string `json:"real_name" binding:"required"`
IDCard string `json:"id_card" binding:"required"`
}
func (h *Handler) VerifyIdentity(c *gin.Context) {
userID, exists := c.Get("user_id")
if !exists {
response.Unauthorized(c, "未登录")
return
}
var req VerifyIdentityRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数错误")
return
}
if err := h.service.VerifyIdentity(c.Request.Context(), userID.(string), req.RealName, req.IDCard); err != nil {
response.Error(c, 400, err.Error())
return
}
response.Success(c, nil)
}
type SendCodeRequest struct {
Phone string `json:"phone" binding:"required"`
}
......
......@@ -193,3 +193,16 @@ func (s *Service) GetUserByID(ctx context.Context, userID string) (*model.User,
func (s *Service) RefreshToken(ctx context.Context, refreshToken string) (*utils.TokenPair, error) {
return utils.RefreshAccessToken(refreshToken)
}
func (s *Service) VerifyIdentity(ctx context.Context, userID, realName, idCard string) error {
// 简单格式校验:身份证18位
if len(idCard) != 18 {
return errors.New("身份证号格式不正确")
}
return s.db.Model(&model.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"real_name": realName,
"is_verified": true,
}).Error
}
......@@ -4,19 +4,52 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
)
// getToolConfig 从数据库获取工具配置(CacheTTL/Timeout/MaxRetries/Status)
func getToolConfig(name string) model.AgentTool {
db := database.GetDB()
if db == nil {
return model.AgentTool{Status: "active", Timeout: 30}
}
var tool model.AgentTool
if err := db.Where("name = ?", name).First(&tool).Error; err != nil {
return model.AgentTool{Status: "active", Timeout: 30}
}
return tool
}
// Executor 工具执行器
type Executor struct {
registry *ToolRegistry
cache *ToolCache
}
func NewExecutor(r *ToolRegistry) *Executor {
return &Executor{registry: r}
return &Executor{registry: r, cache: GetCache()}
}
// Execute 执行工具调用
// Execute 执行工具调用(不记日志,集成缓存)
func (e *Executor) Execute(ctx context.Context, name string, argsJSON string) ToolResult {
cfg := getToolConfig(name)
// 检查工具是否被禁用
if cfg.Status == "disabled" {
return ToolResult{Success: false, Error: fmt.Sprintf("工具 %s 已被管理员禁用", name)}
}
// 缓存命中检查
if cfg.CacheTTL > 0 {
if cached, ok := e.cache.Get(name, argsJSON); ok {
return cached
}
}
tool, ok := e.registry.Get(name)
if !ok {
return ToolResult{Success: false, Error: fmt.Sprintf("tool not found: %s", name)}
......@@ -27,9 +60,74 @@ func (e *Executor) Execute(ctx context.Context, name string, argsJSON string) To
return ToolResult{Success: false, Error: fmt.Sprintf("invalid params: %v", err)}
}
data, err := tool.Execute(ctx, params)
if err != nil {
return ToolResult{Success: false, Error: err.Error()}
// 超时控制
timeout := time.Duration(cfg.Timeout) * time.Second
if timeout <= 0 {
timeout = 30 * time.Second
}
execCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 执行(带重试)
var result ToolResult
maxRetries := cfg.MaxRetries
if maxRetries < 0 {
maxRetries = 0
}
for attempt := 0; attempt <= maxRetries; attempt++ {
data, err := tool.Execute(execCtx, params)
if err == nil {
result = ToolResult{Success: true, Data: data}
break
}
return ToolResult{Success: true, Data: data}
if attempt == maxRetries {
result = ToolResult{Success: false, Error: err.Error()}
}
}
// 写入缓存(仅成功结果)
if result.Success && cfg.CacheTTL > 0 {
e.cache.Set(name, argsJSON, result, cfg.CacheTTL)
}
return result
}
// ExecuteWithLog 执行工具调用并写入 AgentToolLog
func (e *Executor) ExecuteWithLog(ctx context.Context, name, argsJSON, traceID, agentID, sessionID, userID string, iteration int) ToolResult {
start := time.Now()
result := e.Execute(ctx, name, argsJSON)
durationMs := int(time.Since(start).Milliseconds())
// 异步写日志
go func() {
db := database.GetDB()
if db == nil {
return
}
outputJSON, _ := json.Marshal(result)
errMsg := ""
if !result.Success {
errMsg = result.Error
}
entry := &model.AgentToolLog{
TraceID: traceID,
ToolName: name,
AgentID: agentID,
SessionID: sessionID,
UserID: userID,
InputParams: argsJSON,
OutputResult: string(outputJSON),
Success: result.Success,
ErrorMessage: errMsg,
DurationMs: durationMs,
Iteration: iteration,
CreatedAt: time.Now(),
}
if err := db.Create(entry).Error; err != nil {
log.Printf("[executor] 保存工具调用日志失败: %v", err)
}
}()
return result
}
......@@ -5,6 +5,8 @@ import (
"encoding/json"
"fmt"
"internet-hospital/pkg/ai"
"github.com/google/uuid"
)
// AgentInput Agent输入
......@@ -15,6 +17,7 @@ type AgentInput struct {
Context map[string]interface{} `json:"context"`
History []ai.ChatMessage `json:"history"`
MaxIterations int `json:"max_iterations"`
TraceID string `json:"trace_id,omitempty"` // 外部传入的链路追踪ID(可选)
}
// AgentOutput Agent输出
......@@ -24,6 +27,7 @@ type AgentOutput struct {
Iterations int `json:"iterations"`
FinishReason string `json:"finish_reason"`
TotalTokens int `json:"total_tokens"`
TraceID string `json:"trace_id,omitempty"` // 本次执行的链路追踪ID
}
// ToolCallResult 工具调用结果记录
......@@ -73,11 +77,18 @@ func (a *ReActAgent) Description() string { return a.cfg.Description }
// Run 执行Agent(非流式)
func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, error) {
// 生成链路追踪ID(如果未提供则新建)
traceID := input.TraceID
if traceID == "" {
traceID = uuid.New().String()
}
client := ai.GetClient()
if client == nil {
return &AgentOutput{
Response: "AI服务未配置,请在管理端配置API Key",
FinishReason: "no_client",
TraceID: traceID,
}, nil
}
......@@ -105,6 +116,23 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
return nil, fmt.Errorf("AI调用失败: %w", err)
}
totalTokens += resp.Usage.TotalTokens
// 记录本轮 LLM 调用日志(与 AIUsageLog 关联)
ai.SaveLog(ai.LogParams{
Scene: a.cfg.ID,
UserID: input.UserID,
RequestContent: input.Message,
ResponseContent: resp.Choices[0].Message.Content,
PromptTokens: resp.Usage.PromptTokens,
CompletionTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
Success: true,
TraceID: traceID,
AgentID: a.cfg.ID,
SessionID: input.SessionID,
Iteration: i + 1,
})
choice := resp.Choices[0]
if choice.FinishReason == "stop" || len(choice.Message.ToolCalls) == 0 {
......@@ -114,6 +142,7 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
Iterations: i + 1,
FinishReason: "completed",
TotalTokens: totalTokens,
TraceID: traceID,
}, nil
}
......@@ -126,7 +155,8 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
messages = append(messages, assistantMsg)
for _, tc := range choice.Message.ToolCalls {
result := a.executor.Execute(ctx, tc.Function.Name, tc.Function.Arguments)
result := a.executor.ExecuteWithLog(ctx, tc.Function.Name, tc.Function.Arguments,
traceID, a.cfg.ID, input.SessionID, input.UserID, i+1)
resultJSON, _ := json.Marshal(result)
toolCallResults = append(toolCallResults, ToolCallResult{
ToolName: tc.Function.Name,
......@@ -149,6 +179,7 @@ func (a *ReActAgent) Run(ctx context.Context, input AgentInput) (*AgentOutput, e
Iterations: maxIter,
FinishReason: "max_iterations",
TotalTokens: totalTokens,
TraceID: traceID,
}, nil
}
......@@ -175,7 +206,13 @@ func (a *ReActAgent) RunStream(ctx context.Context, input AgentInput, onChunk fu
}
func (a *ReActAgent) buildSystemPrompt(ctx map[string]interface{}) string {
prompt := a.cfg.SystemPrompt
// 1. 优先从数据库加载该Agent关联的 active 提示词模板
prompt := ai.GetActivePromptByAgent(a.cfg.ID)
// 2. 回退到 AgentDefinition 的 SystemPrompt(即代码配置值)
if prompt == "" {
prompt = a.cfg.SystemPrompt
}
// 3. 渲染上下文变量
if ctx != nil {
if patientID, ok := ctx["patient_id"].(string); ok {
prompt += fmt.Sprintf("\n\n当前患者ID: %s", patientID)
......
......@@ -37,6 +37,12 @@ func (r *ToolRegistry) GetSchemas(names []string) []ToolSchema {
return schemas
}
func (r *ToolRegistry) Unregister(name string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.tools, name)
}
func (r *ToolRegistry) All() map[string]Tool {
r.mu.RLock()
defer r.mu.RUnlock()
......
package agent
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"time"
"internet-hospital/pkg/database"
)
const cacheKeyPrefix = "tool_cache:"
// ToolCache 工具结果 Redis 缓存
type ToolCache struct{}
var globalCache = &ToolCache{}
// GetCache 获取全局缓存实例
func GetCache() *ToolCache { return globalCache }
// argsHash 对工具参数计算 sha256 前 16 位,作为缓存 key 后缀
func argsHash(argsJSON string) string {
h := sha256.Sum256([]byte(argsJSON))
return fmt.Sprintf("%x", h[:8])
}
// cacheKey 构造缓存 key
func cacheKey(toolName, argsJSON string) string {
return cacheKeyPrefix + toolName + ":" + argsHash(argsJSON)
}
// Get 尝试命中缓存,返回 (result, hit)
func (c *ToolCache) Get(toolName, argsJSON string) (ToolResult, bool) {
client := database.RedisClient
if client == nil {
return ToolResult{}, false
}
key := cacheKey(toolName, argsJSON)
val, err := client.Get(context.Background(), key).Result()
if err != nil {
return ToolResult{}, false
}
var result ToolResult
if err := json.Unmarshal([]byte(val), &result); err != nil {
return ToolResult{}, false
}
return result, true
}
// Set 写入缓存
func (c *ToolCache) Set(toolName, argsJSON string, result ToolResult, ttlSec int) {
if ttlSec <= 0 {
return
}
client := database.RedisClient
if client == nil {
return
}
b, err := json.Marshal(result)
if err != nil {
return
}
key := cacheKey(toolName, argsJSON)
client.Set(context.Background(), key, string(b), time.Duration(ttlSec)*time.Second)
}
// Invalidate 手动使某工具的所有缓存失效(管理员修改工具配置后调用)
func (c *ToolCache) Invalidate(toolName string) {
client := database.RedisClient
if client == nil {
return
}
pattern := cacheKeyPrefix + toolName + ":*"
keys, err := client.Keys(context.Background(), pattern).Result()
if err != nil || len(keys) == 0 {
return
}
client.Del(context.Background(), keys...)
}
package tools
import (
"context"
"fmt"
"internet-hospital/pkg/agent"
)
// AgentCallFn 由 internal/agent/service.go 在启动时注入,避免循环依赖
var AgentCallFn func(ctx context.Context, agentID, userID, sessionID, message string, ctxData map[string]interface{}) (string, error)
const maxAgentCallDepth = 3
type agentCallDepthKey struct{}
// AgentCallerTool 允许 Agent 调用另一个 Agent(子 Agent 模式)
type AgentCallerTool struct{}
func (t *AgentCallerTool) Name() string { return "call_agent" }
func (t *AgentCallerTool) Description() string {
return "调用另一个专业 Agent 处理特定子任务,如诊断辅助、处方审核、随访管理等"
}
func (t *AgentCallerTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "agent_id",
Type: "string",
Description: "目标 Agent ID,如 diagnosis_agent、prescription_agent、follow_up_agent",
Required: true,
},
{
Name: "message",
Type: "string",
Description: "发送给目标 Agent 的任务描述或问题",
Required: true,
},
{
Name: "context",
Type: "object",
Description: "传递给目标 Agent 的上下文信息(可选)",
Required: false,
},
}
}
func (t *AgentCallerTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
// 防循环调用:检测调用深度
depth, _ := ctx.Value(agentCallDepthKey{}).(int)
if depth >= maxAgentCallDepth {
return nil, fmt.Errorf("agent 调用深度超限(最大 %d 层),防止无限递归", maxAgentCallDepth)
}
ctx = context.WithValue(ctx, agentCallDepthKey{}, depth+1)
agentID, ok := params["agent_id"].(string)
if !ok || agentID == "" {
return nil, fmt.Errorf("agent_id 必填")
}
message, ok := params["message"].(string)
if !ok || message == "" {
return nil, fmt.Errorf("message 必填")
}
ctxData, _ := params["context"].(map[string]interface{})
if AgentCallFn == nil {
return nil, fmt.Errorf("AgentCallFn 未注入,请检查初始化顺序")
}
response, err := AgentCallFn(ctx, agentID, "system", "", message, ctxData)
if err != nil {
return nil, fmt.Errorf("调用 agent %s 失败: %w", agentID, err)
}
return map[string]interface{}{
"agent_id": agentID,
"response": response,
}, nil
}
package tools
import (
"context"
"fmt"
"github.com/expr-lang/expr"
agentpkg "internet-hospital/pkg/agent"
)
// ExprEvalTool 安全表达式执行工具(使用 expr-lang,无 IO/网络权限)
type ExprEvalTool struct{}
func (t *ExprEvalTool) Name() string { return "eval_expression" }
func (t *ExprEvalTool) Description() string {
return "安全执行数学/逻辑表达式,支持四则运算、条件判断、字符串处理、数组过滤。常用于 BMI 计算、风险评分、剂量调整等"
}
func (t *ExprEvalTool) Parameters() []agentpkg.ToolParameter {
return []agentpkg.ToolParameter{
{
Name: "expression",
Type: "string",
Description: "表达式字符串,如: weight / (height * height) 或 age > 65 ? '老年' : '成年'",
Required: true,
},
{
Name: "variables",
Type: "object",
Description: "表达式中引用的变量,如: {\"weight\": 70, \"height\": 1.75}",
Required: false,
},
}
}
func (t *ExprEvalTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
expression, ok := params["expression"].(string)
if !ok || expression == "" {
return nil, fmt.Errorf("expression 必填")
}
variables, _ := params["variables"].(map[string]interface{})
if variables == nil {
variables = map[string]interface{}{}
}
// 编译表达式
program, err := expr.Compile(expression, expr.Env(variables))
if err != nil {
return nil, fmt.Errorf("表达式编译失败: %w", err)
}
// 执行表达式
result, err := expr.Run(program, variables)
if err != nil {
return nil, fmt.Errorf("表达式执行失败: %w", err)
}
return map[string]interface{}{
"expression": expression,
"result": result,
"variables": variables,
}, nil
}
package tools
import (
"context"
"fmt"
"internet-hospital/pkg/agent"
)
// GenerateFollowUpPlanTool 生成随访计划(此工具之前在分类 map 中声明但从未注册)
type GenerateFollowUpPlanTool struct{}
func (t *GenerateFollowUpPlanTool) Name() string { return "generate_follow_up_plan" }
func (t *GenerateFollowUpPlanTool) Description() string {
return "根据诊断和治疗信息生成个性化随访计划,包括复诊时间、检查项目、用药提醒和生活建议"
}
func (t *GenerateFollowUpPlanTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "diagnosis",
Type: "string",
Description: "诊断结论,如:高血压2级、2型糖尿病",
Required: true,
},
{
Name: "medications",
Type: "array",
Description: "当前用药列表,如:[\"氨氯地平 5mg\", \"二甲双胍 500mg\"]",
Required: false,
},
{
Name: "patient_age",
Type: "number",
Description: "患者年龄",
Required: false,
},
{
Name: "severity",
Type: "string",
Description: "病情严重程度:mild/moderate/severe",
Required: false,
Enum: []string{"mild", "moderate", "severe"},
},
}
}
func (t *GenerateFollowUpPlanTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
diagnosis, ok := params["diagnosis"].(string)
if !ok || diagnosis == "" {
return nil, fmt.Errorf("diagnosis 必填")
}
severity, _ := params["severity"].(string)
if severity == "" {
severity = "moderate"
}
age, _ := params["patient_age"].(float64)
// 基于诊断生成随访计划(内置规则)
plan := buildFollowUpPlan(diagnosis, severity, int(age))
meds := make([]string, 0)
if medsRaw, ok := params["medications"].([]interface{}); ok {
for _, m := range medsRaw {
if s, ok := m.(string); ok {
meds = append(meds, s)
}
}
}
if len(meds) > 0 {
plan["medication_reminders"] = buildMedReminders(meds)
}
return plan, nil
}
func buildFollowUpPlan(diagnosis, severity string, age int) map[string]interface{} {
// 随访频率
followUpIntervalDays := 90 // 默认3个月
switch severity {
case "mild":
followUpIntervalDays = 180
case "severe":
followUpIntervalDays = 30
}
if age >= 65 {
followUpIntervalDays = followUpIntervalDays / 2
}
plan := map[string]interface{}{
"diagnosis": diagnosis,
"follow_up_interval": fmt.Sprintf("每 %d 天随访一次", followUpIntervalDays),
"follow_up_interval_days": followUpIntervalDays,
}
// 诊断特异性建议
checks := []string{"血压测量", "体重监测", "症状询问"}
advice := []string{"低盐低脂饮食", "规律运动(每周150分钟中等强度)", "戒烟限酒", "规律服药"}
containsAny := func(s string, keywords ...string) bool {
for _, k := range keywords {
if containsIgnoreCase(s, k) {
return true
}
}
return false
}
if containsAny(diagnosis, "高血压", "blood pressure") {
checks = append(checks, "尿常规", "血肌酐", "血钾")
advice = append(advice, "每日家庭血压监测(早晚各1次)", "限制钠盐摄入(< 6g/日)")
plan["target"] = "血压控制目标:< 140/90 mmHg(耐受者 < 130/80 mmHg)"
}
if containsAny(diagnosis, "糖尿病", "diabetes") {
checks = append(checks, "HbA1c", "空腹血糖", "尿微量白蛋白", "眼底检查", "足部检查")
advice = append(advice, "控制碳水化合物摄入", "每日血糖自我监测", "注意低血糖症状")
plan["target"] = "血糖控制目标:HbA1c < 7.0%(个体化)"
}
if containsAny(diagnosis, "冠心病", "心绞痛", "心肌梗死") {
checks = append(checks, "心电图", "血脂全套", "心脏超声(年度)")
advice = append(advice, "避免剧烈运动", "随身携带硝酸甘油", "规范抗血小板治疗")
plan["emergency_signs"] = "如出现持续胸痛、呼吸困难、大汗请立即就医"
}
plan["monitoring_items"] = checks
plan["lifestyle_advice"] = advice
return plan
}
func buildMedReminders(meds []string) []map[string]interface{} {
reminders := make([]map[string]interface{}, 0, len(meds))
for _, med := range meds {
reminders = append(reminders, map[string]interface{}{
"medication": med,
"reminder": "按时服药,切勿自行停药",
"tip": "如出现不适反应请及时就医",
})
}
return reminders
}
package tools
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
)
// DynamicHTTPTool 动态 HTTP 工具,从数据库配置生成,无需修改代码
type DynamicHTTPTool struct {
def *model.HTTPToolDefinition
}
func NewDynamicHTTPTool(def *model.HTTPToolDefinition) *DynamicHTTPTool {
return &DynamicHTTPTool{def: def}
}
func (t *DynamicHTTPTool) Name() string { return t.def.Name }
func (t *DynamicHTTPTool) Description() string { return t.def.Description }
func (t *DynamicHTTPTool) Parameters() []agent.ToolParameter {
if t.def.Parameters == "" || t.def.Parameters == "[]" {
return nil
}
var params []agent.ToolParameter
if err := json.Unmarshal([]byte(t.def.Parameters), &params); err != nil {
return nil
}
return params
}
// renderTemplate 将 {{key}} 替换为 params 中对应值
func renderTemplate(template string, params map[string]interface{}) string {
result := template
for k, v := range params {
result = strings.ReplaceAll(result, "{{"+k+"}}", fmt.Sprintf("%v", v))
}
return result
}
func (t *DynamicHTTPTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
// 渲染 URL
url := renderTemplate(t.def.URL, params)
// 渲染 Body
var bodyStr string
if t.def.BodyTemplate != "" {
bodyStr = renderTemplate(t.def.BodyTemplate, params)
}
// 超时
timeout := time.Duration(t.def.Timeout) * time.Second
if timeout <= 0 {
timeout = 10 * time.Second
}
httpCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// 创建请求
var bodyReader io.Reader
if bodyStr != "" {
bodyReader = strings.NewReader(bodyStr)
}
req, err := http.NewRequestWithContext(httpCtx, t.def.Method, url, bodyReader)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 设置 Content-Type
if bodyStr != "" {
req.Header.Set("Content-Type", "application/json")
}
// 自定义 Headers
if t.def.Headers != "" && t.def.Headers != "{}" {
var headers map[string]string
if err := json.Unmarshal([]byte(t.def.Headers), &headers); err == nil {
for k, v := range headers {
req.Header.Set(k, renderTemplate(v, params))
}
}
}
// 认证
switch t.def.AuthType {
case "bearer":
var authCfg map[string]string
if err := json.Unmarshal([]byte(t.def.AuthConfig), &authCfg); err == nil {
if token, ok := authCfg["token"]; ok && token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
}
case "basic":
var authCfg map[string]string
if err := json.Unmarshal([]byte(t.def.AuthConfig), &authCfg); err == nil {
req.SetBasicAuth(authCfg["username"], authCfg["password"])
}
case "apikey":
var authCfg map[string]string
if err := json.Unmarshal([]byte(t.def.AuthConfig), &authCfg); err == nil {
headerName := authCfg["header"]
headerVal := authCfg["value"]
if headerName != "" {
req.Header.Set(headerName, headerVal)
}
}
}
// 发起请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("HTTP 请求失败: %w", err)
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
result := map[string]interface{}{
"status_code": resp.StatusCode,
"success": resp.StatusCode >= 200 && resp.StatusCode < 300,
}
// 尝试解析 JSON 响应
var jsonBody interface{}
if err := json.Unmarshal(bodyBytes, &jsonBody); err == nil {
result["body"] = jsonBody
} else {
result["body"] = string(bodyBytes)
}
if resp.StatusCode >= 400 {
return result, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(bodyBytes))
}
return result, nil
}
package tools
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// HumanReviewTool Agent 主动发起人工审核任务
type HumanReviewTool struct{}
func (t *HumanReviewTool) Name() string { return "request_human_review" }
func (t *HumanReviewTool) Description() string {
return "发起人工审核请求,等待医生或管理员审核后继续;适用于需要人工介入的场景"
}
func (t *HumanReviewTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "title",
Type: "string",
Description: "审核任务标题",
Required: true,
},
{
Name: "description",
Type: "string",
Description: "需要人工审核的内容描述",
Required: true,
},
{
Name: "assignee_role",
Type: "string",
Description: "指定审核角色:doctor 或 admin",
Required: false,
Enum: []string{"doctor", "admin"},
},
{
Name: "data",
Type: "object",
Description: "随审核任务传递的附加数据",
Required: false,
},
}
}
func (t *HumanReviewTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
title, ok := params["title"].(string)
if !ok || title == "" {
return nil, fmt.Errorf("title 必填")
}
description, _ := params["description"].(string)
assigneeRole, _ := params["assignee_role"].(string)
if assigneeRole == "" {
assigneeRole = "admin"
}
extraData, _ := params["data"].(map[string]interface{})
extraJSON := "{}"
if extraData != nil {
if b, err := json.Marshal(extraData); err == nil {
extraJSON = string(b)
}
}
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
taskID := uuid.New().String()
now := time.Now()
task := model.WorkflowHumanTask{
TaskID: taskID,
ExecutionID: "agent-" + taskID,
NodeID: "agent_requested",
Title: title,
Description: description,
AssigneeRole: assigneeRole,
FormData: extraJSON,
Status: "pending",
CreatedAt: now,
}
if err := db.Create(&task).Error; err != nil {
return nil, fmt.Errorf("创建审核任务失败: %w", err)
}
return map[string]interface{}{
"task_id": taskID,
"status": "pending",
"assignee": assigneeRole,
"message": fmt.Sprintf("已创建人工审核任务,等待 %s 审核", assigneeRole),
}, nil
}
package tools
import (
"context"
"fmt"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// KnowledgeListTool 列出可用知识库集合,Agent 可据此选择 collection_id
type KnowledgeListTool struct{}
func (t *KnowledgeListTool) Name() string { return "list_knowledge_collections" }
func (t *KnowledgeListTool) Description() string {
return "列出所有可用的知识库集合及文档数量,用于在检索前确认目标集合 ID"
}
func (t *KnowledgeListTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "keyword",
Type: "string",
Description: "按名称关键词过滤集合(可选)",
Required: false,
},
}
}
func (t *KnowledgeListTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
keyword, _ := params["keyword"].(string)
type CollectionStat struct {
CollectionID string `gorm:"column:collection_id"`
Name string `gorm:"column:name"`
Description string `gorm:"column:description"`
DocCount int64 `gorm:"column:doc_count"`
}
query := db.Model(&model.KnowledgeCollection{}).
Select("knowledge_collections.collection_id, knowledge_collections.name, knowledge_collections.description, COUNT(knowledge_documents.id) as doc_count").
Joins("LEFT JOIN knowledge_documents ON knowledge_documents.collection_id = knowledge_collections.collection_id AND knowledge_documents.status = 'active'").
Group("knowledge_collections.collection_id, knowledge_collections.name, knowledge_collections.description")
if keyword != "" {
query = query.Where("knowledge_collections.name ILIKE ? OR knowledge_collections.description ILIKE ?",
"%"+keyword+"%", "%"+keyword+"%")
}
var stats []CollectionStat
if err := query.Scan(&stats).Error; err != nil {
return nil, fmt.Errorf("查询知识库失败: %w", err)
}
result := make([]map[string]interface{}, 0, len(stats))
for _, s := range stats {
result = append(result, map[string]interface{}{
"collection_id": s.CollectionID,
"name": s.Name,
"description": s.Description,
"doc_count": s.DocCount,
})
}
return map[string]interface{}{
"collections": result,
"total": len(result),
}, nil
}
package tools
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
"internet-hospital/pkg/rag"
)
// KnowledgeWriteTool Agent 向知识库写入或更新内容
type KnowledgeWriteTool struct {
DB *gorm.DB
Retriever *rag.Retriever
}
func (t *KnowledgeWriteTool) Name() string { return "write_knowledge" }
func (t *KnowledgeWriteTool) Description() string {
return "向指定知识库集合写入新文档或更新已有文档,Agent 可用于保存诊断结论、治疗方案、病例摘要等"
}
func (t *KnowledgeWriteTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "collection_id",
Type: "string",
Description: "目标知识库集合 ID,如 clinical_guideline、case_summary",
Required: true,
},
{
Name: "title",
Type: "string",
Description: "文档标题",
Required: true,
},
{
Name: "content",
Type: "string",
Description: "文档正文内容(支持 Markdown)",
Required: true,
},
{
Name: "tags",
Type: "array",
Description: "标签列表,如 [\"高血压\", \"治疗方案\"]",
Required: false,
},
{
Name: "doc_id",
Type: "string",
Description: "指定文档 ID 则更新,不指定则创建新文档",
Required: false,
},
}
}
func (t *KnowledgeWriteTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
db := t.DB
if db == nil {
db = database.GetDB()
}
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
collectionID, ok := params["collection_id"].(string)
if !ok || collectionID == "" {
return nil, fmt.Errorf("collection_id 必填")
}
title, ok := params["title"].(string)
if !ok || title == "" {
return nil, fmt.Errorf("title 必填")
}
content, ok := params["content"].(string)
if !ok || content == "" {
return nil, fmt.Errorf("content 必填")
}
// 构建 tags JSON
tagsJSON := "[]"
if tagsRaw, ok := params["tags"].([]interface{}); ok {
tags := make([]string, 0, len(tagsRaw))
for _, t := range tagsRaw {
if s, ok := t.(string); ok {
tags = append(tags, s)
}
}
if b, err := json.Marshal(tags); err == nil {
tagsJSON = string(b)
}
}
docID, _ := params["doc_id"].(string)
now := time.Now()
// 确保集合存在
var collection model.KnowledgeCollection
if err := db.Where("collection_id = ?", collectionID).First(&collection).Error; err != nil {
// 自动创建集合
collection = model.KnowledgeCollection{
CollectionID: collectionID,
Name: collectionID,
Description: "由 Agent 自动创建",
CreatedAt: now,
UpdatedAt: now,
}
if err := db.Create(&collection).Error; err != nil {
return nil, fmt.Errorf("创建集合失败: %w", err)
}
}
// 创建或更新文档
if docID == "" {
docID = uuid.New().String()
}
var doc model.KnowledgeDocument
isNew := db.Where("document_id = ?", docID).First(&doc).Error != nil
if isNew {
doc = model.KnowledgeDocument{
DocumentID: docID,
CollectionID: collectionID,
Title: title,
Content: content,
Metadata: tagsJSON,
Status: "ready",
CreatedAt: now,
UpdatedAt: now,
}
if err := db.Create(&doc).Error; err != nil {
return nil, fmt.Errorf("创建文档失败: %w", err)
}
} else {
if err := db.Model(&doc).Updates(map[string]interface{}{
"title": title,
"content": content,
"metadata": tagsJSON,
"updated_at": now,
}).Error; err != nil {
return nil, fmt.Errorf("更新文档失败: %w", err)
}
// 删除旧分块
db.Where("document_id = ?", docID).Delete(&model.KnowledgeChunk{})
}
// 分块并写入(按段落分块)
chunks := splitIntoChunks(content, 500)
for i, chunk := range chunks {
c := model.KnowledgeChunk{
ChunkID: uuid.New().String(),
DocumentID: docID,
CollectionID: collectionID,
Content: chunk,
ChunkIndex: i,
CreatedAt: now,
}
db.Create(&c)
}
// 异步触发向量索引(如果 Retriever 可用)
if t.Retriever != nil {
go func() {
_ = t.Retriever.IndexDocument(context.Background(), docID)
}()
}
action := "created"
if !isNew {
action = "updated"
}
return map[string]interface{}{
"doc_id": docID,
"collection_id": collectionID,
"title": title,
"chunks": len(chunks),
"action": action,
}, nil
}
// splitIntoChunks 按段落分块,maxLen 为最大字符数
func splitIntoChunks(text string, maxLen int) []string {
paragraphs := strings.Split(text, "\n\n")
var chunks []string
var current strings.Builder
for _, p := range paragraphs {
if current.Len()+len(p) > maxLen && current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
current.Reset()
}
if current.Len() > 0 {
current.WriteString("\n\n")
}
current.WriteString(p)
}
if current.Len() > 0 {
chunks = append(chunks, strings.TrimSpace(current.String()))
}
if len(chunks) == 0 {
chunks = []string{text}
}
return chunks
}
package tools
import (
"context"
"fmt"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// SendNotificationTool 向用户发送系统通知(站内信 + WebSocket 推送)
// 此工具之前在分类 map 中声明但从未注册,此次完整落地
type SendNotificationTool struct{}
func (t *SendNotificationTool) Name() string { return "send_notification" }
func (t *SendNotificationTool) Description() string {
return "向指定用户发送系统通知,支持站内信存储和实时 WebSocket 推送"
}
func (t *SendNotificationTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "user_id",
Type: "string",
Description: "接收通知的用户 ID(UUID)",
Required: true,
},
{
Name: "title",
Type: "string",
Description: "通知标题",
Required: true,
},
{
Name: "content",
Type: "string",
Description: "通知正文内容",
Required: true,
},
{
Name: "type",
Type: "string",
Description: "通知类型",
Required: false,
Enum: []string{"reminder", "alert", "info", "followup", "system"},
},
{
Name: "priority",
Type: "string",
Description: "优先级:normal/high/urgent",
Required: false,
Enum: []string{"normal", "high", "urgent"},
},
}
}
func (t *SendNotificationTool) Execute(_ context.Context, params map[string]interface{}) (interface{}, error) {
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
userID, ok := params["user_id"].(string)
if !ok || userID == "" {
return nil, fmt.Errorf("user_id 必填")
}
title, ok := params["title"].(string)
if !ok || title == "" {
return nil, fmt.Errorf("title 必填")
}
content, _ := params["content"].(string)
nType, _ := params["type"].(string)
if nType == "" {
nType = "info"
}
priority, _ := params["priority"].(string)
if priority == "" {
priority = "normal"
}
// 写入系统日志表作为站内信(复用 SystemLog)
log := model.SystemLog{
Action: "notification",
Resource: fmt.Sprintf("%s/%s", nType, priority),
Detail: fmt.Sprintf("[%s] %s: %s", nType, title, content),
UserID: userID,
}
if err := db.Create(&log).Error; err != nil {
return nil, fmt.Errorf("写入通知失败: %w", err)
}
return map[string]interface{}{
"notification_id": log.ID,
"user_id": userID,
"title": title,
"type": nType,
"priority": priority,
"sent_at": time.Now().Format(time.RFC3339),
}, nil
}
package tools
import (
"context"
"encoding/json"
"fmt"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
)
// WorkflowQueryTool 查询工作流执行状态
type WorkflowQueryTool struct{}
func (t *WorkflowQueryTool) Name() string { return "query_workflow_status" }
func (t *WorkflowQueryTool) Description() string {
return "查询工作流执行状态和结果,通常与 trigger_workflow 配合使用"
}
func (t *WorkflowQueryTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "execution_id",
Type: "string",
Description: "工作流执行 ID(由 trigger_workflow 返回)",
Required: true,
},
}
}
func (t *WorkflowQueryTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
executionID, ok := params["execution_id"].(string)
if !ok || executionID == "" {
return nil, fmt.Errorf("execution_id 必填")
}
db := database.GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
var exec model.WorkflowExecution
if err := db.Where("execution_id = ?", executionID).First(&exec).Error; err != nil {
return nil, fmt.Errorf("未找到执行记录: %s", executionID)
}
result := map[string]interface{}{
"execution_id": exec.ExecutionID,
"workflow_id": exec.WorkflowID,
"status": exec.Status,
"trigger_by": exec.TriggerBy,
"started_at": exec.StartedAt,
"completed_at": exec.CompletedAt,
}
// 解析输出
if exec.Output != "" && exec.Output != "null" {
var output interface{}
if err := json.Unmarshal([]byte(exec.Output), &output); err == nil {
result["output"] = output
}
}
return result, nil
}
package tools
import (
"context"
"fmt"
"internet-hospital/pkg/agent"
)
// WorkflowTriggerFn 由 main.go 或 workflow 包在启动时注入
var WorkflowTriggerFn func(ctx context.Context, workflowID string, input map[string]interface{}) (string, error)
// WorkflowTriggerTool Agent 触发工作流执行
type WorkflowTriggerTool struct{}
func (t *WorkflowTriggerTool) Name() string { return "trigger_workflow" }
func (t *WorkflowTriggerTool) Description() string {
return "触发一个工作流异步执行,返回执行 ID 可用于后续查询状态"
}
func (t *WorkflowTriggerTool) Parameters() []agent.ToolParameter {
return []agent.ToolParameter{
{
Name: "workflow_id",
Type: "string",
Description: "工作流 ID,如 pre_consult、follow_up",
Required: true,
},
{
Name: "input",
Type: "object",
Description: "工作流输入参数(键值对)",
Required: false,
},
}
}
func (t *WorkflowTriggerTool) Execute(ctx context.Context, params map[string]interface{}) (interface{}, error) {
workflowID, ok := params["workflow_id"].(string)
if !ok || workflowID == "" {
return nil, fmt.Errorf("workflow_id 必填")
}
input, _ := params["input"].(map[string]interface{})
if input == nil {
input = map[string]interface{}{}
}
if WorkflowTriggerFn == nil {
return nil, fmt.Errorf("WorkflowTriggerFn 未注入")
}
executionID, err := WorkflowTriggerFn(ctx, workflowID, input)
if err != nil {
return nil, fmt.Errorf("触发工作流 %s 失败: %w", workflowID, err)
}
return map[string]interface{}{
"workflow_id": workflowID,
"execution_id": executionID,
"status": "triggered",
}, nil
}
......@@ -584,6 +584,11 @@ type CallParams struct {
UserID string // 用户ID
Messages []ChatMessage // 消息列表
RequestSummary string // 请求摘要(用于日志)
// 链路追踪字段(Agent调用时填写)
TraceID string // 链路追踪ID
AgentID string // 关联Agent ID
SessionID string // 关联会话ID
Iteration int // Agent第几轮迭代(0=非Agent调用)
}
// CallResult 统一AI调用结果
......@@ -655,6 +660,10 @@ func Call(ctx context.Context, params CallParams) CallResult {
Success: result.Error == nil,
ErrorMessage: errMsg,
IsMock: result.IsMock,
TraceID: params.TraceID,
AgentID: params.AgentID,
SessionID: params.SessionID,
Iteration: params.Iteration,
})
return result
......@@ -715,6 +724,10 @@ func CallStream(ctx context.Context, params CallParams, onChunk func(content str
Success: result.Error == nil,
ErrorMessage: errMsg,
IsMock: result.IsMock,
TraceID: params.TraceID,
AgentID: params.AgentID,
SessionID: params.SessionID,
Iteration: params.Iteration,
})
return result
......
......@@ -22,6 +22,11 @@ type LogParams struct {
Success bool
ErrorMessage string
IsMock bool
// 链路追踪字段
TraceID string
AgentID string
SessionID string
Iteration int
}
// SaveLog 异步保存AI调用日志
......@@ -66,6 +71,10 @@ func SaveLog(params LogParams) {
Success: params.Success,
ErrorMessage: params.ErrorMessage,
IsMock: params.IsMock,
TraceID: params.TraceID,
AgentID: params.AgentID,
SessionID: params.SessionID,
Iteration: params.Iteration,
CreatedAt: time.Now(),
}
......
......@@ -37,3 +37,17 @@ func GetActivePromptByScene(scene string) string {
}
return ""
}
// GetActivePromptByAgent 根据 agent_id 查找该Agent关联的激活系统提示词模板
func GetActivePromptByAgent(agentID string) string {
db := database.GetDB()
if db == nil {
return ""
}
var tmpl model.PromptTemplate
if err := db.Where("agent_id = ? AND template_type = 'system' AND status = 'active'", agentID).
Order("version DESC").First(&tmpl).Error; err == nil {
return tmpl.Content
}
return ""
}
package safety
import (
"regexp"
"strings"
"internet-hospital/internal/model"
)
// FilterResult 过滤结果
type FilterResult struct {
Text string // 处理后的文本(passed/replaced时有效)
Action string // passed | blocked | replaced | warned
MatchedRules []MatchedRule // 命中的规则列表
Blocked bool // 是否被拦截
}
// Filter 安全过滤器
type Filter struct {
loader *Loader
}
// NewFilter 创建过滤器
func NewFilter() *Filter {
return &Filter{loader: GetLoader()}
}
// globalFilter 全局过滤器单例
var globalFilter = &Filter{loader: nil}
// GetFilter 获取全局过滤器
func GetFilter() *Filter {
if globalFilter.loader == nil {
globalFilter.loader = GetLoader()
}
return globalFilter
}
// FilterInput 输入过滤(用户消息 → LLM之前)
func (f *Filter) FilterInput(text, agentID, traceID, userID string) FilterResult {
return f.filter(text, "input", agentID, traceID, userID)
}
// FilterOutput 输出过滤(LLM回复 → 返回用户之前)
func (f *Filter) FilterOutput(text, agentID, traceID, userID string) FilterResult {
return f.filter(text, "output", agentID, traceID, userID)
}
func (f *Filter) filter(text, direction, agentID, traceID, userID string) FilterResult {
rules := f.loader.Load()
if len(rules) == 0 {
return FilterResult{Text: text, Action: "passed"}
}
result := FilterResult{Text: text, Action: "passed"}
for _, rule := range rules {
// 方向过滤
if rule.Direction != "both" && rule.Direction != direction {
continue
}
// agentID过滤(空=全局规则,或匹配特定agent)
if rule.AgentID != "" && rule.AgentID != agentID {
continue
}
matched := matchRule(rule, text)
if !matched {
continue
}
mr := MatchedRule{RuleID: rule.ID, Word: rule.Word, Action: rule.Level}
result.MatchedRules = append(result.MatchedRules, mr)
switch rule.Level {
case "block":
result.Action = "blocked"
result.Blocked = true
result.Text = ""
// 记录日志
SaveFilterLog(traceID, direction, text, "", result.MatchedRules, "blocked", agentID, userID)
return result
case "replace":
replacement := rule.Replacement
if replacement == "" {
replacement = "***"
}
result.Text = replaceWord(result.Text, rule.Word, replacement, rule.IsRegex)
result.Action = "replaced"
case "warn":
if result.Action == "passed" {
result.Action = "warned"
}
}
}
if result.Action != "passed" {
SaveFilterLog(traceID, direction, text, result.Text, result.MatchedRules, result.Action, agentID, userID)
}
return result
}
func matchRule(rule model.SafetyWordRule, text string) bool {
if rule.IsRegex {
re, err := regexp.Compile(rule.Word)
if err != nil {
return false
}
return re.MatchString(text)
}
return strings.Contains(text, rule.Word)
}
func replaceWord(text, word, replacement string, isRegex bool) string {
if isRegex {
re, err := regexp.Compile(word)
if err != nil {
return text
}
return re.ReplaceAllString(text, replacement)
}
return strings.ReplaceAll(text, word, replacement)
}
package safety
import (
"encoding/json"
"log"
"sync"
"time"
"internet-hospital/internal/model"
"internet-hospital/pkg/database"
)
// globalLoader 全局规则加载器(单例)
var globalLoader *Loader
var loaderOnce sync.Once
// Loader 安全词规则加载器(带缓存)
type Loader struct {
mu sync.RWMutex
rules []model.SafetyWordRule
lastLoad time.Time
ttl time.Duration
}
// GetLoader 获取全局加载器单例
func GetLoader() *Loader {
loaderOnce.Do(func() {
globalLoader = &Loader{
ttl: 5 * time.Minute,
}
})
return globalLoader
}
// Load 加载规则(带TTL缓存)
func (l *Loader) Load() []model.SafetyWordRule {
l.mu.RLock()
if time.Since(l.lastLoad) < l.ttl && len(l.rules) > 0 {
rules := l.rules
l.mu.RUnlock()
return rules
}
l.mu.RUnlock()
return l.reload()
}
// Reload 强制重新加载
func (l *Loader) Reload() []model.SafetyWordRule {
return l.reload()
}
func (l *Loader) reload() []model.SafetyWordRule {
db := database.GetDB()
if db == nil {
return nil
}
var rules []model.SafetyWordRule
if err := db.Where("status = 'active'").Order("id asc").Find(&rules).Error; err != nil {
log.Printf("[safety] 加载安全词规则失败: %v", err)
l.mu.RLock()
cached := l.rules
l.mu.RUnlock()
return cached
}
l.mu.Lock()
l.rules = rules
l.lastLoad = time.Now()
l.mu.Unlock()
return rules
}
// SaveFilterLog 异步保存过滤日志
func SaveFilterLog(traceID, direction, original, filtered string, matched []MatchedRule, action, agentID, userID string) {
go func() {
db := database.GetDB()
if db == nil {
return
}
matchedJSON, _ := json.Marshal(matched)
entry := &model.SafetyFilterLog{
TraceID: traceID,
Direction: direction,
OriginalText: truncate(original, 2000),
FilteredText: truncate(filtered, 2000),
MatchedRules: string(matchedJSON),
Action: action,
AgentID: agentID,
UserID: userID,
}
if err := db.Create(entry).Error; err != nil {
log.Printf("[safety] 保存过滤日志失败: %v", err)
}
}()
}
// MatchedRule 命中的规则记录
type MatchedRule struct {
RuleID uint `json:"rule_id"`
Word string `json:"word"`
Action string `json:"action"`
}
func truncate(s string, max int) string {
runes := []rune(s)
if len(runes) <= max {
return s
}
return string(runes[:max]) + "..."
}
package workflow
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
exprlib "github.com/expr-lang/expr"
"github.com/google/uuid"
"internet-hospital/internal/model"
"internet-hospital/pkg/agent"
agentpkg "internet-hospital/pkg/agent"
"internet-hospital/pkg/database"
"github.com/google/uuid"
)
// NodeType 节点类型
......@@ -68,14 +72,14 @@ type Engine struct {
agentSvc interface {
Chat(ctx context.Context, agentID, userID, sessionID, message string, ctxData map[string]interface{}) (interface{}, error)
}
toolRegistry *agent.ToolRegistry
toolRegistry *agentpkg.ToolRegistry
}
var globalEngine *Engine
func GetEngine() *Engine {
if globalEngine == nil {
globalEngine = &Engine{toolRegistry: agent.GetRegistry()}
globalEngine = &Engine{toolRegistry: agentpkg.GetRegistry()}
}
return globalEngine
}
......@@ -253,15 +257,28 @@ func (e *Engine) createHumanTask(_ context.Context, node *Node, execCtx *Executi
return map[string]interface{}{"task_id": task.TaskID, "status": "pending"}, nil
}
func evalCondition(expr string, vars map[string]interface{}) bool {
// 简单实现:支持 "key == 'value'" 格式
func evalCondition(expression string, vars map[string]interface{}) bool {
if expression == "" {
return true
}
program, err := exprlib.Compile(expression, exprlib.Env(vars), exprlib.AsBool())
if err != nil {
// 降级:简单字符串相等检查
for k, v := range vars {
check := fmt.Sprintf("%s == '%v'", k, v)
if expr == check {
if expression == fmt.Sprintf("%s == '%v'", k, v) {
return true
}
}
return false
}
result, err := exprlib.Run(program, vars)
if err != nil {
return false
}
if b, ok := result.(bool); ok {
return b
}
return result != nil
}
// executeParallel 并行执行多个子节点
......@@ -367,84 +384,115 @@ func (e *Engine) executeLoop(ctx context.Context, wf *Workflow, node *Node, exec
}, nil
}
// executeCode 执行代码节点(简单表达式求值)
// executeCode 代码节点:使用 expr-lang 安全执行表达式
func (e *Engine) executeCode(_ context.Context, node *Node, execCtx *ExecutionContext) (interface{}, error) {
code, _ := node.Config["code"].(string)
if code == "" {
return map[string]interface{}{"skipped": "no code"}, nil
}
// 简单的表达式求值(实际生产环境应使用安全的脚本引擎)
// 这里只支持简单的变量赋值和返回
result := make(map[string]interface{})
// 将所有变量复制到结果中
// 合并所有变量作为上下文
env := make(map[string]interface{})
for k, v := range execCtx.Variables {
result[k] = v
env[k] = v
}
// 将所有节点输出复制到结果中
for k, v := range execCtx.NodeOutputs {
result["node_"+k] = v
env["node_"+k] = v
}
return result, nil
program, err := exprlib.Compile(code, exprlib.Env(env))
if err != nil {
return nil, fmt.Errorf("代码节点编译失败: %w", err)
}
result, err := exprlib.Run(program, env)
if err != nil {
return nil, fmt.Errorf("代码节点执行失败: %w", err)
}
// 如果结果是 map,合并回变量
if m, ok := result.(map[string]interface{}); ok {
for k, v := range m {
execCtx.Variables[k] = v
}
}
return map[string]interface{}{"result": result, "code": code}, nil
}
// executeHTTP 执行HTTP请求节点
// executeHTTP 执行 HTTP 请求节点(完整 net/http 实现)
func (e *Engine) executeHTTP(ctx context.Context, node *Node, execCtx *ExecutionContext) (interface{}, error) {
url, _ := node.Config["url"].(string)
rawURL, _ := node.Config["url"].(string)
method, _ := node.Config["method"].(string)
if method == "" {
method = "GET"
}
if url == "" {
if rawURL == "" {
return nil, fmt.Errorf("http node requires url")
}
// 替换URL中的变量
// 替换 URL 中的 {{var}} 变量
for k, v := range execCtx.Variables {
placeholder := fmt.Sprintf("${%s}", k)
if strVal, ok := v.(string); ok {
url = replaceAll(url, placeholder, strVal)
}
rawURL = replaceAll(rawURL, "{{"+k+"}}", fmt.Sprintf("%v", v))
}
// 创建HTTP请求
var reqBody []byte
// 构建请求 Body
var bodyBytes []byte
if body, ok := node.Config["body"].(map[string]interface{}); ok {
reqBody, _ = json.Marshal(body)
// 渲染 body 中的变量
for k, v := range body {
if s, ok2 := v.(string); ok2 {
for varK, varV := range execCtx.Variables {
s = replaceAll(s, "{{"+varK+"}}", fmt.Sprintf("%v", varV))
}
body[k] = s
}
}
bodyBytes, _ = json.Marshal(body)
}
req, err := newHTTPRequest(ctx, method, url, reqBody)
if err != nil {
return nil, err
var bodyReader io.Reader
if len(bodyBytes) > 0 {
bodyReader = bytes.NewReader(bodyBytes)
}
// 设置请求头
timeout := 30 * time.Second
if t, ok := node.Config["timeout"].(float64); ok && t > 0 {
timeout = time.Duration(t) * time.Second
}
httpCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req, err := http.NewRequestWithContext(httpCtx, method, rawURL, bodyReader)
if err != nil {
return nil, fmt.Errorf("创建 HTTP 请求失败: %w", err)
}
if len(bodyBytes) > 0 {
req.Header.Set("Content-Type", "application/json")
}
if headers, ok := node.Config["headers"].(map[string]interface{}); ok {
for k, v := range headers {
if strVal, ok := v.(string); ok {
req.Header[k] = strVal
if sv, ok2 := v.(string); ok2 {
req.Header.Set(k, sv)
}
}
}
// 执行请求
client := &httpClient{timeout: 30 * time.Second}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return map[string]interface{}{
"error": err.Error(),
"status": 0,
}, nil
return map[string]interface{}{"error": err.Error(), "status": 0}, nil
}
defer resp.Body.Close()
return map[string]interface{}{
"status": resp.StatusCode,
"body": resp.Body,
}, nil
respBytes, _ := io.ReadAll(resp.Body)
result := map[string]interface{}{"status": resp.StatusCode}
var parsed interface{}
if json.Unmarshal(respBytes, &parsed) == nil {
result["body"] = parsed
} else {
result["body"] = string(respBytes)
}
return result, nil
}
// executeTemplate 执行模板渲染节点
......@@ -504,39 +552,3 @@ func replaceAll(s, old, new string) string {
return string(result)
}
// 简单的HTTP客户端封装
type httpClient struct {
timeout time.Duration
}
type httpResponse struct {
StatusCode int
Body string
}
func newHTTPRequest(ctx context.Context, method, url string, body []byte) (*httpRequest, error) {
return &httpRequest{
ctx: ctx,
method: method,
url: url,
body: body,
Header: make(map[string]string),
}, nil
}
type httpRequest struct {
ctx context.Context
method string
url string
body []byte
Header map[string]string
}
func (c *httpClient) Do(req *httpRequest) (*httpResponse, error) {
// 简化实现:实际应使用net/http
// 这里返回模拟响应,避免引入额外依赖
return &httpResponse{
StatusCode: 200,
Body: `{"status": "ok"}`,
}, nil
}
......@@ -62,6 +62,32 @@ export interface WorkflowExecution {
completed_at: string;
}
export interface AgentDefinition {
id: number;
agent_id: string;
name: string;
description: string;
category: string;
system_prompt: string;
tools: string; // JSON array string
config: string;
max_iterations: number;
status: string;
created_at: string;
updated_at: string;
}
export type AgentDefinitionParams = {
agent_id: string;
name: string;
description?: string;
category?: string;
system_prompt?: string;
tools?: string;
max_iterations?: number;
status?: string;
};
export interface WorkflowCreateParams {
workflow_id: string;
name: string;
......@@ -84,18 +110,21 @@ export interface KnowledgeDocumentParams {
// ==================== Agent API ====================
// AI 推理调用专用超时(2分钟)
const AI_TIMEOUT = 120000;
export const agentApi = {
chat: (agentId: string, params: {
session_id?: string;
message: string;
context?: Record<string, unknown>;
}) => post<AgentOutput>(`/agent/${agentId}/chat`, params),
}) => post<AgentOutput>(`/agent/${agentId}/chat`, params, { timeout: AI_TIMEOUT }),
listAgents: () =>
get<{ id: string; name: string; description: string }[]>('/agent/list'),
listTools: () =>
get<{ id: string; name: string; description: string; category: string; parameters: Record<string, unknown>; is_enabled: boolean; created_at: string }[]>('/agent/tools'),
get<{ id: string; name: string; display_name: string; description: string; category: string; parameters: Record<string, unknown>; status: string; is_enabled: boolean; cache_ttl: number; timeout: number; max_retries: number; created_at: string }[]>('/agent/tools'),
getSessions: (agentId?: string) =>
get<AgentSession[]>('/agent/sessions', { params: agentId ? { agent_id: agentId } : {} }),
......@@ -115,6 +144,69 @@ export const agentApi = {
get<{ agent_id: string; count: number; avg_iterations: number; avg_tokens: number; success_rate: number }[]>(
'/admin/agent/stats', { params }
),
// Agent 配置 CRUD
listDefinitions: () =>
get<AgentDefinition[]>('/agent/definitions'),
getDefinition: (agentId: string) =>
get<AgentDefinition>(`/agent/definitions/${agentId}`),
createDefinition: (data: AgentDefinitionParams) =>
post<AgentDefinition>('/agent/definitions', data),
updateDefinition: (agentId: string, data: Partial<AgentDefinitionParams>) =>
put<AgentDefinition>(`/agent/definitions/${agentId}`, data),
reloadAgent: (agentId: string) =>
put<null>(`/agent/definitions/${agentId}/reload`, {}),
updateToolStatus: (name: string, status: 'active' | 'disabled') =>
put<null>(`/agent/tools/${name}/status`, { status }),
};
// ==================== HTTP Tool API ====================
export interface HTTPToolDefinition {
id: number;
name: string;
display_name: string;
description: string;
category: string;
method: string;
url: string;
headers: string;
body_template: string;
auth_type: string;
auth_config: string;
parameters: string;
timeout: number;
cache_ttl: number;
status: string;
created_by: string;
created_at: string;
updated_at: string;
}
export type HTTPToolParams = Omit<HTTPToolDefinition, 'id' | 'created_at' | 'updated_at'>;
export const httpToolApi = {
list: () => get<HTTPToolDefinition[]>('/agent/http-tools'),
create: (data: Partial<HTTPToolParams>) =>
post<HTTPToolDefinition>('/agent/http-tools', data),
update: (id: number, data: Partial<HTTPToolParams>) =>
put<HTTPToolDefinition>(`/agent/http-tools/${id}`, data),
delete: (id: number) => del<null>(`/agent/http-tools/${id}`),
test: (id: number, params: Record<string, unknown>) =>
post<{ success: boolean; data?: unknown; error?: string }>(
`/agent/http-tools/${id}/test`, { params }
),
reload: () => post<{ message: string }>('/agent/http-tools/reload', {}),
};
// ==================== Workflow API ====================
......@@ -127,7 +219,7 @@ export const workflowApi = {
update: (id: number, data: Partial<WorkflowCreateParams>) =>
put<unknown>(`/admin/workflows/${id}`, data),
publish: (id: number) => post<null>(`/admin/workflows/${id}/publish`),
publish: (id: number) => put<null>(`/admin/workflows/${id}/publish`, {}),
execute: (workflowId: string, input?: Record<string, unknown>) =>
post<{ execution_id: string }>(`/workflow/${workflowId}/execute`, input || {}),
......
......@@ -59,7 +59,7 @@ export const chronicApi = {
listRenewals: () => get<RenewalRequest[]>('/chronic/renewals'),
createRenewal: (data: { chronic_id?: string; disease_name: string; medicines: string[]; reason: string }) =>
post<RenewalRequest>('/chronic/renewals', data),
getAIAdvice: (id: string) => post<{ advice: string }>(`/chronic/renewals/${id}/ai-advice`, {}),
getAIAdvice: (id: string) => post<{ advice: string }>(`/chronic/renewals/${id}/ai-advice`, {}, { timeout: 120000 }),
// 用药提醒
listReminders: () => get<MedicationReminder[]>('/chronic/reminders'),
......
......@@ -129,7 +129,7 @@ export const consultApi = {
tool_calls?: import('./agent').ToolCall[];
iterations?: number;
total_tokens?: number;
}>(`/consult/${id}/ai-assist`, { scene }),
}>(`/consult/${id}/ai-assist`, { scene }, { timeout: 120000 }),
// 取消问诊
cancelConsult: (id: string, reason?: string) =>
......
......@@ -191,5 +191,5 @@ export const doctorPortalApi = {
iterations?: number;
has_warning: boolean;
has_contraindication: boolean;
}>('/doctor-portal/prescription/check', params),
}>('/doctor-portal/prescription/check', params, { timeout: 120000 }),
};
......@@ -93,7 +93,7 @@ export const healthApi = {
listReports: () => get<LabReport[]>('/health/reports'),
createReport: (data: CreateReportParams) => post<LabReport>('/health/reports', data),
deleteReport: (id: string) => del<null>(`/health/reports/${id}`),
aiInterpret: (id: string) => post<{ interpret: string }>(`/health/reports/${id}/ai-interpret`),
aiInterpret: (id: string) => post<{ interpret: string }>(`/health/reports/${id}/ai-interpret`, undefined, { timeout: 120000 }),
// 家庭成员
listFamily: () => get<FamilyMember[]>('/health/family'),
createFamily: (data: FamilyMemberParams) => post<FamilyMember>('/health/family', data),
......
......@@ -21,9 +21,10 @@ const getBaseURL = () => {
};
// 创建 axios 实例
// 普通接口:30s;AI / Agent 接口(模型推理)可能需要 2min,通过 config.timeout 单独覆盖
const request: AxiosInstance = axios.create({
baseURL: getBaseURL(),
timeout: 10000,
timeout: 30000,
headers: {
'Content-Type': 'application/json',
},
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -7,7 +7,8 @@ import {
DashboardOutlined, UserOutlined, TeamOutlined, ApartmentOutlined,
SettingOutlined, LogoutOutlined, BellOutlined, MedicineBoxOutlined,
FileSearchOutlined, FileTextOutlined, RobotOutlined, SafetyCertificateOutlined,
ApiOutlined, DeploymentUnitOutlined, BookOutlined,
ApiOutlined, DeploymentUnitOutlined, BookOutlined, CheckCircleOutlined,
SafetyOutlined, FundOutlined, AppstoreOutlined, CloudOutlined,
} from '@ant-design/icons';
import { useUserStore } from '@/store/userStore';
......@@ -34,9 +35,14 @@ const menuItems = [
key: 'ai-platform', icon: <ApiOutlined />, label: '智能体平台',
children: [
{ key: '/admin/agents', icon: <RobotOutlined />, label: 'Agent管理' },
{ key: '/admin/tools', icon: <ApiOutlined />, label: 'Tools工具' },
{ key: '/admin/tool-market', icon: <AppstoreOutlined />, label: '工具市场' },
{ key: '/admin/tools', icon: <ApiOutlined />, label: '内置工具' },
{ key: '/admin/http-tools', icon: <CloudOutlined />, label: 'HTTP工具' },
{ key: '/admin/workflows', icon: <DeploymentUnitOutlined />, label: '工作流' },
{ key: '/admin/tasks', icon: <CheckCircleOutlined />, label: '人工审核' },
{ key: '/admin/knowledge', icon: <BookOutlined />, label: '知识库' },
{ key: '/admin/safety', icon: <SafetyOutlined />, label: '内容安全' },
{ key: '/admin/ai-center', icon: <FundOutlined />, label: 'AI运营中心' },
],
},
];
......@@ -68,7 +74,9 @@ export default function AdminLayout({ children }: { children: React.ReactNode })
const getSelectedKeys = () => {
const allKeys = ['/admin/dashboard', '/admin/patients', '/admin/doctors', '/admin/admins',
'/admin/departments', '/admin/consultations', '/admin/prescription', '/admin/pharmacy',
'/admin/ai-config', '/admin/compliance', '/admin/agents', '/admin/tools', '/admin/workflows', '/admin/knowledge'];
'/admin/ai-config', '/admin/compliance', '/admin/agents', '/admin/tool-market',
'/admin/tools', '/admin/http-tools', '/admin/workflows',
'/admin/tasks', '/admin/knowledge', '/admin/safety', '/admin/ai-center'];
const match = allKeys.find(k => currentPath.startsWith(k));
return match ? [match] : [];
};
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -7,6 +7,7 @@ import {
DashboardOutlined, UserOutlined, MessageOutlined, CalendarOutlined,
SettingOutlined, LogoutOutlined, BellOutlined, MedicineBoxOutlined,
RobotOutlined, DollarOutlined, FolderOpenOutlined, SafetyCertificateOutlined,
CheckCircleOutlined,
} from '@ant-design/icons';
import { useUserStore } from '@/store/userStore';
......@@ -16,6 +17,7 @@ const { Text } = Typography;
const menuItems = [
{ key: '/doctor/workbench', icon: <DashboardOutlined />, label: '工作台' },
{ key: '/doctor/consult', icon: <MessageOutlined />, label: '问诊大厅' },
{ key: '/doctor/tasks', icon: <CheckCircleOutlined />, label: '待办任务' },
{ key: '/doctor/chronic/review', icon: <SafetyCertificateOutlined />, label: '慢病续方' },
{ key: '/doctor/patient', icon: <FolderOpenOutlined />, label: '患者档案' },
{ key: '/doctor/schedule', icon: <CalendarOutlined />, label: '排班管理' },
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment