Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import json | |
| from functools import lru_cache | |
| def convert_sentence_to_json(sentence): | |
| if "_" in sentence: | |
| prefix, rest = sentence.split("_", 1) | |
| query, rest = rest.split("_", 1) | |
| query_index = len(prefix.rstrip().split(" ")) | |
| else: | |
| query, query_index = None, None | |
| prefix, rest = sentence.split("[", 1) | |
| pronoun, rest = rest.split("]", 1) | |
| pronoun_index = len(prefix.rstrip().split(" ")) | |
| sentence = sentence.replace("_", "").replace("[", "").replace("]", "") | |
| return { | |
| "idx": 0, | |
| "text": sentence, | |
| "target": { | |
| "span1_index": query_index, | |
| "span1_text": query, | |
| "span2_index": pronoun_index, | |
| "span2_text": pronoun, | |
| }, | |
| } | |
| def extended_noun_chunks(sentence): | |
| noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks} | |
| np_start, cur_np = 0, "NONE" | |
| for i, token in enumerate(sentence): | |
| np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE" | |
| if np_type != cur_np: | |
| if cur_np != "NONE": | |
| noun_chunks.add((np_start, i)) | |
| if np_type != "NONE": | |
| np_start = i | |
| cur_np = np_type | |
| if cur_np != "NONE": | |
| noun_chunks.add((np_start, len(sentence))) | |
| return [sentence[s:e] for (s, e) in sorted(noun_chunks)] | |
| def find_token(sentence, start_pos): | |
| found_tok = None | |
| for tok in sentence: | |
| if tok.idx == start_pos: | |
| found_tok = tok | |
| break | |
| return found_tok | |
| def find_span(sentence, search_text, start=0): | |
| search_text = search_text.lower() | |
| for tok in sentence[start:]: | |
| remainder = sentence[tok.i :].text.lower() | |
| if remainder.startswith(search_text): | |
| len_to_consume = len(search_text) | |
| start_idx = tok.idx | |
| for next_tok in sentence[tok.i :]: | |
| end_idx = next_tok.idx + len(next_tok.text) | |
| if end_idx - start_idx == len_to_consume: | |
| span = sentence[tok.i : next_tok.i + 1] | |
| return span | |
| return None | |
| def get_detokenizer(): | |
| from sacremoses import MosesDetokenizer | |
| detok = MosesDetokenizer(lang="en") | |
| return detok | |
| def get_spacy_nlp(): | |
| import en_core_web_lg | |
| nlp = en_core_web_lg.load() | |
| return nlp | |
| def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False): | |
| detok = get_detokenizer() | |
| nlp = get_spacy_nlp() | |
| with open(input_fname) as fin: | |
| for line in fin: | |
| sample = json.loads(line.strip()) | |
| if positive_only and "label" in sample and not sample["label"]: | |
| # only consider examples where the query is correct | |
| continue | |
| target = sample["target"] | |
| # clean up the query | |
| query = target["span1_text"] | |
| if query is not None: | |
| if "\n" in query: | |
| continue | |
| if query.endswith(".") or query.endswith(","): | |
| query = query[:-1] | |
| # split tokens | |
| tokens = sample["text"].split(" ") | |
| def strip_pronoun(x): | |
| return x.rstrip('.,"') | |
| # find the pronoun | |
| pronoun_idx = target["span2_index"] | |
| pronoun = strip_pronoun(target["span2_text"]) | |
| if strip_pronoun(tokens[pronoun_idx]) != pronoun: | |
| # hack: sometimes the index is misaligned | |
| if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun: | |
| pronoun_idx += 1 | |
| else: | |
| raise Exception("Misaligned pronoun!") | |
| assert strip_pronoun(tokens[pronoun_idx]) == pronoun | |
| # split tokens before and after the pronoun | |
| before = tokens[:pronoun_idx] | |
| after = tokens[pronoun_idx + 1 :] | |
| # the GPT BPE attaches leading spaces to tokens, so we keep track | |
| # of whether we need spaces before or after the pronoun | |
| leading_space = " " if pronoun_idx > 0 else "" | |
| trailing_space = " " if len(after) > 0 else "" | |
| # detokenize | |
| before = detok.detokenize(before, return_str=True) | |
| pronoun = detok.detokenize([pronoun], return_str=True) | |
| after = detok.detokenize(after, return_str=True) | |
| # hack: when the pronoun ends in a period (or comma), move the | |
| # punctuation to the "after" part | |
| if pronoun.endswith(".") or pronoun.endswith(","): | |
| after = pronoun[-1] + trailing_space + after | |
| pronoun = pronoun[:-1] | |
| # hack: when the "after" part begins with a comma or period, remove | |
| # the trailing space | |
| if after.startswith(".") or after.startswith(","): | |
| trailing_space = "" | |
| # parse sentence with spacy | |
| sentence = nlp(before + leading_space + pronoun + trailing_space + after) | |
| # find pronoun span | |
| start = len(before + leading_space) | |
| first_pronoun_tok = find_token(sentence, start_pos=start) | |
| pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i) | |
| assert pronoun_span.text == pronoun | |
| if eval: | |
| # convert to format where pronoun is surrounded by "[]" and | |
| # query is surrounded by "_" | |
| query_span = find_span(sentence, query) | |
| query_with_ws = "_{}_{}".format( | |
| query_span.text, | |
| (" " if query_span.text_with_ws.endswith(" ") else ""), | |
| ) | |
| pronoun_with_ws = "[{}]{}".format( | |
| pronoun_span.text, | |
| (" " if pronoun_span.text_with_ws.endswith(" ") else ""), | |
| ) | |
| if query_span.start < pronoun_span.start: | |
| first = (query_span, query_with_ws) | |
| second = (pronoun_span, pronoun_with_ws) | |
| else: | |
| first = (pronoun_span, pronoun_with_ws) | |
| second = (query_span, query_with_ws) | |
| sentence = ( | |
| sentence[: first[0].start].text_with_ws | |
| + first[1] | |
| + sentence[first[0].end : second[0].start].text_with_ws | |
| + second[1] | |
| + sentence[second[0].end :].text | |
| ) | |
| yield sentence, sample.get("label", None) | |
| else: | |
| yield sentence, pronoun_span, query, sample.get("label", None) | |
| def winogrande_jsonl_iterator(input_fname, eval=False): | |
| with open(input_fname) as fin: | |
| for line in fin: | |
| sample = json.loads(line.strip()) | |
| sentence, option1, option2 = ( | |
| sample["sentence"], | |
| sample["option1"], | |
| sample["option2"], | |
| ) | |
| pronoun_span = (sentence.index("_"), sentence.index("_") + 1) | |
| if eval: | |
| query, cand = option1, option2 | |
| else: | |
| query = option1 if sample["answer"] == "1" else option2 | |
| cand = option2 if sample["answer"] == "1" else option1 | |
| yield sentence, pronoun_span, query, cand | |
| def filter_noun_chunks( | |
| chunks, exclude_pronouns=False, exclude_query=None, exact_match=False | |
| ): | |
| if exclude_pronouns: | |
| chunks = [ | |
| np | |
| for np in chunks | |
| if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np)) | |
| ] | |
| if exclude_query is not None: | |
| excl_txt = [exclude_query.lower()] | |
| filtered_chunks = [] | |
| for chunk in chunks: | |
| lower_chunk = chunk.text.lower() | |
| found = False | |
| for excl in excl_txt: | |
| if ( | |
| not exact_match and (lower_chunk in excl or excl in lower_chunk) | |
| ) or lower_chunk == excl: | |
| found = True | |
| break | |
| if not found: | |
| filtered_chunks.append(chunk) | |
| chunks = filtered_chunks | |
| return chunks | |