Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import pymupdf4llm | |
| import spacy | |
| from transformers import AutoTokenizer, AutoModel | |
| from adapters import AutoAdapterModel | |
| from extract_citations import fetch_citations_for_dois | |
| from extract_embeddings import ( | |
| prune_contexts, | |
| embed_abstracts, | |
| embed_contexts, | |
| restore_inverted_abstract, | |
| calculate_distances | |
| ) | |
| from extract_mentions import extract_citation_contexts | |
| def extract_text(pdf_file): | |
| if not pdf_file: | |
| return "Please upload a PDF file." | |
| try: | |
| return pymupdf4llm.to_markdown(pdf_file) | |
| except Exception as e: | |
| return f"Error when processing PDF. {e}" | |
| def extract_citations(doi): | |
| try: | |
| citations_data = fetch_citations_for_dois([doi]) | |
| except Exception as e: | |
| return f"Please submit a valid DOI. {e}" | |
| return citations_data | |
| def get_cite_context_distance(pdf, doi): | |
| # Load models | |
| tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base') | |
| model = AutoAdapterModel.from_pretrained('allenai/specter2_base') | |
| nlp = spacy.load("en_core_web_sm") | |
| # fetch cited papers from OpenAlex | |
| citations_data = fetch_citations_for_dois([doi]) | |
| # get markdown text from PDF file | |
| text = extract_text(pdf.name) | |
| # get the context around citation markers | |
| citations = extract_citation_contexts(citations_data, text) | |
| citations["pruned_contexts"], citations["known_tokens_fraction"] = prune_contexts(citations, nlp, tokenizer) | |
| # embed the contexts | |
| citation_context_embedding = embed_contexts( | |
| citations[ | |
| (citations["known_tokens_fraction"] >= 0.7) & | |
| (~citations["pruned_contexts"].isna()) | |
| ]["pruned_contexts"].to_list(), | |
| model, | |
| tokenizer, | |
| ).detach().numpy() | |
| citations_data = {entry["id"]:entry for cite in citations_data.values() for entry in cite} | |
| # embed the abstract | |
| citation_abstract_embedding = embed_abstracts( | |
| [ | |
| { | |
| "title":citations_data[cite]["title"], | |
| "abstract": ( | |
| restore_inverted_abstract( | |
| citations_data[cite]["abstract_inverted_index"] | |
| ) | |
| if citations_data[cite]["abstract_inverted_index"] is not None | |
| else None | |
| ) | |
| } | |
| for cite in citations["citation_id"].unique() | |
| ], | |
| model, | |
| tokenizer, | |
| batch_size=4, | |
| ).detach().numpy() | |
| print(citation_abstract_embedding.shape) | |
| # calculate the distances | |
| index_left = citations.index[ | |
| (citations["known_tokens_fraction"] >= 0.7) & | |
| (~citations["pruned_contexts"].isna()) | |
| ].tolist() | |
| index_right = citations["citation_id"].unique().tolist() | |
| indices = [ | |
| (index_left.index(i), index_right.index(cite_id)) | |
| if i in index_left else (None, None) | |
| for i, cite_id in enumerate(citations["citation_id"]) | |
| ] | |
| distances = np.array(calculate_distances(citation_context_embedding, citation_abstract_embedding, indices)) | |
| results = [] | |
| for i, dist in enumerate(distances): | |
| if not np.isnan(dist): | |
| obj = {} | |
| left_context = citations.left_context[i][-50:].replace('\n', '') | |
| right_context = citations.right_context[i][:50].replace('\n', '') | |
| obj["cite_context_short"] = f"...{left_context}{citations.mention[i]}{right_context}..." | |
| obj["cited_paper"] = citations_data[citations.citation_id[i]]["title"] | |
| obj["cited_paper_id"] = citations.citation_id[i] | |
| obj["distance"] = dist | |
| results.append(obj) | |
| return {"score": np.nanmean(distances), "individual_citations": results} | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Citation Integrity Score") | |
| doi_input = gr.Textbox(label="Enter DOI (optional)") | |
| pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| output = gr.Textbox(label="Extracted Citations", lines=20) | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click(fn=get_cite_context_distance, inputs=[pdf_input, doi_input], outputs=output) | |
| demo.launch() | |