Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/reranker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

11 changes: 11 additions & 0 deletions tests/reranker/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from memos.reranker.base import BaseReranker
from memos.reranker.strategies.base import BaseRerankerStrategy
from tests.utils import check_module_base_class


def test_base_reranker_class_contract():
check_module_base_class(BaseReranker)


def test_base_reranker_strategy_class_contract():
check_module_base_class(BaseRerankerStrategy)
62 changes: 62 additions & 0 deletions tests/reranker/test_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest

from memos.configs.reranker import RerankerConfigFactory
from memos.reranker.cosine_local import CosineLocalReranker
from memos.reranker.factory import RerankerFactory
from memos.reranker.noop import NoopReranker
from memos.reranker.strategies.concat_background import ConcatBackgroundStrategy
from memos.reranker.strategies.concat_docsource import ConcatDocSourceStrategy
from memos.reranker.strategies.factory import RerankerStrategyFactory
from memos.reranker.strategies.single_turn import SingleTurnStrategy
from memos.reranker.strategies.singleturn_outmem import SingleTurnOutMemStrategy


def test_reranker_factory_returns_none_for_missing_config():
assert RerankerFactory.from_config(None) is None


def test_reranker_factory_builds_noop_reranker():
config = RerankerConfigFactory(backend="noop")

assert isinstance(RerankerFactory.from_config(config), NoopReranker)


def test_reranker_factory_builds_cosine_local_reranker_with_options():
config = RerankerConfigFactory(
backend="cosine_local",
config={
"level_weights": {"topic": 2.0},
"level_field": "background",
},
)

reranker = RerankerFactory.from_config(config)

assert isinstance(reranker, CosineLocalReranker)
assert reranker.level_weights == {"topic": 2.0}
assert reranker.level_field == "background"


def test_reranker_factory_rejects_unknown_backend():
config = RerankerConfigFactory(backend="unknown")

with pytest.raises(ValueError, match="Unknown reranker backend"):
RerankerFactory.from_config(config)


@pytest.mark.parametrize(
("backend", "expected_class"),
[
("single_turn", SingleTurnStrategy),
("concat_background", ConcatBackgroundStrategy),
("singleturn_outmem", SingleTurnOutMemStrategy),
("concat_docsource", ConcatDocSourceStrategy),
],
)
def test_reranker_strategy_factory_builds_supported_strategies(backend, expected_class):
assert isinstance(RerankerStrategyFactory.from_config(backend), expected_class)


def test_reranker_strategy_factory_rejects_unknown_backend():
with pytest.raises(ValueError, match="Invalid backend"):
RerankerStrategyFactory.from_config("missing")
79 changes: 79 additions & 0 deletions tests/reranker/test_local_rerankers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.reranker.cosine_local import CosineLocalReranker
from memos.reranker.noop import NoopReranker


def _memory_item(item_id: str, memory: str, embedding: list[float] | None = None):
return TextualMemoryItem(
id=item_id,
memory=memory,
metadata=TreeNodeTextualMemoryMetadata(
background="topic",
embedding=embedding,
sources=[],
),
)


def test_noop_reranker_returns_top_k_items_with_zero_scores():
items = [
_memory_item("00000000-0000-0000-0000-000000000001", "first"),
_memory_item("00000000-0000-0000-0000-000000000002", "second"),
_memory_item("00000000-0000-0000-0000-000000000003", "third"),
]

ranked = NoopReranker().rerank("query", items, top_k=2)

assert ranked == [(items[0], 0.0), (items[1], 0.0)]


def test_cosine_local_reranker_returns_empty_for_empty_results():
assert CosineLocalReranker().rerank("query", [], top_k=3, query_embedding=[1.0, 0.0]) == []


def test_cosine_local_reranker_falls_back_without_query_embedding():
items = [_memory_item("00000000-0000-0000-0000-000000000001", "first")]

assert CosineLocalReranker().rerank("query", items, top_k=1) == [(items[0], 0.0)]


def test_cosine_local_reranker_scores_and_sorts_embedded_items():
near = _memory_item(
"00000000-0000-0000-0000-000000000001",
"near",
embedding=[1.0, 0.0],
)
far = _memory_item(
"00000000-0000-0000-0000-000000000002",
"far",
embedding=[0.0, 1.0],
)

ranked = CosineLocalReranker().rerank(
"query",
[far, near],
top_k=2,
query_embedding=[1.0, 0.0],
)

assert ranked[0][0] == near
assert ranked[0][1] > ranked[1][1]


def test_cosine_local_reranker_fills_missing_embeddings_with_negative_score():
embedded = _memory_item(
"00000000-0000-0000-0000-000000000001",
"embedded",
embedding=[1.0, 0.0],
)
missing = _memory_item("00000000-0000-0000-0000-000000000002", "missing")

ranked = CosineLocalReranker().rerank(
"query",
[missing, embedded],
top_k=2,
query_embedding=[1.0, 0.0],
)

assert ranked[0][0] == embedded
assert ranked[1] == (missing, -1.0)
141 changes: 141 additions & 0 deletions tests/reranker/test_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from memos.memories.textual.item import (
SourceMessage,
TextualMemoryItem,
TreeNodeTextualMemoryMetadata,
)
from memos.reranker.strategies.concat_background import ConcatBackgroundStrategy
from memos.reranker.strategies.concat_docsource import ConcatDocSourceStrategy
from memos.reranker.strategies.dialogue_common import DialogueRankingTracker, extract_content
from memos.reranker.strategies.single_turn import SingleTurnStrategy
from memos.reranker.strategies.singleturn_outmem import SingleTurnOutMemStrategy


def _memory_item(
item_id: str = "00000000-0000-0000-0000-000000000001",
memory: str = "[tag] remembers preferred answer style",
sources: list[dict] | None = None,
background: str = "background note",
):
return TextualMemoryItem(
id=item_id,
memory=memory,
metadata=TreeNodeTextualMemoryMetadata(
background=background,
embedding=[1.0, 0.0],
sources=sources or [],
),
)


def test_extract_content_supports_dict_source_message_and_string():
source = SourceMessage(role="user", content="from model", chat_time="2026-05-12")

assert extract_content({"content": "from dict"}) == "from dict"
assert extract_content(source) == "from model"
assert extract_content("raw text") == "raw text"


def test_dialogue_ranking_tracker_builds_documents_and_bounds_lookup():
tracker = DialogueRankingTracker()
pair_id = tracker.add_dialogue_pair(
"memory-1",
0,
{"content": "hello"},
{"content": "hi"},
"memory body",
chat_time="2026-05-12",
)

assert pair_id == "memory-1_0"
assert tracker.get_dialogue_pair_by_index(0).user_content == "hello"
assert tracker.get_dialogue_pair_by_index(99) is None
assert tracker.get_documents_for_ranking() == [
"memory body\n\n[2026-05-12]: \nuser: hello\nassistant: hi"
]


def test_single_turn_strategy_prepares_documents_from_chat_pairs():
item = _memory_item(
sources=[
{"role": "user", "content": "What should I do?", "chat_time": "2026-05-12"},
{"role": "assistant", "content": "Use the concise answer style."},
]
)

tracker, original_items, documents = SingleTurnStrategy().prepare_documents(
"answer style",
[item],
top_k=1,
)

assert original_items == {item.id: item}
assert len(tracker.dialogue_pairs) == 1
assert documents == [
"remembers preferred answer style\n\n"
"[2026-05-12]: \nuser: What should I do?\nassistant: Use the concise answer style."
]


def test_single_turn_strategy_reconstructs_ranked_dialogue_items():
item = _memory_item(
sources=[
{"role": "user", "content": "question"},
{"role": "assistant", "content": "answer"},
]
)
strategy = SingleTurnStrategy()
tracker, original_items, _documents = strategy.prepare_documents("query", [item], top_k=1)

ranked = strategy.reconstruct_items([0], [0.9], tracker, original_items, top_k=1)

assert ranked[0][1] == 0.9
assert "sources-dialogue-pairs" in ranked[0][0].memory
assert ranked[0][0] is not item


def test_single_turn_outmem_strategy_aggregates_by_original_memory():
item = _memory_item(
sources=[
{"role": "user", "content": "first"},
{"role": "assistant", "content": "answer"},
{"role": "user", "content": "second"},
{"role": "assistant", "content": "answer"},
]
)
strategy = SingleTurnOutMemStrategy()
tracker, original_items, _documents = strategy.prepare_documents("query", [item], top_k=1)

ranked = strategy.reconstruct_items([0, 1], [0.2, 0.8], tracker, original_items, top_k=1)

assert ranked == [(item, 0.8)]


def test_concat_background_strategy_includes_background_text():
item = _memory_item()

_tracker, _original_items, documents = ConcatBackgroundStrategy().prepare_documents(
"answer style",
[item],
top_k=1,
)

assert documents == ["remembers preferred answer style\nbackground note"]


def test_concat_docsource_strategy_includes_file_source_content():
item = _memory_item(
sources=[
{
"type": "file",
"content": "file chunk",
}
]
)

_tracker, _original_items, documents = ConcatDocSourceStrategy().prepare_documents(
"answer style",
[item],
top_k=1,
)

assert documents == ["remembers preferred answer style\n\n[Sources]:\nfile chunk"]
Loading