DeepBoner / src /services /research_memory.py
VibecoderMcSwaggins's picture
fix: P0 Advanced Mode timeout synthesis + CodeRabbit recommendations
3a2b22f
"""Shared research memory layer for all orchestration modes.
Design Pattern: Dependency Injection
- Receives embedding service via constructor
- Uses service_loader.get_embedding_service() as default (Strategy Pattern)
- Allows testing with mock services
SOLID Principles:
- Dependency Inversion: Depends on EmbeddingServiceProtocol, not concrete class
- Open/Closed: Works with any service implementing the protocol
"""
from typing import TYPE_CHECKING, Any, get_args
import structlog
from src.agents.graph.state import Conflict, Hypothesis
from src.utils.models import Citation, Evidence, SourceName
if TYPE_CHECKING:
from src.services.embedding_protocol import EmbeddingServiceProtocol
logger = structlog.get_logger()
class ResearchMemory:
"""Shared cognitive state for research workflows.
This is the memory layer that ALL modes use.
It mimics the LangGraph state management but for manual orchestration.
The embedding service is selected via get_embedding_service(), which returns:
- LlamaIndexRAGService (premium tier) if OPENAI_API_KEY is available
- EmbeddingService (free tier) as fallback
"""
def __init__(self, query: str, embedding_service: "EmbeddingServiceProtocol | None" = None):
"""Initialize ResearchMemory with a query and optional embedding service.
Args:
query: The research query to track evidence for.
embedding_service: Service for semantic search and deduplication.
Uses get_embedding_service() if not provided,
which selects the best available service.
"""
self.query = query
self.hypotheses: list[Hypothesis] = []
self.conflicts: list[Conflict] = []
self.evidence_ids: list[str] = []
self._evidence_cache: dict[str, Evidence] = {}
self.iteration_count: int = 0
# Use service loader for tiered service selection (Strategy Pattern)
if embedding_service is None:
from src.utils.service_loader import get_embedding_service
self._embedding_service: EmbeddingServiceProtocol = get_embedding_service()
else:
self._embedding_service = embedding_service
async def store_evidence(self, evidence: list[Evidence]) -> list[str]:
"""Store evidence and return new IDs (deduped)."""
if not self._embedding_service:
return []
# Deduplicate and store (deduplicate() already calls add_evidence() internally)
unique = await self._embedding_service.deduplicate(evidence)
# Track IDs and cache (evidence already stored by deduplicate())
new_ids = []
for ev in unique:
ev_id = ev.citation.url
new_ids.append(ev_id)
self._evidence_cache[ev_id] = ev
self.evidence_ids.extend(new_ids)
if new_ids:
logger.info("Stored new evidence", count=len(new_ids))
return new_ids
def get_all_evidence(self) -> list[Evidence]:
"""Get all accumulated evidence objects."""
return list(self._evidence_cache.values())
async def get_relevant_evidence(self, n: int = 20) -> list[Evidence]:
"""Retrieve relevant evidence for current query."""
if not self._embedding_service:
return []
results = await self._embedding_service.search_similar(self.query, n_results=n)
evidence_list = []
for r in results:
meta = r.get("metadata", {})
authors_str = meta.get("authors", "")
authors = [a.strip() for a in authors_str.split(",")] if authors_str else []
# Reconstruct Evidence object
source_raw = meta.get("source", "web")
# Validate source against canonical SourceName type (avoids drift)
valid_sources = get_args(SourceName)
source_name: Any = source_raw if source_raw in valid_sources else "web"
citation = Citation(
source=source_name,
title=meta.get("title", "Unknown"),
url=meta.get("url", r.get("id", "")),
date=meta.get("date", "Unknown"),
authors=authors,
)
evidence_list.append(
Evidence(
content=r.get("content", ""),
citation=citation,
relevance=1.0 - r.get("distance", 0.5), # Approx conversion
)
)
return evidence_list
async def get_context_summary(self) -> str:
"""Generate a summary of all collected evidence for the final report."""
if not self.evidence_ids:
return "No evidence collected."
summary = [f"Research Query: {self.query}\n"]
# Add Hypotheses
if self.hypotheses:
summary.append("## Hypotheses")
for h in self.hypotheses:
summary.append(f"- {h.statement} (Conf: {h.confidence})")
summary.append("")
# Add Top Evidence (limit to avoid token overflow)
# We use get_all_evidence() but might need to summarize if too large
evidence = self.get_all_evidence()
summary.append(f"## Evidence ({len(evidence)} items)")
# Group by source for cleaner summary
for i, ev in enumerate(evidence[:20], 1): # Limit to top 20 items
summary.append(f"{i}. {ev.citation.title} ({ev.citation.date})")
summary.append(f" {ev.content[:200]}...") # Brief snippet
return "\n".join(summary)
def add_hypothesis(self, hypothesis: Hypothesis) -> None:
"""Add a hypothesis to tracking."""
self.hypotheses.append(hypothesis)
logger.info("Added hypothesis", id=hypothesis.id, confidence=hypothesis.confidence)
def add_conflict(self, conflict: Conflict) -> None:
"""Add a detected conflict."""
self.conflicts.append(conflict)
logger.info("Added conflict", id=conflict.id)
def get_open_conflicts(self) -> list[Conflict]:
"""Get unresolved conflicts."""
return [c for c in self.conflicts if c.status == "open"]
def get_confirmed_hypotheses(self) -> list[Hypothesis]:
"""Get high-confidence hypotheses."""
return [h for h in self.hypotheses if h.confidence > 0.8]