Spaces:
Sleeping
Sleeping
| __author__ = "qiao" | |
| """ | |
| Conduct the first stage retrieval by the hybrid retriever | |
| """ | |
| from beir.datasets.data_loader import GenericDataLoader | |
| import faiss | |
| import json | |
| from nltk import word_tokenize | |
| import numpy as np | |
| import os | |
| from rank_bm25 import BM25Okapi | |
| import sys | |
| import tqdm | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| from beir import util, LoggingHandler | |
| # Device detection - use CUDA if available, otherwise CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| def get_bm25_corpus_index(corpus): | |
| corpus_path = os.path.join(f"trialgpt_retrieval/bm25_corpus_{corpus}.json") | |
| # if already cached then load, otherwise build | |
| if os.path.exists(corpus_path): | |
| corpus_data = json.load(open(corpus_path)) | |
| tokenized_corpus = corpus_data["tokenized_corpus"] | |
| corpus_nctids = corpus_data["corpus_nctids"] | |
| else: | |
| tokenized_corpus = [] | |
| corpus_nctids = [] | |
| with open(f"dataset/{corpus}/corpus.jsonl", "r") as f: | |
| for line in f.readlines(): | |
| entry = json.loads(line) | |
| corpus_nctids.append(entry["_id"]) | |
| # weighting: 3 * title, 2 * condition, 1 * text | |
| tokens = word_tokenize(entry["title"].lower()) * 3 | |
| for disease in entry["metadata"]["diseases_list"]: | |
| tokens += word_tokenize(disease.lower()) * 2 | |
| tokens += word_tokenize(entry["text"].lower()) | |
| tokenized_corpus.append(tokens) | |
| corpus_data = { | |
| "tokenized_corpus": tokenized_corpus, | |
| "corpus_nctids": corpus_nctids, | |
| } | |
| with open(corpus_path, "w") as f: | |
| json.dump(corpus_data, f, indent=4) | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| return bm25, corpus_nctids | |
| def get_medcpt_corpus_index(corpus): | |
| corpus_path = f"trialgpt_retrieval/{corpus}_embeds.npy" | |
| nctids_path = f"trialgpt_retrieval/{corpus}_nctids.json" | |
| # if already cached then load, otherwise build | |
| if os.path.exists(corpus_path): | |
| embeds = np.load(corpus_path) | |
| corpus_nctids = json.load(open(nctids_path)) | |
| else: | |
| embeds = [] | |
| corpus_nctids = [] | |
| model = AutoModel.from_pretrained("ncbi/MedCPT-Article-Encoder").to(device) | |
| tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Article-Encoder") | |
| with open(f"dataset/{corpus}/corpus.jsonl", "r") as f: | |
| print("Encoding the corpus") | |
| for line in tqdm.tqdm(f.readlines()): | |
| entry = json.loads(line) | |
| corpus_nctids.append(entry["_id"]) | |
| title = entry["title"] | |
| text = entry["text"] | |
| with torch.no_grad(): | |
| # tokenize the articles | |
| encoded = tokenizer( | |
| [[title, text]], | |
| truncation=True, | |
| padding=True, | |
| return_tensors='pt', | |
| max_length=512, | |
| ).to(device) | |
| embed = model(**encoded).last_hidden_state[:, 0, :] | |
| embeds.append(embed[0].cpu().numpy()) | |
| embeds = np.array(embeds) | |
| np.save(corpus_path, embeds) | |
| with open(nctids_path, "w") as f: | |
| json.dump(corpus_nctids, f, indent=4) | |
| index = faiss.IndexFlatIP(768) | |
| index.add(embeds) | |
| return index, corpus_nctids | |
| if __name__ == "__main__": | |
| # different corpora, "trec_2021", "trec_2022", "sigir" | |
| corpus = sys.argv[1] | |
| # query type | |
| q_type = sys.argv[2] | |
| # different k for fusion | |
| k = int(sys.argv[3]) | |
| # bm25 weight | |
| bm25_wt = int(sys.argv[4]) | |
| # medcpt weight | |
| medcpt_wt = int(sys.argv[5]) | |
| # how many to rank | |
| N = 2000 | |
| # loading the qrels | |
| _, _, qrels = GenericDataLoader(data_folder=f"dataset/{corpus}/").load(split="test") | |
| # loading all types of queries | |
| id2queries = json.load(open(f"dataset/{corpus}/id2queries.json")) | |
| # loading the indices | |
| bm25, bm25_nctids = get_bm25_corpus_index(corpus) | |
| medcpt, medcpt_nctids = get_medcpt_corpus_index(corpus) | |
| # loading the query encoder for MedCPT | |
| model = AutoModel.from_pretrained("ncbi/MedCPT-Query-Encoder").to(device) | |
| tokenizer = AutoTokenizer.from_pretrained("ncbi/MedCPT-Query-Encoder") | |
| # then conduct the searches, saving top 1k | |
| output_path = f"results/qid2nctids_results_{q_type}_{corpus}_k{k}_bm25wt{bm25_wt}_medcptwt{medcpt_wt}_N{N}.json" | |
| qid2nctids = {} | |
| recalls = [] | |
| with open(f"dataset/{corpus}/queries.jsonl", "r") as f: | |
| for line in tqdm.tqdm(f.readlines()): | |
| entry = json.loads(line) | |
| query = entry["text"] | |
| qid = entry["_id"] | |
| if qid not in qrels: | |
| continue | |
| truth_sum = sum(qrels[qid].values()) | |
| # get the keyword list | |
| if q_type in ["raw", "human_summary"]: | |
| conditions = [id2queries[qid][q_type]] | |
| elif "turbo" in q_type: | |
| conditions = id2queries[qid][q_type]["conditions"] | |
| elif "Clinician" in q_type: | |
| conditions = id2queries[qid].get(q_type, []) | |
| if len(conditions) == 0: | |
| nctid2score = {} | |
| else: | |
| # a list of nctid lists for the bm25 retriever | |
| bm25_condition_top_nctids = [] | |
| for condition in conditions: | |
| tokens = word_tokenize(condition.lower()) | |
| top_nctids = bm25.get_top_n(tokens, bm25_nctids, n=N) | |
| bm25_condition_top_nctids.append(top_nctids) | |
| # doing MedCPT retrieval | |
| with torch.no_grad(): | |
| encoded = tokenizer( | |
| conditions, | |
| truncation=True, | |
| padding=True, | |
| return_tensors='pt', | |
| max_length=256, | |
| ).to(device) | |
| # encode the queries (use the [CLS] last hidden states as the representations) | |
| embeds = model(**encoded).last_hidden_state[:, 0, :].cpu().numpy() | |
| # search the Faiss index | |
| scores, inds = medcpt.search(embeds, k=N) | |
| medcpt_condition_top_nctids = [] | |
| for ind_list in inds: | |
| top_nctids = [medcpt_nctids[ind] for ind in ind_list] | |
| medcpt_condition_top_nctids.append(top_nctids) | |
| nctid2score = {} | |
| for condition_idx, (bm25_top_nctids, medcpt_top_nctids) in enumerate(zip(bm25_condition_top_nctids, medcpt_condition_top_nctids)): | |
| if bm25_wt > 0: | |
| for rank, nctid in enumerate(bm25_top_nctids): | |
| if nctid not in nctid2score: | |
| nctid2score[nctid] = 0 | |
| nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1)) | |
| if medcpt_wt > 0: | |
| for rank, nctid in enumerate(medcpt_top_nctids): | |
| if nctid not in nctid2score: | |
| nctid2score[nctid] = 0 | |
| nctid2score[nctid] += (1 / (rank + k)) * (1 / (condition_idx + 1)) | |
| nctid2score = sorted(nctid2score.items(), key=lambda x: -x[1]) | |
| top_nctids = [nctid for nctid, _ in nctid2score[:N]] | |
| qid2nctids[qid] = top_nctids | |
| actual_sum = sum([qrels[qid].get(nctid, 0) for nctid in top_nctids]) | |
| recalls.append(actual_sum / truth_sum) | |
| with open(output_path, "w") as f: | |
| json.dump(qid2nctids, f, indent=4) | |