Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions evaluation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ First prepare the dataset `longmemeval_s` from https://huggingface.co/datasets/x
./scripts/run_lme_eval.sh
```

#### Question date and `reference_time`

LongMemEval gives each question a **question date**; evaluation should use that as the reference “now”, not the time when you run the script. The LongMemEval search script passes `question_date` as **`reference_time`** where the backend supports it.

**MemOS Cloud** currently does not support supplying question date on search the same way, so LongMemEval scores there may differ from a spec-faithful run. **Prefer evaluating LongMemEval against the open-source MemOS server** when you need comparable numbers.

### PrefEval Evaluation
Downloading benchmark_dataset/filtered_inter_turns.json from https://github.com/amazon-science/PrefEval/blob/main/benchmark_dataset/filtered_inter_turns.json and save it as `./data/prefeval/filtered_inter_turns.json`.
To evaluate the **Prefeval** dataset — run the following [script](./scripts/run_prefeval_eval.sh):
Expand Down
14 changes: 10 additions & 4 deletions evaluation/scripts/longmemeval/lme_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def mem0_search(client, query, user_id, top_k):
return context, duration_ms


def memos_search(client, query, user_id, top_k):
def memos_search(client, query, user_id, top_k, reference_time=None):
start = time()
results = client.search(query=query, user_id=user_id, top_k=top_k)
results = client.search(
query=query, user_id=user_id, top_k=top_k, reference_time=reference_time
)
context = (
"\n".join([i["memory"] for i in results["text_mem"][0]["memories"]])
+ f"\n{results.get('pref_string', '')}"
Expand Down Expand Up @@ -122,12 +124,16 @@ def process_user(lme_df, conv_idx, frame, version, top_k=20):
from utils.client import MemosApiClient

client = MemosApiClient()
context, duration_ms = memos_search(client, question, user_id, top_k)
context, duration_ms = memos_search(
client, question, user_id, top_k, reference_time=question_date
)
elif frame == "memos-api-online":
from utils.client import MemosApiOnlineClient

client = MemosApiOnlineClient()
context, duration_ms = memos_search(client, question, user_id, top_k)
context, duration_ms = memos_search(
client, question, user_id, top_k, reference_time=question_date
)
elif frame == "memu":
from utils.client import MemuClient

Expand Down
2 changes: 2 additions & 0 deletions src/memos/api/handlers/search_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from memos.multi_mem_cube.composite_cube import CompositeCubeView
from memos.multi_mem_cube.single_cube import SingleCubeView
from memos.multi_mem_cube.views import MemCubeView
from memos.plugins.hooks import hookable


logger = get_logger(__name__)
Expand All @@ -44,6 +45,7 @@ def __init__(self, dependencies: HandlerDependencies):
"naive_mem_cube", "mem_scheduler", "searcher", "deepsearch_agent"
)

@hookable("search")
def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse:
"""
Main handler for search memories endpoint.
Expand Down
19 changes: 19 additions & 0 deletions src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,14 @@ class APISearchRequest(BaseRequest):
)

# ==== Context ====
reference_time: str | None = Field(
None,
description=(
"Optional reference time for time-sensitive search parsing. "
"If omitted, search uses the current server time."
),
)

chat_history: MessageList | None = Field(
None,
description=(
Expand Down Expand Up @@ -608,6 +616,17 @@ class APIADDRequest(BaseRequest):
description=("Whether this request represents user feedback. Default: False."),
)

# ==== Upload skill flag ====
is_upload_skill: bool = Field(
False,
description=(
"Whether this request is an upload skill request. "
"When True, the messages field should contain file items "
"with zip file download URLs for pre-built skill packages. "
"Default: False."
),
)

# ==== Backward compatibility fields (will delete later) ====
mem_cube_id: str | None = Field(
None,
Expand Down
28 changes: 25 additions & 3 deletions src/memos/mem_reader/multi_modal_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from memos.mem_reader.utils import parse_json_result
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.plugins.hook_defs import H
from memos.plugins.hooks import trigger_single_hook
from memos.plugins.hooks import trigger_hook, trigger_single_hook
from memos.templates.mem_reader_prompts import MEMORY_MERGE_PROMPT_EN, MEMORY_MERGE_PROMPT_ZH
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
from memos.types import MessagesType
Expand Down Expand Up @@ -785,6 +785,14 @@ def _process_one_item(
)
rawfile_node.metadata.summary_ids = [mem_node.id for mem_node in fine_items]
fine_items.append(rawfile_node)
enriched_items = trigger_hook(
H.MEMORY_ITEMS_AFTER_FINE_EXTRACT,
items=fine_items,
user_context=kwargs.get("user_context"),
mem_reader=self,
extract_mode="fine",
)
fine_items = enriched_items if enriched_items is not None else fine_items
return fine_items

fine_memory_items: list[TextualMemoryItem] = []
Expand Down Expand Up @@ -993,6 +1001,7 @@ def _process_multi_modal_data(
return fast_memory_items

# Stage: llm_extract — fine mode 4-way parallel LLM + per-source serial
is_upload_skill = kwargs.pop("is_upload_skill", False)
non_file_url_fast_items = [
item for item in fast_memory_items if not self._is_file_url_only_item(item)
]
Expand All @@ -1009,14 +1018,17 @@ def _process_multi_modal_data(
)
future_skill = executor.submit(
process_skill_memory_fine,
fast_memory_items=non_file_url_fast_items,
fast_memory_items=fast_memory_items
if is_upload_skill
else non_file_url_fast_items,
info=info,
searcher=self.searcher,
graph_db=self.graph_db,
llm=self.general_llm,
embedder=self.embedder,
oss_config=self.oss_config,
skills_dir_config=self.skills_dir_config,
is_upload_skill=is_upload_skill,
**kwargs,
)
future_pref = executor.submit(
Expand All @@ -1039,6 +1051,10 @@ def _process_multi_modal_data(
fine_memory_items.extend(fine_memory_items_pref_parser)

# Part B: per-source serial processing
if is_upload_skill:
# (skip for upload skill to avoid zip being parsed)
return fine_memory_items

with timed_stage("add", "per_source") as ts_ps:
for fast_item in fast_memory_items:
sources = fast_item.metadata.sources
Expand Down Expand Up @@ -1072,6 +1088,8 @@ def _process_transfer_multi_modal_data(
logger.warning("[MultiModalStruct] No raw nodes found.")
return []

is_upload_skill = kwargs.pop("is_upload_skill", False)

# Extract info from raw_nodes (same as simple_struct.py)
info = {
"user_id": raw_nodes[0].metadata.user_id,
Expand All @@ -1093,14 +1111,15 @@ def _process_transfer_multi_modal_data(
)
future_skill = executor.submit(
process_skill_memory_fine,
non_file_url_nodes,
raw_nodes if is_upload_skill else non_file_url_nodes,
info,
searcher=self.searcher,
llm=self.general_llm,
embedder=self.embedder,
graph_db=self.graph_db,
oss_config=self.oss_config,
skills_dir_config=self.skills_dir_config,
is_upload_skill=is_upload_skill,
**kwargs,
)
# Add preference memory extraction
Expand All @@ -1125,6 +1144,9 @@ def _process_transfer_multi_modal_data(
fine_memory_items.extend(fine_memory_items_pref_parser)

# Part B: get fine multimodal items
if is_upload_skill:
# (skip for upload skill to avoid zip being parsed)
return fine_memory_items
for raw_node in raw_nodes:
sources = raw_node.metadata.sources
for source in sources:
Expand Down
55 changes: 17 additions & 38 deletions src/memos/mem_reader/read_skill_memory/process_skill_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,36 +252,6 @@ def create_task(skill_mem, gen_type, prompt, requirements, context, **kwargs):
return [item[0] for item in raw_skills_data]


def add_id_to_mysql(memory_id: str, mem_cube_id: str):
"""Add id to mysql, will deprecate this function in the future"""
# TODO: tmp function, deprecate soon
import requests

skill_mysql_url = os.getenv("SKILLS_MYSQL_URL", "")
skill_mysql_bearer = os.getenv("SKILLS_MYSQL_BEARER", "")

if not skill_mysql_url or not skill_mysql_bearer:
logger.warning("[PROCESS_SKILLS] SKILLS_MYSQL_URL or SKILLS_MYSQL_BEARER is not set")
return None
headers = {"Authorization": skill_mysql_bearer, "Content-Type": "application/json"}
data = {"memCubeId": mem_cube_id, "skillId": memory_id}
try:
response = requests.post(skill_mysql_url, headers=headers, json=data)

logger.info(f"[PROCESS_SKILLS] response: \n\n{response.json()}")
logger.info(f"[PROCESS_SKILLS] memory_id: \n\n{memory_id}")
logger.info(f"[PROCESS_SKILLS] mem_cube_id: \n\n{mem_cube_id}")
logger.info(f"[PROCESS_SKILLS] skill_mysql_url: \n\n{skill_mysql_url}")
logger.info(f"[PROCESS_SKILLS] skill_mysql_bearer: \n\n{skill_mysql_bearer}")
logger.info(f"[PROCESS_SKILLS] headers: \n\n{headers}")
logger.info(f"[PROCESS_SKILLS] data: \n\n{data}")

return response.json()
except Exception as e:
logger.warning(f"[PROCESS_SKILLS] Error adding id to mysql: {e}")
return None


@require_python_package(
import_name="alibabacloud_oss_v2",
install_command="pip install alibabacloud-oss-v2",
Expand Down Expand Up @@ -948,6 +918,7 @@ def create_skill_memory_item(
scripts=skill_memory.get("scripts"),
others=skill_memory.get("others"),
url=skill_memory.get("url", ""),
skill_source=skill_memory.get("skill_source"),
manager_user_id=manager_user_id,
project_id=project_id,
)
Expand Down Expand Up @@ -1024,6 +995,21 @@ def process_skill_memory_fine(
complete_skill_memory: bool = True,
**kwargs,
) -> list[TextualMemoryItem]:
is_upload_skill = kwargs.pop("is_upload_skill", False)
if is_upload_skill:
from memos.mem_reader.read_skill_memory.upload_skill_memory import (
process_upload_skill_memory,
)

return process_upload_skill_memory(
fast_memory_items=fast_memory_items,
info=info,
embedder=embedder,
oss_config=oss_config,
skills_dir_config=skills_dir_config,
**kwargs,
)

skills_repo_backend = _get_skill_file_storage_location()
oss_client, _missing_keys, flag = _skill_init(
skills_repo_backend, oss_config, skills_dir_config
Expand Down Expand Up @@ -1252,6 +1238,7 @@ def _full_extract():
if source:
skill_sources.append(source)

skill_memory["skill_source"] = "auto_create"
memory_item = create_skill_memory_item(
skill_memory, info, embedder, sources=skill_sources, **kwargs
)
Expand All @@ -1260,12 +1247,4 @@ def _full_extract():
logger.warning(f"[PROCESS_SKILLS] Error creating skill memory item: {e}")
continue

# TODO: deprecate this funtion and call
for skill_memory, skill_memory_item in zip(skill_memories, skill_memory_items, strict=False):
if skill_memory.get("update", False) and skill_memory.get("old_memory_id", ""):
continue
add_id_to_mysql(
memory_id=skill_memory_item.id,
mem_cube_id=kwargs.get("user_name", info.get("user_id", "")),
)
return skill_memory_items
Loading
Loading