diff --git a/tests/reranker/__init__.py b/tests/reranker/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/tests/reranker/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/reranker/test_base.py b/tests/reranker/test_base.py new file mode 100644 index 000000000..4b880351c --- /dev/null +++ b/tests/reranker/test_base.py @@ -0,0 +1,11 @@ +from memos.reranker.base import BaseReranker +from memos.reranker.strategies.base import BaseRerankerStrategy +from tests.utils import check_module_base_class + + +def test_base_reranker_class_contract(): + check_module_base_class(BaseReranker) + + +def test_base_reranker_strategy_class_contract(): + check_module_base_class(BaseRerankerStrategy) diff --git a/tests/reranker/test_factory.py b/tests/reranker/test_factory.py new file mode 100644 index 000000000..29d6b28c3 --- /dev/null +++ b/tests/reranker/test_factory.py @@ -0,0 +1,62 @@ +import pytest + +from memos.configs.reranker import RerankerConfigFactory +from memos.reranker.cosine_local import CosineLocalReranker +from memos.reranker.factory import RerankerFactory +from memos.reranker.noop import NoopReranker +from memos.reranker.strategies.concat_background import ConcatBackgroundStrategy +from memos.reranker.strategies.concat_docsource import ConcatDocSourceStrategy +from memos.reranker.strategies.factory import RerankerStrategyFactory +from memos.reranker.strategies.single_turn import SingleTurnStrategy +from memos.reranker.strategies.singleturn_outmem import SingleTurnOutMemStrategy + + +def test_reranker_factory_returns_none_for_missing_config(): + assert RerankerFactory.from_config(None) is None + + +def test_reranker_factory_builds_noop_reranker(): + config = RerankerConfigFactory(backend="noop") + + assert isinstance(RerankerFactory.from_config(config), NoopReranker) + + +def test_reranker_factory_builds_cosine_local_reranker_with_options(): + config = RerankerConfigFactory( + backend="cosine_local", + config={ + "level_weights": {"topic": 2.0}, + "level_field": "background", + }, + ) + + reranker = RerankerFactory.from_config(config) + + assert isinstance(reranker, CosineLocalReranker) + assert reranker.level_weights == {"topic": 2.0} + assert reranker.level_field == "background" + + +def test_reranker_factory_rejects_unknown_backend(): + config = RerankerConfigFactory(backend="unknown") + + with pytest.raises(ValueError, match="Unknown reranker backend"): + RerankerFactory.from_config(config) + + +@pytest.mark.parametrize( + ("backend", "expected_class"), + [ + ("single_turn", SingleTurnStrategy), + ("concat_background", ConcatBackgroundStrategy), + ("singleturn_outmem", SingleTurnOutMemStrategy), + ("concat_docsource", ConcatDocSourceStrategy), + ], +) +def test_reranker_strategy_factory_builds_supported_strategies(backend, expected_class): + assert isinstance(RerankerStrategyFactory.from_config(backend), expected_class) + + +def test_reranker_strategy_factory_rejects_unknown_backend(): + with pytest.raises(ValueError, match="Invalid backend"): + RerankerStrategyFactory.from_config("missing") diff --git a/tests/reranker/test_local_rerankers.py b/tests/reranker/test_local_rerankers.py new file mode 100644 index 000000000..cf6a1fb04 --- /dev/null +++ b/tests/reranker/test_local_rerankers.py @@ -0,0 +1,79 @@ +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.reranker.cosine_local import CosineLocalReranker +from memos.reranker.noop import NoopReranker + + +def _memory_item(item_id: str, memory: str, embedding: list[float] | None = None): + return TextualMemoryItem( + id=item_id, + memory=memory, + metadata=TreeNodeTextualMemoryMetadata( + background="topic", + embedding=embedding, + sources=[], + ), + ) + + +def test_noop_reranker_returns_top_k_items_with_zero_scores(): + items = [ + _memory_item("00000000-0000-0000-0000-000000000001", "first"), + _memory_item("00000000-0000-0000-0000-000000000002", "second"), + _memory_item("00000000-0000-0000-0000-000000000003", "third"), + ] + + ranked = NoopReranker().rerank("query", items, top_k=2) + + assert ranked == [(items[0], 0.0), (items[1], 0.0)] + + +def test_cosine_local_reranker_returns_empty_for_empty_results(): + assert CosineLocalReranker().rerank("query", [], top_k=3, query_embedding=[1.0, 0.0]) == [] + + +def test_cosine_local_reranker_falls_back_without_query_embedding(): + items = [_memory_item("00000000-0000-0000-0000-000000000001", "first")] + + assert CosineLocalReranker().rerank("query", items, top_k=1) == [(items[0], 0.0)] + + +def test_cosine_local_reranker_scores_and_sorts_embedded_items(): + near = _memory_item( + "00000000-0000-0000-0000-000000000001", + "near", + embedding=[1.0, 0.0], + ) + far = _memory_item( + "00000000-0000-0000-0000-000000000002", + "far", + embedding=[0.0, 1.0], + ) + + ranked = CosineLocalReranker().rerank( + "query", + [far, near], + top_k=2, + query_embedding=[1.0, 0.0], + ) + + assert ranked[0][0] == near + assert ranked[0][1] > ranked[1][1] + + +def test_cosine_local_reranker_fills_missing_embeddings_with_negative_score(): + embedded = _memory_item( + "00000000-0000-0000-0000-000000000001", + "embedded", + embedding=[1.0, 0.0], + ) + missing = _memory_item("00000000-0000-0000-0000-000000000002", "missing") + + ranked = CosineLocalReranker().rerank( + "query", + [missing, embedded], + top_k=2, + query_embedding=[1.0, 0.0], + ) + + assert ranked[0][0] == embedded + assert ranked[1] == (missing, -1.0) diff --git a/tests/reranker/test_strategies.py b/tests/reranker/test_strategies.py new file mode 100644 index 000000000..c92832844 --- /dev/null +++ b/tests/reranker/test_strategies.py @@ -0,0 +1,141 @@ +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) +from memos.reranker.strategies.concat_background import ConcatBackgroundStrategy +from memos.reranker.strategies.concat_docsource import ConcatDocSourceStrategy +from memos.reranker.strategies.dialogue_common import DialogueRankingTracker, extract_content +from memos.reranker.strategies.single_turn import SingleTurnStrategy +from memos.reranker.strategies.singleturn_outmem import SingleTurnOutMemStrategy + + +def _memory_item( + item_id: str = "00000000-0000-0000-0000-000000000001", + memory: str = "[tag] remembers preferred answer style", + sources: list[dict] | None = None, + background: str = "background note", +): + return TextualMemoryItem( + id=item_id, + memory=memory, + metadata=TreeNodeTextualMemoryMetadata( + background=background, + embedding=[1.0, 0.0], + sources=sources or [], + ), + ) + + +def test_extract_content_supports_dict_source_message_and_string(): + source = SourceMessage(role="user", content="from model", chat_time="2026-05-12") + + assert extract_content({"content": "from dict"}) == "from dict" + assert extract_content(source) == "from model" + assert extract_content("raw text") == "raw text" + + +def test_dialogue_ranking_tracker_builds_documents_and_bounds_lookup(): + tracker = DialogueRankingTracker() + pair_id = tracker.add_dialogue_pair( + "memory-1", + 0, + {"content": "hello"}, + {"content": "hi"}, + "memory body", + chat_time="2026-05-12", + ) + + assert pair_id == "memory-1_0" + assert tracker.get_dialogue_pair_by_index(0).user_content == "hello" + assert tracker.get_dialogue_pair_by_index(99) is None + assert tracker.get_documents_for_ranking() == [ + "memory body\n\n[2026-05-12]: \nuser: hello\nassistant: hi" + ] + + +def test_single_turn_strategy_prepares_documents_from_chat_pairs(): + item = _memory_item( + sources=[ + {"role": "user", "content": "What should I do?", "chat_time": "2026-05-12"}, + {"role": "assistant", "content": "Use the concise answer style."}, + ] + ) + + tracker, original_items, documents = SingleTurnStrategy().prepare_documents( + "answer style", + [item], + top_k=1, + ) + + assert original_items == {item.id: item} + assert len(tracker.dialogue_pairs) == 1 + assert documents == [ + "remembers preferred answer style\n\n" + "[2026-05-12]: \nuser: What should I do?\nassistant: Use the concise answer style." + ] + + +def test_single_turn_strategy_reconstructs_ranked_dialogue_items(): + item = _memory_item( + sources=[ + {"role": "user", "content": "question"}, + {"role": "assistant", "content": "answer"}, + ] + ) + strategy = SingleTurnStrategy() + tracker, original_items, _documents = strategy.prepare_documents("query", [item], top_k=1) + + ranked = strategy.reconstruct_items([0], [0.9], tracker, original_items, top_k=1) + + assert ranked[0][1] == 0.9 + assert "sources-dialogue-pairs" in ranked[0][0].memory + assert ranked[0][0] is not item + + +def test_single_turn_outmem_strategy_aggregates_by_original_memory(): + item = _memory_item( + sources=[ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "answer"}, + {"role": "user", "content": "second"}, + {"role": "assistant", "content": "answer"}, + ] + ) + strategy = SingleTurnOutMemStrategy() + tracker, original_items, _documents = strategy.prepare_documents("query", [item], top_k=1) + + ranked = strategy.reconstruct_items([0, 1], [0.2, 0.8], tracker, original_items, top_k=1) + + assert ranked == [(item, 0.8)] + + +def test_concat_background_strategy_includes_background_text(): + item = _memory_item() + + _tracker, _original_items, documents = ConcatBackgroundStrategy().prepare_documents( + "answer style", + [item], + top_k=1, + ) + + assert documents == ["remembers preferred answer style\nbackground note"] + + +def test_concat_docsource_strategy_includes_file_source_content(): + item = _memory_item( + sources=[ + { + "type": "file", + "content": "file chunk", + } + ] + ) + + _tracker, _original_items, documents = ConcatDocSourceStrategy().prepare_documents( + "answer style", + [item], + top_k=1, + ) + + assert documents == ["remembers preferred answer style\n\n[Sources]:\nfile chunk"]