Files
wechat-agent/src/core/engine.py
jesxion f325025365 修复:改进VLM未读判断 + 修复callback显示逻辑
1. VLM prompt: has_new_message 改为检查左侧边栏红点,而非右上角
2. engine.py: callback 显示最新消息,清晰标注 has_new
3. main.py: on_message 回调更新以显示 has_new 状态
2026-04-13 12:08:18 +08:00

384 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
核心引擎
WeChat Agent Core Engine
"""
import time
import logging
import threading
from dataclasses import dataclass, field
from typing import List, Optional, Callable, Dict, Any
from enum import Enum
from queue import Queue
logger = logging.getLogger(__name__)
class AgentState(Enum):
"""Agent 状态"""
IDLE = "idle"
RUNNING = "running"
PAUSED = "paused"
ERROR = "error"
@dataclass
class ChatSnapshot:
"""聊天快照"""
timestamp: float
chat_name: str
messages: List[Dict[str, Any]]
screenshot_path: str
has_new: bool = False
@dataclass
class ReplyResult:
"""回复结果"""
success: bool
content: str
reason: str = ""
class MessageProcessor:
"""消息处理器"""
def __init__(self, vlm_client, llm_client, config):
self.vlm_client = vlm_client
self.llm_client = llm_client
self.config = config
self._rules = config.rules
def should_reply(self, chat_snapshot: ChatSnapshot) -> bool:
"""判断是否需要回复"""
# 检查是否有消息
messages = chat_snapshot.messages
if not messages:
logger.debug("没有消息,不回复")
return False
# VLM 返回的消息是按时间倒序的(最新的在前)
# 取第一条(最新的)消息进行判断
latest_msg = messages[0]
sender = latest_msg.get("sender", "")
content = latest_msg.get("content", "")
is_self = latest_msg.get("is_self", False)
logger.debug(f"最新消息: sender={sender}, is_self={is_self}, content={content[:30]}")
# 如果是自己发的消息,不回复
if is_self:
logger.debug("自己发的消息,不回复")
return False
# 检查是否有新消息(未读标记)
if not chat_snapshot.has_new:
logger.debug("没有新消息,不回复")
return False
return True
def generate_reply(self, chat_snapshot: ChatSnapshot) -> str:
"""生成回复内容"""
# 先检查关键词规则(取最新消息)
latest_content = chat_snapshot.messages[0].get("content", "") if chat_snapshot.messages else ""
for rule in self._rules:
if not rule.enabled:
continue
if rule.reply_type == "keyword":
# 关键词匹配
for keyword in rule.keywords:
if keyword in latest_content:
logger.info(f"关键词匹配: {keyword}")
return rule.reply_content
elif rule.reply_type == "AI":
# AI 生成回复
return self._ai_generate_reply(chat_snapshot)
return ""
def _ai_generate_reply(self, chat_snapshot: ChatSnapshot) -> str:
"""AI 生成回复"""
try:
# 构造 prompt
prompt = f"""当前聊天: {chat_snapshot.chat_name}
历史消息(按时间倒序):
"""
for msg in chat_snapshot.messages[:10]: # 取最新10条VLM返回的是倒序
sender = "" if msg.get("is_self") else "对方"
prompt += f"- [{sender}] {msg.get('content', '')}\n"
prompt += """
请生成一条合适的回复,只返回回复内容,不要其他文字。"""
# 调用 LLM
response = self.llm_client.chat([
{"role": "user", "content": prompt}
])
return response.get("text", "")
except Exception as e:
logger.error(f"AI 生成回复失败: {e}")
return ""
def match_keyword_rule(self, content: str) -> Optional[str]:
"""匹配关键词规则"""
for rule in self._rules:
if not rule.enabled or rule.reply_type != "keyword":
continue
for keyword in rule.keywords:
if keyword in content:
return rule.reply_content
return None
class WeChatAgent:
"""微信 Agent"""
def __init__(
self,
wechat_controller,
vlm_client,
llm_client,
config,
message_queue: Queue = None
):
self.wechat = wechat_controller
self.vlm = vlm_client
self.llm = llm_client
self.config = config
self.processor = MessageProcessor(vlm_client, llm_client, config)
self._state = AgentState.IDLE
self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._pause_event = threading.Event()
self._message_queue = message_queue or Queue()
self._callbacks: Dict[str, List[Callable]] = {
"on_message": [], # 收到新消息
"on_reply": [], # 发送回复
"on_error": [], # 发生错误
"on_state_change": [], # 状态变化
}
self._last_processed_time: Dict[str, float] = {} # 记录每个聊天的处理时间
@property
def state(self) -> AgentState:
"""获取状态"""
return self._state
def start(self):
"""启动 Agent"""
if self._state == AgentState.RUNNING:
logger.warning("Agent 已经在运行中")
return
self._stop_event.clear()
self._pause_event.clear()
self._state = AgentState.RUNNING
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._emit("on_state_change", self._state)
logger.info("Agent 已启动")
def stop(self):
"""停止 Agent"""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=5)
self._state = AgentState.IDLE
self._emit("on_state_change", self._state)
logger.info("Agent 已停止")
def pause(self):
"""暂停 Agent"""
self._pause_event.set()
self._state = AgentState.PAUSED
self._emit("on_state_change", self._state)
logger.info("Agent 已暂停")
def resume(self):
"""恢复 Agent"""
self._pause_event.clear()
self._state = AgentState.RUNNING
self._emit("on_state_change", self._state)
logger.info("Agent 已恢复")
def _run_loop(self):
"""主循环"""
poll_interval = self.config.wechat.poll_interval
while not self._stop_event.is_set():
try:
# 检查暂停
if self._pause_event.is_set():
time.sleep(0.5)
continue
# 检查连接
if not self.wechat.is_connected():
logger.warning("微信未连接,尝试重连...")
if not self.wechat.connect():
time.sleep(poll_interval)
continue
# 执行一次轮询
self._poll_once()
# 等待
time.sleep(poll_interval)
except Exception as e:
logger.error(f"轮询异常: {e}")
self._state = AgentState.ERROR
self._emit("on_error", str(e))
time.sleep(poll_interval)
def _poll_once(self):
"""执行一次轮询"""
try:
# 1. 截图
screenshot_path = self.wechat.screenshot()
# 2. VLM 分析截图
chat_info = self.vlm.analyze_chat_screenshot(screenshot_path)
# 3. 检查是否有新消息
has_new = chat_info.get("has_new_message", False)
chat_name = chat_info.get("current_chat", "")
messages = chat_info.get("messages", [])
# 防重复处理(同一聊天 5 秒内不重复处理)
current_time = time.time()
chat_key = f"{chat_name}_{hash(str(messages[-1:]))}"
if chat_key in self._last_processed_time:
if current_time - self._last_processed_time[chat_key] < 5:
return
if has_new or messages:
self._last_processed_time[chat_key] = current_time
# 创建快照
snapshot = ChatSnapshot(
timestamp=current_time,
chat_name=chat_name,
messages=messages,
screenshot_path=screenshot_path,
has_new=has_new
)
# 详细日志打印所有消息按VLM返回顺序
logger.debug(f"消息列表(共{len(messages)}条,新消息在前):")
for i, msg in enumerate(messages[:5]):
logger.debug(f" [{i}] sender={msg.get('sender')}, is_self={msg.get('is_self')}, time={msg.get('time')}, content={msg.get('content', '')[:30]}")
# 触发消息回调(传入最新消息用于显示)
latest_msg = messages[0] if messages else {}
self._emit("on_message", {
"chat_name": chat_name,
"latest_message": latest_msg,
"has_new": has_new,
"all_messages": messages
})
# 判断是否需要回复
if self.processor.should_reply(snapshot):
logger.info(f"需要回复,生成回复内容...")
reply = self.processor.generate_reply(snapshot)
if reply:
logger.info(f"发送回复: {reply[:50]}...")
result = self.send_reply(reply)
self._emit("on_reply", result)
else:
logger.warning("生成回复为空")
else:
logger.debug("不需要回复")
except Exception as e:
logger.error(f"轮询处理异常: {e}")
raise
def send_reply(self, text: str) -> ReplyResult:
"""发送回复"""
try:
success = self.wechat.send_text(text)
return ReplyResult(
success=success,
content=text,
reason="发送成功" if success else "发送失败"
)
except Exception as e:
return ReplyResult(
success=False,
content=text,
reason=str(e)
)
def on(self, event: str, callback: Callable):
"""注册事件回调"""
if event in self._callbacks:
self._callbacks[event].append(callback)
def _emit(self, event: str, *args):
"""触发事件"""
if event in self._callbacks:
for callback in self._callbacks[event]:
try:
callback(*args)
except Exception as e:
logger.error(f"回调执行异常: {e}")
def get_status(self) -> Dict[str, Any]:
"""获取状态信息"""
return {
"state": self._state.value,
"connected": self.wechat.is_connected(),
"poll_interval": self.config.wechat.poll_interval,
"rules_count": len([r for r in self.config.rules if r.enabled])
}
class MockWeChatController:
"""模拟微信控制器(用于测试)"""
def __init__(self):
self._connected = True
self._messages = [
{"sender": "张三", "content": "你好", "time": "10:30", "is_self": False},
{"sender": "张三", "content": "这个产品怎么卖?", "time": "10:31", "is_self": False},
]
def connect(self, timeout: float = 10) -> bool:
return True
def is_connected(self) -> bool:
return self._connected
def screenshot(self, output_path: str = None) -> str:
import tempfile
from pathlib import Path
path = Path(tempfile.gettempdir()) / "mock_screenshot.png"
# 创建空白图片
from PIL import Image
img = Image.new("RGB", (800, 600), color="white")
img.save(str(path))
return str(path)
def send_text(self, text: str) -> bool:
self._messages.append({"sender": "", "content": text, "time": "10:32", "is_self": True})
return True
def get_message_list(self, count: int = 10) -> List:
return self._messages[-count:]
def disconnect(self):
self._connected = False