1. generate_reply: 改为优先关键词匹配,不匹配则调用 AI 2. _ai_generate_reply: 改进 prompt,加入对话上下文、微信风格要求 3. 要求回复简洁(50字以内),符合聊天风格
414 lines
14 KiB
Python
414 lines
14 KiB
Python
"""
|
||
核心引擎
|
||
WeChat Agent Core Engine
|
||
"""
|
||
|
||
import time
|
||
import logging
|
||
import threading
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Optional, Callable, Dict, Any
|
||
from enum import Enum
|
||
from queue import Queue
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentState(Enum):
|
||
"""Agent 状态"""
|
||
IDLE = "idle"
|
||
RUNNING = "running"
|
||
PAUSED = "paused"
|
||
ERROR = "error"
|
||
|
||
|
||
@dataclass
|
||
class ChatSnapshot:
|
||
"""聊天快照"""
|
||
timestamp: float
|
||
chat_name: str
|
||
messages: List[Dict[str, Any]]
|
||
screenshot_path: str
|
||
has_new: bool = False
|
||
|
||
|
||
@dataclass
|
||
class ReplyResult:
|
||
"""回复结果"""
|
||
success: bool
|
||
content: str
|
||
reason: str = ""
|
||
|
||
|
||
class MessageProcessor:
|
||
"""消息处理器"""
|
||
|
||
def __init__(self, vlm_client, llm_client, config):
|
||
self.vlm_client = vlm_client
|
||
self.llm_client = llm_client
|
||
self.config = config
|
||
self._rules = config.rules
|
||
|
||
def should_reply(self, chat_snapshot: ChatSnapshot) -> bool:
|
||
"""判断是否需要回复"""
|
||
# 检查是否有消息
|
||
messages = chat_snapshot.messages
|
||
if not messages:
|
||
logger.debug("没有消息,不回复")
|
||
return False
|
||
|
||
# VLM 返回的消息是按时间倒序的(最新的在前)
|
||
# 取第一条(最新的)消息进行判断
|
||
latest_msg = messages[0]
|
||
sender = latest_msg.get("sender", "")
|
||
content = latest_msg.get("content", "")
|
||
is_self = latest_msg.get("is_self", False)
|
||
|
||
logger.debug(f"最新消息: sender={sender}, is_self={is_self}, content={content[:30]}")
|
||
|
||
# 如果是自己发的消息,不回复
|
||
if is_self:
|
||
logger.debug("自己发的消息,不回复")
|
||
return False
|
||
|
||
# 检查是否有新消息(未读标记)
|
||
if not chat_snapshot.has_new:
|
||
logger.debug("没有新消息,不回复")
|
||
return False
|
||
|
||
return True
|
||
|
||
def generate_reply(self, chat_snapshot: ChatSnapshot) -> str:
|
||
"""生成回复内容
|
||
|
||
策略:优先关键词匹配,如果没有匹配则使用 AI 生成回复
|
||
"""
|
||
# 先检查关键词规则(取最新消息)
|
||
latest_content = chat_snapshot.messages[0].get("content", "") if chat_snapshot.messages else ""
|
||
|
||
# 先检查关键词匹配
|
||
for rule in self._rules:
|
||
if not rule.enabled or rule.reply_type != "keyword":
|
||
continue
|
||
|
||
for keyword in rule.keywords:
|
||
if keyword in latest_content:
|
||
logger.info(f"关键词匹配: {keyword}")
|
||
return rule.reply_content
|
||
|
||
# 没有关键词匹配,使用 AI 生成回复
|
||
for rule in self._rules:
|
||
if not rule.enabled or rule.reply_type != "AI":
|
||
continue
|
||
logger.info("使用 AI 生成回复")
|
||
return self._ai_generate_reply(chat_snapshot)
|
||
|
||
# 如果没有配置 AI 回复规则,也尝试调用 AI
|
||
if not self._rules:
|
||
logger.info("无规则配置,使用 AI 生成回复")
|
||
return self._ai_generate_reply(chat_snapshot)
|
||
|
||
return ""
|
||
|
||
def _ai_generate_reply(self, chat_snapshot: ChatSnapshot) -> str:
|
||
"""AI 生成回复"""
|
||
try:
|
||
# 构造 prompt
|
||
chat_name = chat_snapshot.chat_name
|
||
messages = chat_snapshot.messages[:10] # 取最新10条(VLM返回的是倒序)
|
||
|
||
prompt = f"""你是微信聊天助手,正在和「{chat_name}」对话。
|
||
|
||
对话历史(最新在前):
|
||
"""
|
||
for msg in messages:
|
||
sender = "我" if msg.get("is_self") else "对方"
|
||
content = msg.get("content", "")
|
||
time = msg.get("time", "")
|
||
prompt += f"[{time}] {sender}:{content}\n"
|
||
|
||
prompt += """
|
||
请根据对话上下文,生成一条自然的回复。
|
||
要求:
|
||
1. 回复要符合微信聊天风格,轻松友好
|
||
2. 简洁明了,不要太长(50字以内)
|
||
3. 如果对方提问,尽量回答问题
|
||
4. 如果对方分享事情,给予适当回应
|
||
5. 只返回回复内容,不要其他文字"""
|
||
|
||
logger.debug(f"AI 回复 prompt:\n{prompt[:300]}...")
|
||
|
||
# 调用 LLM
|
||
response = self.llm_client.chat([
|
||
{"role": "user", "content": prompt}
|
||
])
|
||
|
||
text = response.get("text", "").strip()
|
||
logger.info(f"AI 生成回复: {text[:50]}...")
|
||
return text
|
||
|
||
except Exception as e:
|
||
logger.error(f"AI 生成回复失败: {e}")
|
||
return ""
|
||
|
||
def match_keyword_rule(self, content: str) -> Optional[str]:
|
||
"""匹配关键词规则"""
|
||
for rule in self._rules:
|
||
if not rule.enabled or rule.reply_type != "keyword":
|
||
continue
|
||
|
||
for keyword in rule.keywords:
|
||
if keyword in content:
|
||
return rule.reply_content
|
||
return None
|
||
|
||
|
||
class WeChatAgent:
|
||
"""微信 Agent"""
|
||
|
||
def __init__(
|
||
self,
|
||
wechat_controller,
|
||
vlm_client,
|
||
llm_client,
|
||
config,
|
||
message_queue: Queue = None
|
||
):
|
||
self.wechat = wechat_controller
|
||
self.vlm = vlm_client
|
||
self.llm = llm_client
|
||
self.config = config
|
||
self.processor = MessageProcessor(vlm_client, llm_client, config)
|
||
|
||
self._state = AgentState.IDLE
|
||
self._thread: Optional[threading.Thread] = None
|
||
self._stop_event = threading.Event()
|
||
self._pause_event = threading.Event()
|
||
|
||
self._message_queue = message_queue or Queue()
|
||
self._callbacks: Dict[str, List[Callable]] = {
|
||
"on_message": [], # 收到新消息
|
||
"on_reply": [], # 发送回复
|
||
"on_error": [], # 发生错误
|
||
"on_state_change": [], # 状态变化
|
||
}
|
||
|
||
self._last_processed_time: Dict[str, float] = {} # 记录每个聊天的处理时间
|
||
|
||
@property
|
||
def state(self) -> AgentState:
|
||
"""获取状态"""
|
||
return self._state
|
||
|
||
def start(self):
|
||
"""启动 Agent"""
|
||
if self._state == AgentState.RUNNING:
|
||
logger.warning("Agent 已经在运行中")
|
||
return
|
||
|
||
self._stop_event.clear()
|
||
self._pause_event.clear()
|
||
self._state = AgentState.RUNNING
|
||
|
||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||
self._thread.start()
|
||
|
||
self._emit("on_state_change", self._state)
|
||
logger.info("Agent 已启动")
|
||
|
||
def stop(self):
|
||
"""停止 Agent"""
|
||
self._stop_event.set()
|
||
if self._thread and self._thread.is_alive():
|
||
self._thread.join(timeout=5)
|
||
self._state = AgentState.IDLE
|
||
self._emit("on_state_change", self._state)
|
||
logger.info("Agent 已停止")
|
||
|
||
def pause(self):
|
||
"""暂停 Agent"""
|
||
self._pause_event.set()
|
||
self._state = AgentState.PAUSED
|
||
self._emit("on_state_change", self._state)
|
||
logger.info("Agent 已暂停")
|
||
|
||
def resume(self):
|
||
"""恢复 Agent"""
|
||
self._pause_event.clear()
|
||
self._state = AgentState.RUNNING
|
||
self._emit("on_state_change", self._state)
|
||
logger.info("Agent 已恢复")
|
||
|
||
def _run_loop(self):
|
||
"""主循环"""
|
||
poll_interval = self.config.wechat.poll_interval
|
||
|
||
while not self._stop_event.is_set():
|
||
try:
|
||
# 检查暂停
|
||
if self._pause_event.is_set():
|
||
time.sleep(0.5)
|
||
continue
|
||
|
||
# 检查连接
|
||
if not self.wechat.is_connected():
|
||
logger.warning("微信未连接,尝试重连...")
|
||
if not self.wechat.connect():
|
||
time.sleep(poll_interval)
|
||
continue
|
||
|
||
# 执行一次轮询
|
||
self._poll_once()
|
||
|
||
# 等待
|
||
time.sleep(poll_interval)
|
||
|
||
except Exception as e:
|
||
logger.error(f"轮询异常: {e}")
|
||
self._state = AgentState.ERROR
|
||
self._emit("on_error", str(e))
|
||
time.sleep(poll_interval)
|
||
|
||
def _poll_once(self):
|
||
"""执行一次轮询"""
|
||
try:
|
||
# 1. 截图
|
||
screenshot_path = self.wechat.screenshot()
|
||
|
||
# 2. VLM 分析截图
|
||
chat_info = self.vlm.analyze_chat_screenshot(screenshot_path)
|
||
|
||
# 3. 检查是否有新消息
|
||
has_new = chat_info.get("has_new_message", False)
|
||
chat_name = chat_info.get("current_chat", "")
|
||
messages = chat_info.get("messages", [])
|
||
|
||
# 防重复处理:基于最新消息内容hash,10秒内不重复回复同一条
|
||
current_time = time.time()
|
||
if messages:
|
||
latest_content = messages[0].get("content", "")[:50] # 取最新消息的前50字符
|
||
dedup_key = f"{chat_name}_{latest_content}"
|
||
if dedup_key in self._last_processed_time:
|
||
if current_time - self._last_processed_time[dedup_key] < 10:
|
||
logger.debug(f"消息已处理过,跳过: {latest_content[:30]}")
|
||
return
|
||
self._last_processed_time[dedup_key] = current_time
|
||
|
||
if has_new or messages:
|
||
# 创建快照
|
||
snapshot = ChatSnapshot(
|
||
timestamp=current_time,
|
||
chat_name=chat_name,
|
||
messages=messages,
|
||
screenshot_path=screenshot_path,
|
||
has_new=has_new
|
||
)
|
||
|
||
# 详细日志:打印所有消息(按VLM返回顺序)
|
||
logger.debug(f"消息列表(共{len(messages)}条,新消息在前):")
|
||
for i, msg in enumerate(messages[:5]):
|
||
logger.debug(f" [{i}] sender={msg.get('sender')}, is_self={msg.get('is_self')}, time={msg.get('time')}, content={msg.get('content', '')[:30]}")
|
||
|
||
# 触发消息回调(传入最新消息用于显示)
|
||
latest_msg = messages[0] if messages else {}
|
||
self._emit("on_message", {
|
||
"chat_name": chat_name,
|
||
"latest_message": latest_msg,
|
||
"has_new": has_new,
|
||
"all_messages": messages
|
||
})
|
||
|
||
# 判断是否需要回复
|
||
# 策略:只要对方有新消息(is_self=False)就回复,不依赖 has_new_message
|
||
if messages and not messages[0].get("is_self", True):
|
||
logger.info(f"检测到对方新消息,准备回复...")
|
||
reply = self.processor.generate_reply(snapshot)
|
||
if reply:
|
||
logger.info(f"发送回复: {reply[:50]}...")
|
||
result = self.send_reply(reply)
|
||
self._emit("on_reply", result)
|
||
else:
|
||
logger.warning("生成回复为空")
|
||
else:
|
||
logger.debug("不需要回复(无新消息或自己发送)")
|
||
|
||
except Exception as e:
|
||
logger.error(f"轮询处理异常: {e}")
|
||
raise
|
||
|
||
def send_reply(self, text: str) -> ReplyResult:
|
||
"""发送回复"""
|
||
try:
|
||
success = self.wechat.send_text(text)
|
||
return ReplyResult(
|
||
success=success,
|
||
content=text,
|
||
reason="发送成功" if success else "发送失败"
|
||
)
|
||
except Exception as e:
|
||
return ReplyResult(
|
||
success=False,
|
||
content=text,
|
||
reason=str(e)
|
||
)
|
||
|
||
def on(self, event: str, callback: Callable):
|
||
"""注册事件回调"""
|
||
if event in self._callbacks:
|
||
self._callbacks[event].append(callback)
|
||
|
||
def _emit(self, event: str, *args):
|
||
"""触发事件"""
|
||
if event in self._callbacks:
|
||
for callback in self._callbacks[event]:
|
||
try:
|
||
callback(*args)
|
||
except Exception as e:
|
||
logger.error(f"回调执行异常: {e}")
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
"""获取状态信息"""
|
||
return {
|
||
"state": self._state.value,
|
||
"connected": self.wechat.is_connected(),
|
||
"poll_interval": self.config.wechat.poll_interval,
|
||
"rules_count": len([r for r in self.config.rules if r.enabled])
|
||
}
|
||
|
||
|
||
class MockWeChatController:
|
||
"""模拟微信控制器(用于测试)"""
|
||
|
||
def __init__(self):
|
||
self._connected = True
|
||
self._messages = [
|
||
{"sender": "张三", "content": "你好", "time": "10:30", "is_self": False},
|
||
{"sender": "张三", "content": "这个产品怎么卖?", "time": "10:31", "is_self": False},
|
||
]
|
||
|
||
def connect(self, timeout: float = 10) -> bool:
|
||
return True
|
||
|
||
def is_connected(self) -> bool:
|
||
return self._connected
|
||
|
||
def screenshot(self, output_path: str = None) -> str:
|
||
import tempfile
|
||
from pathlib import Path
|
||
path = Path(tempfile.gettempdir()) / "mock_screenshot.png"
|
||
# 创建空白图片
|
||
from PIL import Image
|
||
img = Image.new("RGB", (800, 600), color="white")
|
||
img.save(str(path))
|
||
return str(path)
|
||
|
||
def send_text(self, text: str) -> bool:
|
||
self._messages.append({"sender": "我", "content": text, "time": "10:32", "is_self": True})
|
||
return True
|
||
|
||
def get_message_list(self, count: int = 10) -> List:
|
||
return self._messages[-count:]
|
||
|
||
def disconnect(self):
|
||
self._connected = False
|