Spaces:
Sleeping
Sleeping
| import json | |
| import string | |
| from time import time | |
| import en_core_web_lg | |
| import inflect | |
| import nltk | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| from nltk.tokenize import sent_tokenize | |
| from transformers import pipeline | |
| # Set constant values | |
| INFLECT_ENGINE = inflect.engine() | |
| TOP_K = 30 | |
| NLI_LIMIT = 0.9 | |
| st.set_page_config(layout="wide") | |
| def get_top_k(): | |
| return TOP_K | |
| def get_nli_limit(): | |
| return NLI_LIMIT | |
| ### Streamlit specific | |
| def load_model_prompting(): | |
| return pipeline("fill-mask", model="distilbert-base-uncased") | |
| def load_model_nli(): | |
| try: | |
| return pipeline( | |
| task="sentiment-analysis", model="roberta-large-mnli", device="mps" | |
| ) | |
| except: | |
| return pipeline(task="sentiment-analysis", model="roberta-large-mnli") | |
| def load_spacy_pipeline(): | |
| return en_core_web_lg.load() | |
| def download_punkt(): | |
| nltk.download("punkt") | |
| download_punkt() | |
| def read_json_from_web(uploaded_json): | |
| return json.load(uploaded_json) | |
| def read_csv_from_web(uploaded_file): | |
| """Read CSV from the streamlit interface | |
| :param uploaded_file: File to read | |
| :type uploaded_file: UploadedFile (BytesIO) | |
| :return: Dataframe | |
| :rtype: pandas DataFrame | |
| """ | |
| try: | |
| # Try first to read comma separated and semicolon separated files | |
| data = pd.read_csv(uploaded_file, sep=None, engine="python") | |
| # If both are not correct, then it will error and go to the except | |
| except pd.errors.ParserError: | |
| # This should be the case when there is no separator (1 column csv) | |
| # Reset the IO object due to the previous crash | |
| uploaded_file.seek(0) | |
| # Use standard reading of CSV (no separator) | |
| data = pd.read_csv(uploaded_file) | |
| return data | |
| def apply_style(): | |
| # Avoid having ellipsis in the multi select options | |
| styl = """ | |
| <style> | |
| .stMultiSelect span{ | |
| max-width: none; | |
| } | |
| </style> | |
| """ | |
| st.markdown(styl, unsafe_allow_html=True) | |
| # Set color of multiselect to red | |
| st.markdown( | |
| """ | |
| <style> | |
| span[data-baseweb="tag"] { | |
| background-color: red !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| hide_st_style = """ | |
| <style> | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| header {visibility: hidden;} | |
| </style> | |
| """ | |
| st.markdown(hide_st_style, unsafe_allow_html=True) | |
| def choose_text_menu(text): | |
| if "text" not in st.session_state: | |
| st.session_state.text = "Several demonstrators were injured." | |
| text = st.text_area("Event description", st.session_state.text) | |
| return text | |
| def initiate_widget_st_state(widget_key, perm_key, default_value): | |
| if perm_key not in st.session_state: | |
| st.session_state[perm_key] = default_value | |
| if widget_key not in st.session_state: | |
| st.session_state[widget_key] = st.session_state[perm_key] | |
| def get_idx_column(col_name, col_list): | |
| if col_name in col_list: | |
| return col_list.index(col_name) | |
| else: | |
| return 0 | |
| def callback_add_to_multiselect(str_to_add, multiselect_key, text_input_key, *keys): | |
| if len(str_to_add) == 0: | |
| st.warning("Word is empty, did you press Enter on the field text?") | |
| return | |
| current_dict = st.session_state | |
| *dict_keys, item_keys = keys | |
| try: | |
| for key in dict_keys: | |
| current_dict = current_dict[key] | |
| current_dict[item_keys].append(str_to_add) | |
| except KeyError as e: | |
| raise KeyError(keys) from e | |
| if multiselect_key in st.session_state: | |
| st.session_state[multiselect_key].append(str_to_add) | |
| else: | |
| st.session_state[multiselect_key] = [str_to_add] | |
| st.session_state[text_input_key] = "" | |
| # Split the text into sentences. Necessary for NLI models | |
| def split_sentences(text): | |
| return sent_tokenize(text) | |
| def get_num_sentences_in_list_text(list_texts): | |
| num_sentences = 0 | |
| for text in list_texts: | |
| num_sentences += len(split_sentences(text)) | |
| return num_sentences | |
| ###### Prompting | |
| def query_model_prompting(model, text, prompt_with_mask, top_k, targets): | |
| """Query the prompting model | |
| :param model: Prompting model object | |
| :type model: Huggingface pipeline object | |
| :param text: Event description (context) | |
| :type text: str | |
| :param prompt_with_mask: Prompt with a mask | |
| :type prompt_with_mask: str | |
| :param top_k: Number of tokens to output | |
| :type top_k: integer | |
| :param targets: Restrict the answer to these possible tokens | |
| :type targets: list | |
| :return: Results of the prompting model | |
| :rtype: list of dict | |
| """ | |
| sequence = text + prompt_with_mask | |
| output_tokens = model(sequence, top_k=top_k, targets=targets) | |
| return output_tokens | |
| def do_sentence_entailment(sentence, hypothesis, model): | |
| """Concatenate context and hypothesis then perform entailment | |
| :param sentence: Event description (context), 1 sentence | |
| :type sentence: str | |
| :param hypothesis: Mask filled with a token | |
| :type hypothesis: str | |
| :param model: NLI Model | |
| :type model: Huggingface pipeline | |
| :return: DataFrame containing the result of the entailment | |
| :rtype: pandas DataFrame | |
| """ | |
| text = sentence + "</s></s>" + hypothesis | |
| res = model(text, return_all_scores=True) | |
| df_res = pd.DataFrame(res[0]) | |
| df_res["label"] = df_res["label"].apply(lambda x: x.lower()) | |
| df_res.columns = ["Label", "Score"] | |
| return df_res | |
| def softmax(x): | |
| """Compute softmax values for each sets of scores in x.""" | |
| return np.exp(x) / np.sum(np.exp(x), axis=0) | |
| def get_singular_form(word): | |
| """Get the singular form of a word | |
| :param word: word | |
| :type word: string | |
| :return: singular form of the word | |
| :rtype: string | |
| """ | |
| if INFLECT_ENGINE.singular_noun(word): | |
| return INFLECT_ENGINE.singular_noun(word) | |
| else: | |
| return word | |
| ######### NLI + PROMPTING | |
| def do_text_entailment(text, hypothesis, model): | |
| """ | |
| Do entailment for each sentence of the event description as | |
| model was trained on sentence pair | |
| :param text: Event Description (context) | |
| :type text: str | |
| :param hypothesis: Mask filled with a token | |
| :type hypothesis: str | |
| :param model: Model NLI | |
| :type model: Huggingface pipeline | |
| :return: List of entailment results for each sentence of the text | |
| :rtype: list | |
| """ | |
| text_entailment_results = [] | |
| for i, sentence in enumerate(split_sentences(text)): | |
| df_score = do_sentence_entailment(sentence, hypothesis, model) | |
| text_entailment_results.append((sentence, hypothesis, df_score)) | |
| return text_entailment_results | |
| def get_true_entailment(text_entailment_results, nli_limit): | |
| """ | |
| From the result of each sentence entailment, extract the maximum entailment score and | |
| check if it's higher than the entailment threshold. | |
| """ | |
| true_hypothesis_list = [] | |
| max_score = 0 | |
| for sentence_entailment in text_entailment_results: | |
| df_score = sentence_entailment[2] | |
| score = df_score[df_score["Label"] == "entailment"]["Score"].values.max() | |
| if score > max_score: | |
| max_score = score | |
| if max_score > nli_limit: | |
| true_hypothesis_list.append((sentence_entailment[1], np.round(max_score, 2))) | |
| return list(set(true_hypothesis_list)) | |
| def run_model_nli(data, batch_size, model_nli, use_tf=False): | |
| if not use_tf: | |
| return model_nli(data, top_k=3, batch_size=batch_size) | |
| else: | |
| raise NotImplementedError | |
| # return run_pipeline_on_gpu(data, batch_size, model_nli["tokenizer"], model_nli["model"]) | |
| def prompt_to_nli_batching( | |
| text, | |
| prompt, | |
| model_prompting, | |
| nli_model, | |
| nlp, | |
| top_k=10, | |
| nli_limit=0.5, | |
| targets=None, | |
| additional_words=None, | |
| remove_lemma=False, | |
| use_tf=False, | |
| ): | |
| # Check if text has end ponctuation | |
| if text[-1] not in string.punctuation: | |
| text += "." | |
| prompt_masked = prompt.format(model_prompting.tokenizer.mask_token) | |
| output_prompting = query_model_prompting( | |
| model_prompting, text, prompt_masked, top_k, targets=targets | |
| ) | |
| if remove_lemma: | |
| output_prompting = filter_prompt_output_by_lemma(prompt, output_prompting, nlp) | |
| full_batch_concat = [] | |
| prompt_tokens = [] | |
| for token in output_prompting: | |
| hypothesis = prompt.format(token["token_str"]) | |
| for i, sentence in enumerate(split_sentences(text)): | |
| full_batch_concat.append(sentence + "</s></s>" + hypothesis) | |
| prompt_tokens.append((token["token_str"], token["score"])) | |
| # Add words that must be tried for entailment | |
| # Also increase batch_size | |
| if additional_words: | |
| for i, sentence in enumerate(split_sentences(text)): | |
| for token in additional_words: | |
| hypothesis = prompt.format(token) | |
| full_batch_concat.append(sentence + "</s></s>" + hypothesis) | |
| prompt_tokens.append((token, 1)) | |
| top_k = top_k + 1 | |
| results_nli = run_model_nli(full_batch_concat, top_k, nli_model, use_tf) | |
| # Get entailed tokens | |
| entailed_tokens = [] | |
| for i, res in enumerate(results_nli): | |
| entailed_tokens.extend( | |
| [ | |
| (get_singular_form(prompt_tokens[i][0]), x["score"]) | |
| for x in res | |
| if ((x["label"] == "ENTAILMENT") & (x["score"] > nli_limit)) | |
| ] | |
| ) | |
| if entailed_tokens: | |
| entailed_tokens = list( | |
| pd.DataFrame(entailed_tokens).groupby(0).max()[1].items() | |
| ) | |
| return entailed_tokens, list(set(prompt_tokens)) | |
| def remove_similar_lemma_from_list(prompt, list_words, nlp): | |
| ## Compute a dictionnary with the lemma for all tokens | |
| ## If there is a duplicate lemma then the dictionnary value will be a list of the corresponding tokens | |
| lemma_dict = {} | |
| for each in list_words: | |
| mask_filled = nlp(prompt.strip(".").format(each)) | |
| lemma_dict.setdefault([x.lemma_ for x in mask_filled][-1], []).append(each) | |
| ## Get back the list of tokens | |
| ## If multiple tokens available then take the shortest one | |
| new_token_list = [] | |
| for key in lemma_dict.keys(): | |
| if len(lemma_dict[key]) >= 1: | |
| new_token_list.append(min(lemma_dict[key], key=len)) | |
| else: | |
| raise ValueError("Lemma dict has 0 corresponding words") | |
| return new_token_list | |
| def filter_prompt_output_by_lemma(prompt, output_prompting, nlp): | |
| """ | |
| Remove all similar lemmas from the prompt output (e.g. "protest", "protests") | |
| """ | |
| list_words = [x["token_str"] for x in output_prompting] | |
| new_token_list = remove_similar_lemma_from_list(prompt, list_words, nlp) | |
| return [x for x in output_prompting if x["token_str"] in new_token_list] | |
| # Streamlit specific run functions | |
| def do_prent(text, template, top_k, nli_limit, additional_words=None): | |
| """Function used to execute PRENT model | |
| :param text: Event text | |
| :type text: string | |
| :param template: Template with mask | |
| :type template: string | |
| :param top_k: Maximum tokens to output from prompting model | |
| :type top_k: int | |
| :param nli_limit: Threshold of entailment for NLI [0,1] | |
| :type nli_limit: float | |
| :param additional_words: List of words that bypass prompting and goes directly to NLI, defaults to None | |
| :type additional_words: list, optional | |
| :return: (Results Entailment, Results Prompting) | |
| :rtype: tuple | |
| """ | |
| results_nli, results_pr = prompt_to_nli_batching( | |
| text, | |
| template, | |
| load_model_prompting(), | |
| load_model_nli(), | |
| load_spacy_pipeline(), | |
| top_k=top_k, | |
| nli_limit=nli_limit, | |
| targets=None, | |
| additional_words=additional_words, | |
| remove_lemma=True, | |
| ) | |
| return results_nli, results_pr | |
| def get_additional_words(): | |
| """Extract the additional words from the codebook | |
| :return: list of additional words | |
| :rtype: list | |
| """ | |
| if "add_words" in st.session_state.codebook: | |
| additional_words = st.session_state.codebook["add_words"] | |
| else: | |
| additional_words = None | |
| return additional_words | |
| def run_prent( | |
| text="", templates=[], additional_words=None, progress=True, display_text=True | |
| ): | |
| """Execute PRENT over a list of templates and display streamlit widgets | |
| :param text: Event description, defaults to "" | |
| :type text: str, optional | |
| :param templates: Templates with a mask, defaults to [] | |
| :type templates: list, optional | |
| :param additional_words: List of words to bypass prompting, defaults to None | |
| :type additional_words: list, optional | |
| :param progress: Display or not the progress bar, defaults to True | |
| :type progress: bool, optional | |
| :return: (results of prent, computation time) | |
| :rtype: tuple | |
| """ | |
| # Check if there is any template and event description available | |
| if not templates: | |
| st.warning("Template list is empty. Please add one.") | |
| return None, None | |
| if not text: | |
| st.warning("Event description is empty.") | |
| return None, None | |
| # Display text only when computing | |
| if display_text: | |
| temp_text = st.empty() | |
| temp_text.markdown("**Event Descriptions:** {}".format(text)) | |
| # Start progress bar | |
| if progress: | |
| progress_bar = st.progress(0) | |
| num_prent_call = len(templates) | |
| num_sentences = get_num_sentences_in_list_text([text]) | |
| iter = 0 | |
| t0 = time() | |
| # We set the radio choice of streamlit to Ignore at first | |
| if "accept_reject_text_perm" in st.session_state: | |
| st.session_state["accept_reject_text_perm"] = "Ignore" | |
| res = {} | |
| for template in templates: | |
| template = template.replace("[Z]", "{}") | |
| results_nli, results_pr = do_prent( | |
| text, | |
| template, | |
| top_k=TOP_K, | |
| nli_limit=NLI_LIMIT, | |
| additional_words=additional_words, | |
| ) | |
| # Results_nli contains % of entailment, we only care about the tokens string | |
| res[template] = [x[0] for x in results_nli] | |
| # Update progress bar | |
| iter += 1 | |
| if progress: | |
| progress_bar.progress((1 / num_prent_call) * (iter)) | |
| if display_text: | |
| temp_text.markdown("") | |
| time_comput = (time() - t0) / num_sentences | |
| # This check is done otherwise the time of computation is replaced by the | |
| # time of computation when using cached value | |
| if not time_comput < st.session_state.time_comput / 5: | |
| st.session_state.time_comput = int(time_comput) | |
| # Store some results | |
| res["templates_used"] = templates | |
| res["additional_words_used"] = additional_words | |
| return res, time_comput | |
| ####### Find event types based on codebook and PRENT results | |
| def check_any_conds(cond_any, list_res): | |
| """Function that evaluates the "OR" conditions of the codebook versus the list of filled templates | |
| :param cond_any: List of groundtruth filled templates | |
| :type cond_any: list | |
| :param list_res: A list of the filled templates given by PRENT | |
| :type list_res: list | |
| :return: True if any groundtruth template is inside the list given by PRENT | |
| :rtype: bool | |
| """ | |
| cond_any = list(cond_any) | |
| condition = False | |
| # Return False if there is no any condition | |
| if not cond_any: | |
| return False | |
| for cond in cond_any: | |
| # With the current codebook design, this should never be true. | |
| # Before it was possible to have recursion to check AND conditions inside an OR condition | |
| if isinstance(cond, dict): | |
| condition = check_all_conds(cond["all"], list_res) | |
| else: | |
| # Check lowercase version of templates | |
| if cond.lower() in [x.lower() for x in list_res]: | |
| condition = True | |
| # Exit function as the other templates won't change the outcome | |
| return condition | |
| return condition | |
| def check_all_conds(cond_all, list_res): | |
| """Function that evaluates the "AND" conditions of the codebook versus the list of filled templates | |
| :param cond_all: List of groundtruth filled templates | |
| :type cond_all: list | |
| :param list_res: A list of the filled templates given by PRENT | |
| :type list_res: list | |
| :return: True if all groundtruth template are inside the list given by PRENT | |
| :rtype: bool | |
| """ | |
| cond_all = list(cond_all) | |
| # Return False if there is no all condition | |
| if not cond_all: | |
| return False | |
| # Start bool on True, and put it to false if any template is missing | |
| condition = True | |
| for cond in cond_all: | |
| # With the current codebook design, this should never be true. | |
| # Before it was possible to have recursion to check OR conditions inside an AND condition | |
| if isinstance(cond, dict): | |
| condition = check_any_conds(cond["any"]) | |
| else: | |
| # Check lowercase version of templates | |
| if not (cond.lower() in [x.lower() for x in list_res]): | |
| condition = False | |
| # Exit function as the other templates won't change the outcome | |
| return condition | |
| return condition | |
| def find_event_types(codebook, list_res): | |
| """This function evaluates the codebook and then outputs a list of events types corresponding to the given results of PRENT (list of filled templates). | |
| :param codebook: A codebook in the format given by the dashboard | |
| :type codebook: dict | |
| :param list_res: A list of the filled templates given by PRENT | |
| :type list_res: list | |
| :return: List of event type | |
| :rtype: list | |
| """ | |
| list_event_type = [] | |
| # Iterate over all defined event types | |
| for event_type in codebook["events"]: | |
| code_event = codebook["events"][event_type] | |
| is_not_all_event, is_not_any_event, is_not_event = False, False, False | |
| is_all_event, is_any_event, is_event = False, False, False | |
| # First check if NOT conditions are met | |
| # e.g. a filled template that is contrary to the event is present | |
| if "not_all" in code_event: | |
| cond_all = code_event["not_all"] | |
| if check_all_conds(cond_all, list_res): | |
| is_not_all_event = True | |
| if "not_any" in code_event: | |
| cond_any = code_event["not_any"] | |
| if check_any_conds(cond_any, list_res): | |
| is_not_any_event = True | |
| # Next we need to check if the "not_all" and "not_any" are related | |
| # by an "OR" or "AND". | |
| # This latest case needs special care because one of two list can | |
| # be empty so False | |
| if code_event["not_all_any_rel"] == "AND": | |
| if is_not_all_event and (not code_event["not_any"]): | |
| # If all TRUE and ANY is empty (so false) | |
| is_not_event = True | |
| elif is_not_any_event and (not code_event["not_all"]): | |
| # If any TRUE and ALL is empty (so false) | |
| is_not_event = True | |
| if is_not_all_event and is_not_any_event: | |
| is_not_event = True | |
| elif code_event["not_all_any_rel"] == "OR": | |
| if is_not_all_event or is_not_any_event: | |
| is_not_event = True | |
| # The other checks are not necessary if this is true, so we go | |
| # to the next iteration | |
| if is_not_event: | |
| continue | |
| # Similar to the previous checks but this time we look for templates that should be present | |
| if "all" in code_event: | |
| cond_all = code_event["all"] | |
| ## Then check if All conditions are met, if not exit | |
| if check_all_conds(cond_all, list_res): | |
| is_all_event = True | |
| if "any" in code_event: | |
| ## Finally check if Any conditions is met, if not exit | |
| cond_any = code_event["any"] | |
| if check_any_conds(cond_any, list_res): | |
| is_any_event = True | |
| # This case needs special care because one of two list can | |
| # be empty so False | |
| if code_event["all_any_rel"] == "AND": | |
| if is_all_event and (not code_event["any"]): | |
| # If all TRUE and ANY is empty (so false) | |
| is_event = True | |
| elif is_any_event and (not code_event["all"]): | |
| # If any TRUE and ALL is empty (so false) | |
| is_event = True | |
| elif is_all_event and is_any_event: | |
| is_event = True | |
| elif code_event["all_any_rel"] == "OR": | |
| if is_all_event or is_any_event: | |
| is_event = True | |
| # If all checks are correct, then we can add the event type to the output list | |
| if is_event: | |
| list_event_type.append(event_type) | |
| return list_event_type | |