Spaces:
Sleeping
Sleeping
Commit
·
f1d6381
1
Parent(s):
efd6737
Fix in-memory Qdrant initialization for T4 deployment
Browse files- Added support for in-memory Qdrant mode when host is None
- Auto-initialize collection and upload vectors for in-memory mode
- Load embeddings from data/vectors/embeddings.npy
- Batch upload vectors to prevent memory issues
src/retrieval/hybrid_retriever.py
CHANGED
|
@@ -72,8 +72,8 @@ class HybridRetriever:
|
|
| 72 |
CPT_PATTERN = re.compile(r'\b\d{5}\b')
|
| 73 |
|
| 74 |
def __init__(self,
|
| 75 |
-
qdrant_host: str = "localhost",
|
| 76 |
-
qdrant_port: int = 6333,
|
| 77 |
collection_name: str = "ip_medcpt",
|
| 78 |
chunks_file: str = "data/chunks/chunks.jsonl",
|
| 79 |
cpt_index_file: str = "data/term_index/cpt_codes.jsonl",
|
|
@@ -84,8 +84,8 @@ class HybridRetriever:
|
|
| 84 |
Initialize hybrid retriever.
|
| 85 |
|
| 86 |
Args:
|
| 87 |
-
qdrant_host: Qdrant server host
|
| 88 |
-
qdrant_port: Qdrant server port
|
| 89 |
collection_name: Name of Qdrant collection
|
| 90 |
chunks_file: Path to chunks JSONL file
|
| 91 |
cpt_index_file: Path to CPT codes index
|
|
@@ -93,8 +93,11 @@ class HybridRetriever:
|
|
| 93 |
query_encoder_model: Model for query encoding
|
| 94 |
reranker_model: Cross-encoder for reranking
|
| 95 |
"""
|
| 96 |
-
# Initialize Qdrant client
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
self.collection_name = collection_name
|
| 99 |
|
| 100 |
# Load chunks for BM25 and metadata
|
|
@@ -130,8 +133,70 @@ class HybridRetriever:
|
|
| 130 |
print(f"Loading reranker: {reranker_model}")
|
| 131 |
self.reranker = CrossEncoder(reranker_model)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
print("Hybrid retriever initialized")
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def _load_term_index(self, index_file: str) -> Dict[str, List[str]]:
|
| 136 |
"""Load term index from JSONL file."""
|
| 137 |
index = defaultdict(list)
|
|
|
|
| 72 |
CPT_PATTERN = re.compile(r'\b\d{5}\b')
|
| 73 |
|
| 74 |
def __init__(self,
|
| 75 |
+
qdrant_host: Optional[str] = "localhost",
|
| 76 |
+
qdrant_port: Optional[int] = 6333,
|
| 77 |
collection_name: str = "ip_medcpt",
|
| 78 |
chunks_file: str = "data/chunks/chunks.jsonl",
|
| 79 |
cpt_index_file: str = "data/term_index/cpt_codes.jsonl",
|
|
|
|
| 84 |
Initialize hybrid retriever.
|
| 85 |
|
| 86 |
Args:
|
| 87 |
+
qdrant_host: Qdrant server host (None for in-memory)
|
| 88 |
+
qdrant_port: Qdrant server port (None for in-memory)
|
| 89 |
collection_name: Name of Qdrant collection
|
| 90 |
chunks_file: Path to chunks JSONL file
|
| 91 |
cpt_index_file: Path to CPT codes index
|
|
|
|
| 93 |
query_encoder_model: Model for query encoding
|
| 94 |
reranker_model: Cross-encoder for reranking
|
| 95 |
"""
|
| 96 |
+
# Initialize Qdrant client (in-memory if host is None)
|
| 97 |
+
if qdrant_host is None:
|
| 98 |
+
self.qdrant = QdrantClient(":memory:")
|
| 99 |
+
else:
|
| 100 |
+
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
|
| 101 |
self.collection_name = collection_name
|
| 102 |
|
| 103 |
# Load chunks for BM25 and metadata
|
|
|
|
| 133 |
print(f"Loading reranker: {reranker_model}")
|
| 134 |
self.reranker = CrossEncoder(reranker_model)
|
| 135 |
|
| 136 |
+
# Initialize in-memory collection if needed
|
| 137 |
+
if qdrant_host is None:
|
| 138 |
+
self._initialize_inmemory_collection()
|
| 139 |
+
|
| 140 |
print("Hybrid retriever initialized")
|
| 141 |
|
| 142 |
+
def _initialize_inmemory_collection(self):
|
| 143 |
+
"""Initialize in-memory Qdrant collection with vectors."""
|
| 144 |
+
print("Initializing in-memory Qdrant collection...")
|
| 145 |
+
|
| 146 |
+
from qdrant_client.models import Distance, VectorParams, PointStruct
|
| 147 |
+
|
| 148 |
+
# Create collection
|
| 149 |
+
try:
|
| 150 |
+
self.qdrant.create_collection(
|
| 151 |
+
collection_name=self.collection_name,
|
| 152 |
+
vectors_config=VectorParams(size=768, distance=Distance.COSINE)
|
| 153 |
+
)
|
| 154 |
+
print(f"Created collection: {self.collection_name}")
|
| 155 |
+
except Exception as e:
|
| 156 |
+
print(f"Collection might already exist: {e}")
|
| 157 |
+
|
| 158 |
+
# Load embeddings if available
|
| 159 |
+
embeddings_file = Path("data/vectors/embeddings.npy")
|
| 160 |
+
if embeddings_file.exists():
|
| 161 |
+
print("Loading embeddings...")
|
| 162 |
+
embeddings = np.load(embeddings_file)
|
| 163 |
+
|
| 164 |
+
# Prepare points for upload
|
| 165 |
+
points = []
|
| 166 |
+
for i, (chunk, embedding) in enumerate(zip(self.chunks[:len(embeddings)], embeddings)):
|
| 167 |
+
# Use chunk_id if available, otherwise use index
|
| 168 |
+
chunk_id = chunk.get('chunk_id', chunk.get('id', str(i)))
|
| 169 |
+
|
| 170 |
+
point = PointStruct(
|
| 171 |
+
id=i,
|
| 172 |
+
vector=embedding.tolist(),
|
| 173 |
+
payload={
|
| 174 |
+
"chunk_id": chunk_id,
|
| 175 |
+
"text": chunk.get('text', ''),
|
| 176 |
+
"doc_id": chunk.get('doc_id', ''),
|
| 177 |
+
"section_title": chunk.get('section_title', ''),
|
| 178 |
+
"authority_tier": chunk.get('authority_tier', 'A4'),
|
| 179 |
+
"evidence_level": chunk.get('evidence_level', 'H4'),
|
| 180 |
+
"year": chunk.get('year', 2020),
|
| 181 |
+
"doc_type": chunk.get('doc_type', 'article')
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
points.append(point)
|
| 185 |
+
|
| 186 |
+
# Upload in batches
|
| 187 |
+
batch_size = 100
|
| 188 |
+
for i in range(0, len(points), batch_size):
|
| 189 |
+
batch = points[i:i+batch_size]
|
| 190 |
+
self.qdrant.upsert(
|
| 191 |
+
collection_name=self.collection_name,
|
| 192 |
+
points=batch
|
| 193 |
+
)
|
| 194 |
+
print(f"Uploaded batch {i//batch_size + 1}/{(len(points) + batch_size - 1)//batch_size}")
|
| 195 |
+
|
| 196 |
+
print(f"Uploaded {len(points)} vectors to in-memory collection")
|
| 197 |
+
else:
|
| 198 |
+
print("Warning: No embeddings file found. Semantic search will not work.")
|
| 199 |
+
|
| 200 |
def _load_term_index(self, index_file: str) -> Dict[str, List[str]]:
|
| 201 |
"""Load term index from JSONL file."""
|
| 202 |
index = defaultdict(list)
|