russellmiller49 commited on
Commit
f1d6381
·
1 Parent(s): efd6737

Fix in-memory Qdrant initialization for T4 deployment

Browse files

- Added support for in-memory Qdrant mode when host is None
- Auto-initialize collection and upload vectors for in-memory mode
- Load embeddings from data/vectors/embeddings.npy
- Batch upload vectors to prevent memory issues

Files changed (1) hide show
  1. src/retrieval/hybrid_retriever.py +71 -6
src/retrieval/hybrid_retriever.py CHANGED
@@ -72,8 +72,8 @@ class HybridRetriever:
72
  CPT_PATTERN = re.compile(r'\b\d{5}\b')
73
 
74
  def __init__(self,
75
- qdrant_host: str = "localhost",
76
- qdrant_port: int = 6333,
77
  collection_name: str = "ip_medcpt",
78
  chunks_file: str = "data/chunks/chunks.jsonl",
79
  cpt_index_file: str = "data/term_index/cpt_codes.jsonl",
@@ -84,8 +84,8 @@ class HybridRetriever:
84
  Initialize hybrid retriever.
85
 
86
  Args:
87
- qdrant_host: Qdrant server host
88
- qdrant_port: Qdrant server port
89
  collection_name: Name of Qdrant collection
90
  chunks_file: Path to chunks JSONL file
91
  cpt_index_file: Path to CPT codes index
@@ -93,8 +93,11 @@ class HybridRetriever:
93
  query_encoder_model: Model for query encoding
94
  reranker_model: Cross-encoder for reranking
95
  """
96
- # Initialize Qdrant client
97
- self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
 
 
 
98
  self.collection_name = collection_name
99
 
100
  # Load chunks for BM25 and metadata
@@ -130,8 +133,70 @@ class HybridRetriever:
130
  print(f"Loading reranker: {reranker_model}")
131
  self.reranker = CrossEncoder(reranker_model)
132
 
 
 
 
 
133
  print("Hybrid retriever initialized")
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def _load_term_index(self, index_file: str) -> Dict[str, List[str]]:
136
  """Load term index from JSONL file."""
137
  index = defaultdict(list)
 
72
  CPT_PATTERN = re.compile(r'\b\d{5}\b')
73
 
74
  def __init__(self,
75
+ qdrant_host: Optional[str] = "localhost",
76
+ qdrant_port: Optional[int] = 6333,
77
  collection_name: str = "ip_medcpt",
78
  chunks_file: str = "data/chunks/chunks.jsonl",
79
  cpt_index_file: str = "data/term_index/cpt_codes.jsonl",
 
84
  Initialize hybrid retriever.
85
 
86
  Args:
87
+ qdrant_host: Qdrant server host (None for in-memory)
88
+ qdrant_port: Qdrant server port (None for in-memory)
89
  collection_name: Name of Qdrant collection
90
  chunks_file: Path to chunks JSONL file
91
  cpt_index_file: Path to CPT codes index
 
93
  query_encoder_model: Model for query encoding
94
  reranker_model: Cross-encoder for reranking
95
  """
96
+ # Initialize Qdrant client (in-memory if host is None)
97
+ if qdrant_host is None:
98
+ self.qdrant = QdrantClient(":memory:")
99
+ else:
100
+ self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
101
  self.collection_name = collection_name
102
 
103
  # Load chunks for BM25 and metadata
 
133
  print(f"Loading reranker: {reranker_model}")
134
  self.reranker = CrossEncoder(reranker_model)
135
 
136
+ # Initialize in-memory collection if needed
137
+ if qdrant_host is None:
138
+ self._initialize_inmemory_collection()
139
+
140
  print("Hybrid retriever initialized")
141
 
142
+ def _initialize_inmemory_collection(self):
143
+ """Initialize in-memory Qdrant collection with vectors."""
144
+ print("Initializing in-memory Qdrant collection...")
145
+
146
+ from qdrant_client.models import Distance, VectorParams, PointStruct
147
+
148
+ # Create collection
149
+ try:
150
+ self.qdrant.create_collection(
151
+ collection_name=self.collection_name,
152
+ vectors_config=VectorParams(size=768, distance=Distance.COSINE)
153
+ )
154
+ print(f"Created collection: {self.collection_name}")
155
+ except Exception as e:
156
+ print(f"Collection might already exist: {e}")
157
+
158
+ # Load embeddings if available
159
+ embeddings_file = Path("data/vectors/embeddings.npy")
160
+ if embeddings_file.exists():
161
+ print("Loading embeddings...")
162
+ embeddings = np.load(embeddings_file)
163
+
164
+ # Prepare points for upload
165
+ points = []
166
+ for i, (chunk, embedding) in enumerate(zip(self.chunks[:len(embeddings)], embeddings)):
167
+ # Use chunk_id if available, otherwise use index
168
+ chunk_id = chunk.get('chunk_id', chunk.get('id', str(i)))
169
+
170
+ point = PointStruct(
171
+ id=i,
172
+ vector=embedding.tolist(),
173
+ payload={
174
+ "chunk_id": chunk_id,
175
+ "text": chunk.get('text', ''),
176
+ "doc_id": chunk.get('doc_id', ''),
177
+ "section_title": chunk.get('section_title', ''),
178
+ "authority_tier": chunk.get('authority_tier', 'A4'),
179
+ "evidence_level": chunk.get('evidence_level', 'H4'),
180
+ "year": chunk.get('year', 2020),
181
+ "doc_type": chunk.get('doc_type', 'article')
182
+ }
183
+ )
184
+ points.append(point)
185
+
186
+ # Upload in batches
187
+ batch_size = 100
188
+ for i in range(0, len(points), batch_size):
189
+ batch = points[i:i+batch_size]
190
+ self.qdrant.upsert(
191
+ collection_name=self.collection_name,
192
+ points=batch
193
+ )
194
+ print(f"Uploaded batch {i//batch_size + 1}/{(len(points) + batch_size - 1)//batch_size}")
195
+
196
+ print(f"Uploaded {len(points)} vectors to in-memory collection")
197
+ else:
198
+ print("Warning: No embeddings file found. Semantic search will not work.")
199
+
200
  def _load_term_index(self, index_file: str) -> Dict[str, List[str]]:
201
  """Load term index from JSONL file."""
202
  index = defaultdict(list)