diff --git a/evaluation/README.md b/evaluation/README.md index a5a4f32ca..8c1896943 100644 --- a/evaluation/README.md +++ b/evaluation/README.md @@ -64,6 +64,12 @@ First prepare the dataset `longmemeval_s` from https://huggingface.co/datasets/x ./scripts/run_lme_eval.sh ``` +#### Question date and `reference_time` + +LongMemEval gives each question a **question date**; evaluation should use that as the reference “now”, not the time when you run the script. The LongMemEval search script passes `question_date` as **`reference_time`** where the backend supports it. + +**MemOS Cloud** currently does not support supplying question date on search the same way, so LongMemEval scores there may differ from a spec-faithful run. **Prefer evaluating LongMemEval against the open-source MemOS server** when you need comparable numbers. + ### PrefEval Evaluation Downloading benchmark_dataset/filtered_inter_turns.json from https://github.com/amazon-science/PrefEval/blob/main/benchmark_dataset/filtered_inter_turns.json and save it as `./data/prefeval/filtered_inter_turns.json`. To evaluate the **Prefeval** dataset — run the following [script](./scripts/run_prefeval_eval.sh): diff --git a/evaluation/scripts/longmemeval/lme_search.py b/evaluation/scripts/longmemeval/lme_search.py index 8e0e3c5c2..1eea8cd37 100644 --- a/evaluation/scripts/longmemeval/lme_search.py +++ b/evaluation/scripts/longmemeval/lme_search.py @@ -41,9 +41,11 @@ def mem0_search(client, query, user_id, top_k): return context, duration_ms -def memos_search(client, query, user_id, top_k): +def memos_search(client, query, user_id, top_k, reference_time=None): start = time() - results = client.search(query=query, user_id=user_id, top_k=top_k) + results = client.search( + query=query, user_id=user_id, top_k=top_k, reference_time=reference_time + ) context = ( "\n".join([i["memory"] for i in results["text_mem"][0]["memories"]]) + f"\n{results.get('pref_string', '')}" @@ -122,12 +124,16 @@ def process_user(lme_df, conv_idx, frame, version, top_k=20): from utils.client import MemosApiClient client = MemosApiClient() - context, duration_ms = memos_search(client, question, user_id, top_k) + context, duration_ms = memos_search( + client, question, user_id, top_k, reference_time=question_date + ) elif frame == "memos-api-online": from utils.client import MemosApiOnlineClient client = MemosApiOnlineClient() - context, duration_ms = memos_search(client, question, user_id, top_k) + context, duration_ms = memos_search( + client, question, user_id, top_k, reference_time=question_date + ) elif frame == "memu": from utils.client import MemuClient diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index ba1c50b07..15eb7c38e 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -20,6 +20,7 @@ from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView +from memos.plugins.hooks import hookable logger = get_logger(__name__) @@ -44,6 +45,7 @@ def __init__(self, dependencies: HandlerDependencies): "naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent" ) + @hookable("search") def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ Main handler for search memories endpoint. diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 78dcfc797..049ca544a 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -438,6 +438,14 @@ class APISearchRequest(BaseRequest): ) # ==== Context ==== + reference_time: str | None = Field( + None, + description=( + "Optional reference time for time-sensitive search parsing. " + "If omitted, search uses the current server time." + ), + ) + chat_history: MessageList | None = Field( None, description=( @@ -608,6 +616,17 @@ class APIADDRequest(BaseRequest): description=("Whether this request represents user feedback. Default: False."), ) + # ==== Upload skill flag ==== + is_upload_skill: bool = Field( + False, + description=( + "Whether this request is an upload skill request. " + "When True, the messages field should contain file items " + "with zip file download URLs for pre-built skill packages. " + "Default: False." + ), + ) + # ==== Backward compatibility fields (will delete later) ==== mem_cube_id: str | None = Field( None, diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 26287ff4f..da35a4656 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -17,7 +17,7 @@ from memos.mem_reader.utils import parse_json_result from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.plugins.hook_defs import H -from memos.plugins.hooks import trigger_single_hook +from memos.plugins.hooks import trigger_hook, trigger_single_hook from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType @@ -785,6 +785,14 @@ def _process_one_item( ) rawfile_node.metadata.summary_ids = [mem_node.id for mem_node in fine_items] fine_items.append(rawfile_node) + enriched_items = trigger_hook( + H.MEMORY_ITEMS_AFTER_FINE_EXTRACT, + items=fine_items, + user_context=kwargs.get("user_context"), + mem_reader=self, + extract_mode="fine", + ) + fine_items = enriched_items if enriched_items is not None else fine_items return fine_items fine_memory_items: list[TextualMemoryItem] = [] @@ -993,6 +1001,7 @@ def _process_multi_modal_data( return fast_memory_items # Stage: llm_extract — fine mode 4-way parallel LLM + per-source serial + is_upload_skill = kwargs.pop("is_upload_skill", False) non_file_url_fast_items = [ item for item in fast_memory_items if not self._is_file_url_only_item(item) ] @@ -1009,7 +1018,9 @@ def _process_multi_modal_data( ) future_skill = executor.submit( process_skill_memory_fine, - fast_memory_items=non_file_url_fast_items, + fast_memory_items=fast_memory_items + if is_upload_skill + else non_file_url_fast_items, info=info, searcher=self.searcher, graph_db=self.graph_db, @@ -1017,6 +1028,7 @@ def _process_multi_modal_data( embedder=self.embedder, oss_config=self.oss_config, skills_dir_config=self.skills_dir_config, + is_upload_skill=is_upload_skill, **kwargs, ) future_pref = executor.submit( @@ -1039,6 +1051,10 @@ def _process_multi_modal_data( fine_memory_items.extend(fine_memory_items_pref_parser) # Part B: per-source serial processing + if is_upload_skill: + # (skip for upload skill to avoid zip being parsed) + return fine_memory_items + with timed_stage("add", "per_source") as ts_ps: for fast_item in fast_memory_items: sources = fast_item.metadata.sources @@ -1072,6 +1088,8 @@ def _process_transfer_multi_modal_data( logger.warning("[MultiModalStruct] No raw nodes found.") return [] + is_upload_skill = kwargs.pop("is_upload_skill", False) + # Extract info from raw_nodes (same as simple_struct.py) info = { "user_id": raw_nodes[0].metadata.user_id, @@ -1093,7 +1111,7 @@ def _process_transfer_multi_modal_data( ) future_skill = executor.submit( process_skill_memory_fine, - non_file_url_nodes, + raw_nodes if is_upload_skill else non_file_url_nodes, info, searcher=self.searcher, llm=self.general_llm, @@ -1101,6 +1119,7 @@ def _process_transfer_multi_modal_data( graph_db=self.graph_db, oss_config=self.oss_config, skills_dir_config=self.skills_dir_config, + is_upload_skill=is_upload_skill, **kwargs, ) # Add preference memory extraction @@ -1125,6 +1144,9 @@ def _process_transfer_multi_modal_data( fine_memory_items.extend(fine_memory_items_pref_parser) # Part B: get fine multimodal items + if is_upload_skill: + # (skip for upload skill to avoid zip being parsed) + return fine_memory_items for raw_node in raw_nodes: sources = raw_node.metadata.sources for source in sources: diff --git a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py index 4c93a30c5..269372f25 100644 --- a/src/memos/mem_reader/read_skill_memory/process_skill_memory.py +++ b/src/memos/mem_reader/read_skill_memory/process_skill_memory.py @@ -252,36 +252,6 @@ def create_task(skill_mem, gen_type, prompt, requirements, context, **kwargs): return [item[0] for item in raw_skills_data] -def add_id_to_mysql(memory_id: str, mem_cube_id: str): - """Add id to mysql, will deprecate this function in the future""" - # TODO: tmp function, deprecate soon - import requests - - skill_mysql_url = os.getenv("SKILLS_MYSQL_URL", "") - skill_mysql_bearer = os.getenv("SKILLS_MYSQL_BEARER", "") - - if not skill_mysql_url or not skill_mysql_bearer: - logger.warning("[PROCESS_SKILLS] SKILLS_MYSQL_URL or SKILLS_MYSQL_BEARER is not set") - return None - headers = {"Authorization": skill_mysql_bearer, "Content-Type": "application/json"} - data = {"memCubeId": mem_cube_id, "skillId": memory_id} - try: - response = requests.post(skill_mysql_url, headers=headers, json=data) - - logger.info(f"[PROCESS_SKILLS] response: \n\n{response.json()}") - logger.info(f"[PROCESS_SKILLS] memory_id: \n\n{memory_id}") - logger.info(f"[PROCESS_SKILLS] mem_cube_id: \n\n{mem_cube_id}") - logger.info(f"[PROCESS_SKILLS] skill_mysql_url: \n\n{skill_mysql_url}") - logger.info(f"[PROCESS_SKILLS] skill_mysql_bearer: \n\n{skill_mysql_bearer}") - logger.info(f"[PROCESS_SKILLS] headers: \n\n{headers}") - logger.info(f"[PROCESS_SKILLS] data: \n\n{data}") - - return response.json() - except Exception as e: - logger.warning(f"[PROCESS_SKILLS] Error adding id to mysql: {e}") - return None - - @require_python_package( import_name="alibabacloud_oss_v2", install_command="pip install alibabacloud-oss-v2", @@ -948,6 +918,7 @@ def create_skill_memory_item( scripts=skill_memory.get("scripts"), others=skill_memory.get("others"), url=skill_memory.get("url", ""), + skill_source=skill_memory.get("skill_source"), manager_user_id=manager_user_id, project_id=project_id, ) @@ -1024,6 +995,21 @@ def process_skill_memory_fine( complete_skill_memory: bool = True, **kwargs, ) -> list[TextualMemoryItem]: + is_upload_skill = kwargs.pop("is_upload_skill", False) + if is_upload_skill: + from memos.mem_reader.read_skill_memory.upload_skill_memory import ( + process_upload_skill_memory, + ) + + return process_upload_skill_memory( + fast_memory_items=fast_memory_items, + info=info, + embedder=embedder, + oss_config=oss_config, + skills_dir_config=skills_dir_config, + **kwargs, + ) + skills_repo_backend = _get_skill_file_storage_location() oss_client, _missing_keys, flag = _skill_init( skills_repo_backend, oss_config, skills_dir_config @@ -1252,6 +1238,7 @@ def _full_extract(): if source: skill_sources.append(source) + skill_memory["skill_source"] = "auto_create" memory_item = create_skill_memory_item( skill_memory, info, embedder, sources=skill_sources, **kwargs ) @@ -1260,12 +1247,4 @@ def _full_extract(): logger.warning(f"[PROCESS_SKILLS] Error creating skill memory item: {e}") continue - # TODO: deprecate this funtion and call - for skill_memory, skill_memory_item in zip(skill_memories, skill_memory_items, strict=False): - if skill_memory.get("update", False) and skill_memory.get("old_memory_id", ""): - continue - add_id_to_mysql( - memory_id=skill_memory_item.id, - mem_cube_id=kwargs.get("user_name", info.get("user_id", "")), - ) return skill_memory_items diff --git a/src/memos/mem_reader/read_skill_memory/upload_skill_memory.py b/src/memos/mem_reader/read_skill_memory/upload_skill_memory.py new file mode 100644 index 000000000..fef6b60f2 --- /dev/null +++ b/src/memos/mem_reader/read_skill_memory/upload_skill_memory.py @@ -0,0 +1,312 @@ +import re +import shutil +import tempfile +import zipfile + +from pathlib import Path +from typing import Any +from urllib.parse import urlparse +from uuid import uuid4 + +import requests + +from memos.embedders.base import BaseEmbedder +from memos.log import get_logger +from memos.mem_reader.read_skill_memory.process_skill_memory import ( + create_skill_memory_item, +) +from memos.memories.textual.item import TextualMemoryItem +from memos.utils import timed + + +logger = get_logger(__name__) + +_TEXT_MAX_LEN = 20 + + +def _truncate(text: str) -> str: + """Truncate a string to at most ``_TEXT_MAX_LEN`` characters.""" + return text[:_TEXT_MAX_LEN] + + +def _extract_zip_url_from_items(items: list[TextualMemoryItem]) -> str | None: + """ + Extract the zip download URL from fast-stage memory items. + + FileContentParser.parse_fast stores the URL in source.file_info["file_data"]. + Each upload-skill request contains exactly one zip URL. + """ + for item in items: + for source in getattr(item.metadata, "sources", None) or []: + file_info = getattr(source, "file_info", None) + if not isinstance(file_info, dict): + continue + file_data = file_info.get("file_data", "") + if isinstance(file_data, str) and file_data.startswith(("http://", "https://")): + url_path = urlparse(file_data).path + if url_path.lower().endswith(".zip"): + return file_data + return None + + +def _extract_file_ids_from_items(items: list[TextualMemoryItem]) -> list[str]: + """Extract uploaded file ids from fast-stage memory metadata and sources.""" + file_ids: list[str] = [] + + def _append_file_id(file_id: Any) -> None: + if isinstance(file_id, str) and file_id and file_id not in file_ids: + file_ids.append(file_id) + + for item in items: + metadata = getattr(item, "metadata", None) + metadata_file_ids = getattr(metadata, "file_ids", None) if metadata else None + if isinstance(metadata_file_ids, list): + for file_id in metadata_file_ids: + _append_file_id(file_id) + + for source in getattr(metadata, "sources", None) or []: + file_info = getattr(source, "file_info", None) + if isinstance(file_info, dict): + _append_file_id(file_info.get("file_id")) + + return file_ids + + +def _download_zip(url: str, tmp_dir: Path) -> Path: + """Download a zip file to a local temporary directory.""" + try: + resp = requests.get(url, stream=True, timeout=60) + resp.raise_for_status() + except Exception as e: + raise ValueError(f"Failed to download zip from {url}: {e}") from e + + zip_path = tmp_dir / f"{uuid4()}.zip" + with open(zip_path, "wb") as f: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + + if not zipfile.is_zipfile(zip_path): + raise ValueError(f"Downloaded file is not a valid zip: {url}") + + return zip_path + + +def _extract_and_parse_skill_zip(zip_path: Path) -> dict[str, Any]: + """ + Extract a skill zip and parse SKILL.md + directory contents into a skill_memory dict. + + The SKILL.md format mirrors the output of ``_write_skills_to_file`` in + ``process_skill_memory.py``. Section headings at any level (``#`` through + ``######``) are matched by title text (case-insensitive). + """ + # Step 1: extract & locate SKILL.md + extract_dir = zip_path.parent / zip_path.stem + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(extract_dir) + + skill_md_path = None + for candidate in extract_dir.rglob("SKILL.md"): + skill_md_path = candidate + break + + if skill_md_path is None: + raise FileNotFoundError(f"SKILL.md not found in zip: {zip_path.name}") + + skill_root = skill_md_path.parent + raw_text = skill_md_path.read_text(encoding="utf-8") + + # Step 2: parse frontmatter → name, description + name = "" + description = "" + fm_match = re.match(r"^---\s*\n(.*?)\n---", raw_text, re.DOTALL) + if fm_match: + for line in fm_match.group(1).splitlines(): + if line.startswith("name:"): + name = line[len("name:") :].strip() + elif line.startswith("description:"): + description = line[len("description:") :].strip() + + if not name: + name = zip_path.stem + + # Step 3: split body by any-level heading and parse each section + trigger: str = "" + procedure: str = "" + experience: list[str] = [] + preference: list[str] = [] + examples: list[str] = [] + tool: str | None = None + others_inline: dict[str, str] = {} + + known_sections = { + "trigger", + "procedure", + "experience", + "user preferences", + "examples", + "scripts", + "tool usage", + "additional information", + } + + body = raw_text[fm_match.end() :] if fm_match else raw_text + sections = re.split(r"\n(?=#{1,6}\s)", body) + + for section in sections: + section = section.strip() + if not section: + continue + + heading_match = re.match(r"^(#{1,6})\s+(.*)", section) + if not heading_match: + continue + + title = heading_match.group(2).strip() + content = section[heading_match.end() :].strip() + title_lower = title.lower() + + if title_lower not in known_sections: + logger.warning("[UPLOAD_SKILL] Unknown section '%s' in SKILL.md, skipping", title) + continue + + if title_lower == "trigger": + trigger = content + + elif title_lower == "procedure": + procedure = content + + elif title_lower == "experience": + items = re.findall(r"^\d+\.\s+(.+)$", content, re.MULTILINE) + experience = [item.strip() for item in items] if items else [] + + elif title_lower == "user preferences": + items = re.findall(r"^-\s+(.+)$", content, re.MULTILINE) + preference = [item.strip() for item in items] if items else [] + + elif title_lower == "examples": + blocks = re.findall(r"```markdown\n(.*?)\n```", content, re.DOTALL) + examples = [b.strip() for b in blocks] + + elif title_lower == "scripts": + pass + + elif title_lower == "tool usage": + tool = content.strip() if content.strip() else None + + elif title_lower == "additional information": + sub_sections = re.split(r"\n(?=#{1,6}\s)", content) + for sub in sub_sections: + sub = sub.strip() + if not sub or sub.startswith("See also:"): + continue + sub_heading = re.match(r"^(#{1,6})\s+(.*)", sub) + if not sub_heading: + continue + sub_key = sub_heading.group(2).strip() + sub_val = sub[sub_heading.end() :].strip() + if sub_val: + others_inline[sub_key] = sub_val + + # Step 4: read scripts/ directory + scripts: dict[str, str] | None = None + scripts_dir = skill_root / "scripts" + if scripts_dir.is_dir(): + scripts = {} + for py_file in scripts_dir.glob("*.py"): + scripts[py_file.name] = py_file.read_text(encoding="utf-8") + + # Step 5: read reference/ directory → merge into others + others = dict(others_inline) + reference_dir = skill_root / "reference" + if reference_dir.is_dir(): + for md_file in reference_dir.glob("*.md"): + others[md_file.name] = md_file.read_text(encoding="utf-8") + + # Step 6: truncate text fields & assemble return dict + truncated_trigger = _truncate(trigger) + + result: dict[str, Any] = { + "name": name, + "description": description, + "tags": [truncated_trigger] if truncated_trigger else [], + "procedure": _truncate(procedure), + "experience": [_truncate(e) for e in experience], + "preference": [_truncate(p) for p in preference], + "examples": [_truncate(e) for e in examples], + "tool": _truncate(tool) if tool else None, + "scripts": {k: _truncate(v) for k, v in scripts.items()} if scripts else None, + "others": {k: _truncate(v) for k, v in others.items()} if others else None, + } + # Only include trigger when non-empty; create_skill_memory_item uses + # `skill_memory.get("tags") or skill_memory.get("trigger", [])`, + # an empty-string trigger would override the correct [] fallback. + if truncated_trigger: + result["trigger"] = truncated_trigger + return result + + +@timed +def process_upload_skill_memory( + fast_memory_items: list[TextualMemoryItem], + info: dict[str, Any], + embedder: BaseEmbedder | None = None, + oss_config: dict[str, Any] | None = None, + skills_dir_config: dict[str, Any] | None = None, + **kwargs, +) -> list[TextualMemoryItem]: + """ + Process a user-uploaded skill zip, parse it, and build a SkillMemory node. + + The zip URL is taken from the fast-stage ``TextualMemoryItem`` sources + (``source.file_info["file_data"]``), consistent with both sync-fine and + async-transfer paths. + """ + zip_url = _extract_zip_url_from_items(fast_memory_items) + if not zip_url: + logger.warning("[UPLOAD_SKILL] No zip URL found in fast_memory_items") + return [] + file_ids = _extract_file_ids_from_items(fast_memory_items) + + tmp_dir = Path(tempfile.mkdtemp(prefix="upload_skill_")) + try: + zip_path = _download_zip(zip_url, tmp_dir) + except Exception as e: + logger.warning("[UPLOAD_SKILL] Failed to download zip: %s", e) + shutil.rmtree(tmp_dir, ignore_errors=True) + return [] + + try: + skill_memory = _extract_and_parse_skill_zip(zip_path) + except FileNotFoundError as e: + logger.warning("[UPLOAD_SKILL] %s", e) + shutil.rmtree(tmp_dir, ignore_errors=True) + return [] + except Exception as e: + logger.error("[UPLOAD_SKILL] Failed to parse skill zip: %s", e) + shutil.rmtree(tmp_dir, ignore_errors=True) + return [] + + skill_memory["url"] = zip_url + skill_memory["skill_source"] = "user_upload" + + try: + skill_memory_item = create_skill_memory_item(skill_memory, info, embedder, **kwargs) + if file_ids: + skill_memory_item.metadata.file_ids = file_ids + metadata_info = dict(skill_memory_item.metadata.info or {}) + metadata_info.setdefault("file_id", file_ids[0]) + skill_memory_item.metadata.info = metadata_info + except Exception as e: + logger.error("[UPLOAD_SKILL] Failed to create skill memory item: %s", e) + shutil.rmtree(tmp_dir, ignore_errors=True) + return [] + + # Cleanup temp files + shutil.rmtree(tmp_dir, ignore_errors=True) + + logger.info( + "[UPLOAD_SKILL] Successfully created SkillMemory from uploaded zip: name=%s, id=%s", + skill_memory.get("name"), + skill_memory_item.id, + ) + return [skill_memory_item] diff --git a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py index b85e4ea71..36cc97bdf 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py +++ b/src/memos/mem_scheduler/task_schedule_modules/handlers/mem_read_handler.py @@ -165,6 +165,9 @@ def _process_memories_with_reader( logger.info("Processing %s memories with mem_reader", len(memory_items)) + info = dict(info or {}) + is_upload_skill = info.pop("is_upload_skill", False) + try: processed_memories = mem_reader.fine_transfer_simple_mem( memory_items, @@ -173,6 +176,7 @@ def _process_memories_with_reader( user_name=user_name, chat_history=chat_history, user_context=user_context, + is_upload_skill=is_upload_skill, ) except Exception as e: logger.warning("%s: Fail to transfer mem: %s", e, memory_items) diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index f34cf1efd..b7004c84c 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -81,6 +81,14 @@ class ArchivedTextualMemory(BaseModel): default_factory=lambda: datetime.now().isoformat(), description="The time the memory was created.", ) + timespec: dict[str, Any] | None = Field( + default=None, + description="Compact temporal index snapshot for this archived version, used by retrieval-side version selection.", + ) + memory_form: Literal["state", "event"] | None = Field( + default=None, + description="Internal memory form snapshot for this archived version, used by retrieval-side routing.", + ) class TextualMemoryMetadata(BaseModel): diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index eb15b48ed..3eccb570a 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -1,4 +1,5 @@ import copy +import re import traceback from concurrent.futures import as_completed @@ -32,6 +33,7 @@ logger = get_logger(__name__) +KEYWORD_EXTRACT_TOP_K = 12 COT_DICT = { "fine": {"en": COT_PROMPT, "zh": COT_PROMPT_ZH}, "fast": {"en": SIMPLE_COT_PROMPT, "zh": SIMPLE_COT_PROMPT_ZH}, @@ -278,7 +280,7 @@ def _parse_task( # retrieve related nodes by embedding related_nodes = [ - self.graph_store.get_node(n["id"]) + self.graph_store.get_node(n["id"], user_name=user_name) for n in self.graph_store.search_by_embedding( query_embedding, top_k=top_k, @@ -505,6 +507,39 @@ def _retrieve_from_working_memory( search_filter=search_filter, ) + @staticmethod + def _require_keyword_user_name(user_name: str | None) -> str: + normalized_user_name = user_name.strip() if isinstance(user_name, str) else "" + if not normalized_user_name: + raise ValueError( + "[PATH-KEYWORD] user_name is required for PolarDB fulltext keyword search" + ) + return normalized_user_name + + def _extract_weighted_keyword_terms(self, query: str) -> list[str]: + if detect_lang(query) == "zh": + import jieba.analyse + + weighted_terms = jieba.analyse.extract_tags(query, topK=KEYWORD_EXTRACT_TOP_K) + else: + weighted_terms = [] + if self.tokenizer: + weighted_terms = self.tokenizer.tokenize_mixed(query) + else: + weighted_terms = re.findall(r"\b[a-zA-Z0-9]+\b", query.lower()) + + query_words: list[str] = [] + seen_words: set[str] = set() + for term in weighted_terms: + normalized_term = str(term).strip() + if not normalized_term or normalized_term in seen_words: + continue + seen_words.add(normalized_term) + query_words.append(normalized_term) + if len(query_words) >= KEYWORD_EXTRACT_TOP_K: + break + return query_words + @timed def _retrieve_from_keyword( self, @@ -524,22 +559,21 @@ def _retrieve_from_keyword( return [] if not query_embedding: return [] + user_name = self._require_keyword_user_name(user_name) - query_words: list[str] = [] - if self.tokenizer: - query_words = self.tokenizer.tokenize_mixed(query) - else: - query_words = query.strip().split() - # Use unique tokens; avoid passing the raw query into `to_tsquery(...)` because it may contain - # spaces/operators that cause tsquery parsing errors. - query_words = list(dict.fromkeys(query_words)) - if len(query_words) > 64: - query_words = query_words[:64] + query_words = self._extract_weighted_keyword_terms(query) if not query_words: return [] + # Quote weighted terms before `to_tsquery(...)` to avoid parsing operators from user input. tsquery_terms = ["'" + w.replace("'", "''") + "'" for w in query_words if w and w.strip()] if not tsquery_terms: return [] + logger.info( + "[PATH-KEYWORD] weighted query_words=%s top_k=%s user_name=%s", + query_words, + top_k, + user_name, + ) scopes = [memory_type] if memory_type != "All" else ["LongTermMemory", "UserMemory"] @@ -548,7 +582,7 @@ def _retrieve_from_keyword( try: hits = self.graph_store.search_by_fulltext( query_words=tsquery_terms, - top_k=top_k * 2, + top_k=top_k, status="activated", scope=scope, search_filter=None, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 22d3a253c..f84fc60e1 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -429,7 +429,9 @@ def extract_edge_info(edges_info: list[dict], neighbor_relativity: float): for edge in edges_info: chunk_target_id = edge.get("to") edge_type = edge.get("type") - item_neighbor = self.searcher.graph_store.get_node(chunk_target_id) + item_neighbor = self.searcher.graph_store.get_node( + chunk_target_id, user_name=user_name + ) if item_neighbor: item_neighbor_mem = TextualMemoryItem(**item_neighbor) item_neighbor_mem.metadata.relativity = neighbor_relativity @@ -529,7 +531,10 @@ def _schedule_memory_tasks( content=json.dumps(mem_ids), timestamp=datetime.utcnow(), user_name=self.cube_id, - info=add_req.info, + info={ + **(add_req.info or {}), + "is_upload_skill": getattr(add_req, "is_upload_skill", False), + }, chat_history=add_req.chat_history, user_context=user_context, ) @@ -709,6 +714,7 @@ def _process_text_mem( user_name=user_context.mem_cube_id, chat_history=add_req.chat_history, user_context=user_context, + is_upload_skill=getattr(add_req, "is_upload_skill", False), ) get_memory_ms = ts_gm.duration_ms flattened_local = [mm for m in memories_local for mm in m] diff --git a/src/memos/plugins/hook_defs.py b/src/memos/plugins/hook_defs.py index 536068c8d..5ec73cc86 100644 --- a/src/memos/plugins/hook_defs.py +++ b/src/memos/plugins/hook_defs.py @@ -77,6 +77,7 @@ class H: # mem_reader — generic extension point before LLM extraction MEM_READER_PRE_EXTRACT = "mem_reader.pre_extract" + MEMORY_ITEMS_AFTER_FINE_EXTRACT = "memory_items.after_fine_extract" # memory version — single-provider business hooks MEMORY_VERSION_PREPARE_UPDATES = "memory_version.prepare_updates" @@ -105,6 +106,13 @@ class H: pipe_key="prompt", ) +define_hook( + H.MEMORY_ITEMS_AFTER_FINE_EXTRACT, + description="Post-process memory items after mem_reader fine extraction completes", + params=["items", "user_context", "mem_reader", "extract_mode"], + pipe_key="items", +) + define_hook( H.MEMORY_VERSION_PREPARE_UPDATES, description=(