# memory/base.py
from abc import ABC, abstractmethod
from typing import List, Dict


class BaseMemory(ABC):
    @abstractmethod
    def load(self) -> List[Dict]: ...

    @abstractmethod
    def save(self, messages: List[Dict]): ...

    @abstractmethod
    def clear(self): ...

    @abstractmethod
    def append(self, message: Dict): ...

# memory/redis_memory.py
import json


class RedisMemory(BaseMemory):
    def __init__(self, session_id: str, agent_id: str, redis_client):
        self.key = f"memory:{session_id}:{agent_id}"
        self.redis = redis_client

    def load(self):
        raw = self.redis.get(self.key)
        if raw:
            try:
                return json.loads(raw)
            except Exception as e:
                print(f"[Memory Load Error] {e}")
                return []
        return []

    def save(self, messages):
        try:
            self.redis.set(self.key, json.dumps(messages, ensure_ascii=False))
        except Exception as e:
            print(f"[Memory Save Error] {e}")

    def clear(self):
        self.redis.delete(self.key)

    def append(self, message):
        messages = self.load()
        messages.append(message)
        self.save(messages)

    def extend(self, messages_to_add: list):
        """
        批量追加多条消息
        """
        messages = self.load()
        messages.extend(messages_to_add)
        self.save(messages)

    def __repr__(self):
        return f"<RedisMemory key={self.key} count={len(self.load())}>"

# ✅ 全局模拟 Redis 的共享内存字典
in_memory_store: Dict[str, List[Dict]] = {}

class InMemoryMemory(BaseMemory):
    def __init__(self, session_id: str, agent_id: str):
        self.key = f"memory:{session_id}:{agent_id}"

    def load(self) -> List[Dict]:
        return in_memory_store.get(self.key, []).copy()

    def save(self, messages: List[Dict]):
        in_memory_store[self.key] = messages.copy()

    def clear(self):
        in_memory_store.pop(self.key, None)

    def append(self, message: Dict):
        messages = in_memory_store.get(self.key, [])
        messages.append(message)
        in_memory_store[self.key] = messages

    def extend(self, messages_to_add: List[Dict]):
        messages = in_memory_store.get(self.key, [])
        messages.extend(messages_to_add)
        in_memory_store[self.key] = messages

    def __repr__(self):
        count = len(in_memory_store.get(self.key, []))
        return f"<InMemoryMemory key={self.key} count={count}>"