|
|
"""Shared research memory layer for all orchestration modes. |
|
|
|
|
|
Design Pattern: Dependency Injection |
|
|
- Receives embedding service via constructor |
|
|
- Uses service_loader.get_embedding_service() as default (Strategy Pattern) |
|
|
- Allows testing with mock services |
|
|
|
|
|
SOLID Principles: |
|
|
- Dependency Inversion: Depends on EmbeddingServiceProtocol, not concrete class |
|
|
- Open/Closed: Works with any service implementing the protocol |
|
|
""" |
|
|
|
|
|
from typing import TYPE_CHECKING, Any, get_args |
|
|
|
|
|
import structlog |
|
|
|
|
|
from src.agents.graph.state import Conflict, Hypothesis |
|
|
from src.utils.models import Citation, Evidence, SourceName |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from src.services.embedding_protocol import EmbeddingServiceProtocol |
|
|
|
|
|
logger = structlog.get_logger() |
|
|
|
|
|
|
|
|
class ResearchMemory: |
|
|
"""Shared cognitive state for research workflows. |
|
|
|
|
|
This is the memory layer that ALL modes use. |
|
|
It mimics the LangGraph state management but for manual orchestration. |
|
|
|
|
|
The embedding service is selected via get_embedding_service(), which returns: |
|
|
- LlamaIndexRAGService (premium tier) if OPENAI_API_KEY is available |
|
|
- EmbeddingService (free tier) as fallback |
|
|
""" |
|
|
|
|
|
def __init__(self, query: str, embedding_service: "EmbeddingServiceProtocol | None" = None): |
|
|
"""Initialize ResearchMemory with a query and optional embedding service. |
|
|
|
|
|
Args: |
|
|
query: The research query to track evidence for. |
|
|
embedding_service: Service for semantic search and deduplication. |
|
|
Uses get_embedding_service() if not provided, |
|
|
which selects the best available service. |
|
|
""" |
|
|
self.query = query |
|
|
self.hypotheses: list[Hypothesis] = [] |
|
|
self.conflicts: list[Conflict] = [] |
|
|
self.evidence_ids: list[str] = [] |
|
|
self._evidence_cache: dict[str, Evidence] = {} |
|
|
self.iteration_count: int = 0 |
|
|
|
|
|
|
|
|
if embedding_service is None: |
|
|
from src.utils.service_loader import get_embedding_service |
|
|
|
|
|
self._embedding_service: EmbeddingServiceProtocol = get_embedding_service() |
|
|
else: |
|
|
self._embedding_service = embedding_service |
|
|
|
|
|
async def store_evidence(self, evidence: list[Evidence]) -> list[str]: |
|
|
"""Store evidence and return new IDs (deduped).""" |
|
|
if not self._embedding_service: |
|
|
return [] |
|
|
|
|
|
|
|
|
unique = await self._embedding_service.deduplicate(evidence) |
|
|
|
|
|
|
|
|
new_ids = [] |
|
|
for ev in unique: |
|
|
ev_id = ev.citation.url |
|
|
new_ids.append(ev_id) |
|
|
self._evidence_cache[ev_id] = ev |
|
|
|
|
|
self.evidence_ids.extend(new_ids) |
|
|
if new_ids: |
|
|
logger.info("Stored new evidence", count=len(new_ids)) |
|
|
return new_ids |
|
|
|
|
|
def get_all_evidence(self) -> list[Evidence]: |
|
|
"""Get all accumulated evidence objects.""" |
|
|
return list(self._evidence_cache.values()) |
|
|
|
|
|
async def get_relevant_evidence(self, n: int = 20) -> list[Evidence]: |
|
|
"""Retrieve relevant evidence for current query.""" |
|
|
if not self._embedding_service: |
|
|
return [] |
|
|
|
|
|
results = await self._embedding_service.search_similar(self.query, n_results=n) |
|
|
evidence_list = [] |
|
|
|
|
|
for r in results: |
|
|
meta = r.get("metadata", {}) |
|
|
authors_str = meta.get("authors", "") |
|
|
authors = [a.strip() for a in authors_str.split(",")] if authors_str else [] |
|
|
|
|
|
|
|
|
source_raw = meta.get("source", "web") |
|
|
|
|
|
|
|
|
valid_sources = get_args(SourceName) |
|
|
source_name: Any = source_raw if source_raw in valid_sources else "web" |
|
|
|
|
|
citation = Citation( |
|
|
source=source_name, |
|
|
title=meta.get("title", "Unknown"), |
|
|
url=meta.get("url", r.get("id", "")), |
|
|
date=meta.get("date", "Unknown"), |
|
|
authors=authors, |
|
|
) |
|
|
|
|
|
evidence_list.append( |
|
|
Evidence( |
|
|
content=r.get("content", ""), |
|
|
citation=citation, |
|
|
relevance=1.0 - r.get("distance", 0.5), |
|
|
) |
|
|
) |
|
|
|
|
|
return evidence_list |
|
|
|
|
|
async def get_context_summary(self) -> str: |
|
|
"""Generate a summary of all collected evidence for the final report.""" |
|
|
if not self.evidence_ids: |
|
|
return "No evidence collected." |
|
|
|
|
|
summary = [f"Research Query: {self.query}\n"] |
|
|
|
|
|
|
|
|
if self.hypotheses: |
|
|
summary.append("## Hypotheses") |
|
|
for h in self.hypotheses: |
|
|
summary.append(f"- {h.statement} (Conf: {h.confidence})") |
|
|
summary.append("") |
|
|
|
|
|
|
|
|
|
|
|
evidence = self.get_all_evidence() |
|
|
summary.append(f"## Evidence ({len(evidence)} items)") |
|
|
|
|
|
|
|
|
for i, ev in enumerate(evidence[:20], 1): |
|
|
summary.append(f"{i}. {ev.citation.title} ({ev.citation.date})") |
|
|
summary.append(f" {ev.content[:200]}...") |
|
|
|
|
|
return "\n".join(summary) |
|
|
|
|
|
def add_hypothesis(self, hypothesis: Hypothesis) -> None: |
|
|
"""Add a hypothesis to tracking.""" |
|
|
self.hypotheses.append(hypothesis) |
|
|
logger.info("Added hypothesis", id=hypothesis.id, confidence=hypothesis.confidence) |
|
|
|
|
|
def add_conflict(self, conflict: Conflict) -> None: |
|
|
"""Add a detected conflict.""" |
|
|
self.conflicts.append(conflict) |
|
|
logger.info("Added conflict", id=conflict.id) |
|
|
|
|
|
def get_open_conflicts(self) -> list[Conflict]: |
|
|
"""Get unresolved conflicts.""" |
|
|
return [c for c in self.conflicts if c.status == "open"] |
|
|
|
|
|
def get_confirmed_hypotheses(self) -> list[Hypothesis]: |
|
|
"""Get high-confidence hypotheses.""" |
|
|
return [h for h in self.hypotheses if h.confidence > 0.8] |
|
|
|