Spaces:
Runtime error
Runtime error
| # | |
| # Pyserini: Reproducible IR research with sparse and dense representations | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| """ | |
| Most of the tokenization code here is copied from Facebook/DPR & DrQA codebase to avoid adding an extra dependency | |
| """ | |
| import argparse | |
| import copy | |
| import json | |
| import logging | |
| import re | |
| import unicodedata | |
| from tqdm import tqdm | |
| import numpy as np | |
| import os | |
| import regex | |
| import collections | |
| logger = logging.getLogger(__name__) | |
| DIRNAME = os.path.dirname(os.path.abspath(__file__)) | |
| # download dependencies | |
| if not os.path.exists('data/nq-annotations.jsonl'): | |
| ANNOTATIONS_TO_DOWNLOAD = [ | |
| ('https://dl.fbaipublicfiles.com/qaoverlap/data/nq-annotations.jsonl','nq-annotations.jsonl'), | |
| ('https://dl.fbaipublicfiles.com/qaoverlap/data/triviaqa-annotations.jsonl', 'triviaqa-annotations.jsonl'), | |
| ('https://dl.fbaipublicfiles.com/qaoverlap/data/webquestions-annotations.jsonl','webquestions-annotations.jsonl') | |
| ] | |
| for link, dest in ANNOTATIONS_TO_DOWNLOAD: | |
| os.system(f'wget {link} -P data/') | |
| ANNOTATION_PATHS = { | |
| 'tqa': os.path.join(DIRNAME, '../../data/triviaqa-annotations.jsonl'), | |
| 'nq': os.path.join(DIRNAME, '../../data/nq-annotations.jsonl'), | |
| 'webquestions': os.path.join(DIRNAME, '../../data/webquestions-annotations.jsonl'), | |
| } | |
| class Tokens(object): | |
| """A class to represent a list of tokenized text.""" | |
| TEXT = 0 | |
| TEXT_WS = 1 | |
| SPAN = 2 | |
| POS = 3 | |
| LEMMA = 4 | |
| NER = 5 | |
| def __init__(self, data, annotators, opts=None): | |
| self.data = data | |
| self.annotators = annotators | |
| self.opts = opts or {} | |
| def __len__(self): | |
| """The number of tokens.""" | |
| return len(self.data) | |
| def slice(self, i=None, j=None): | |
| """Return a view of the list of tokens from [i, j).""" | |
| new_tokens = copy.copy(self) | |
| new_tokens.data = self.data[i: j] | |
| return new_tokens | |
| def untokenize(self): | |
| """Returns the original text (with whitespace reinserted).""" | |
| return ''.join([t[self.TEXT_WS] for t in self.data]).strip() | |
| def words(self, uncased=False): | |
| """Returns a list of the text of each token | |
| Args: | |
| uncased: lower cases text | |
| """ | |
| if uncased: | |
| return [t[self.TEXT].lower() for t in self.data] | |
| else: | |
| return [t[self.TEXT] for t in self.data] | |
| def offsets(self): | |
| """Returns a list of [start, end) character offsets of each token.""" | |
| return [t[self.SPAN] for t in self.data] | |
| def pos(self): | |
| """Returns a list of part-of-speech tags of each token. | |
| Returns None if this annotation was not included. | |
| """ | |
| if 'pos' not in self.annotators: | |
| return None | |
| return [t[self.POS] for t in self.data] | |
| def lemmas(self): | |
| """Returns a list of the lemmatized text of each token. | |
| Returns None if this annotation was not included. | |
| """ | |
| if 'lemma' not in self.annotators: | |
| return None | |
| return [t[self.LEMMA] for t in self.data] | |
| def entities(self): | |
| """Returns a list of named-entity-recognition tags of each token. | |
| Returns None if this annotation was not included. | |
| """ | |
| if 'ner' not in self.annotators: | |
| return None | |
| return [t[self.NER] for t in self.data] | |
| def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): | |
| """Returns a list of all ngrams from length 1 to n. | |
| Args: | |
| n: upper limit of ngram length | |
| uncased: lower cases text | |
| filter_fn: user function that takes in an ngram list and returns | |
| True or False to keep or not keep the ngram | |
| as_string: return the ngram as a string vs list | |
| """ | |
| def _skip(gram): | |
| if not filter_fn: | |
| return False | |
| return filter_fn(gram) | |
| words = self.words(uncased) | |
| ngrams = [(s, e + 1) | |
| for s in range(len(words)) | |
| for e in range(s, min(s + n, len(words))) | |
| if not _skip(words[s:e + 1])] | |
| # Concatenate into strings | |
| if as_strings: | |
| ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] | |
| return ngrams | |
| def entity_groups(self): | |
| """Group consecutive entity tokens with the same NER tag.""" | |
| entities = self.entities() | |
| if not entities: | |
| return None | |
| non_ent = self.opts.get('non_ent', 'O') | |
| groups = [] | |
| idx = 0 | |
| while idx < len(entities): | |
| ner_tag = entities[idx] | |
| # Check for entity tag | |
| if ner_tag != non_ent: | |
| # Chomp the sequence | |
| start = idx | |
| while (idx < len(entities) and entities[idx] == ner_tag): | |
| idx += 1 | |
| groups.append((self.slice(start, idx).untokenize(), ner_tag)) | |
| else: | |
| idx += 1 | |
| return groups | |
| class Tokenizer(object): | |
| """Base tokenizer class. | |
| Tokenizers implement tokenize, which should return a Tokens class. | |
| """ | |
| def tokenize(self, text): | |
| raise NotImplementedError | |
| def shutdown(self): | |
| pass | |
| def __del__(self): | |
| self.shutdown() | |
| class SimpleTokenizer(Tokenizer): | |
| ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' | |
| NON_WS = r'[^\p{Z}\p{C}]' | |
| def __init__(self, **kwargs): | |
| """ | |
| Args: | |
| annotators: None or empty set (only tokenizes). | |
| """ | |
| self._regexp = regex.compile( | |
| '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), | |
| flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE | |
| ) | |
| if len(kwargs.get('annotators', {})) > 0: | |
| logger.warning('%s only tokenizes! Skipping annotators: %s' % | |
| (type(self).__name__, kwargs.get('annotators'))) | |
| self.annotators = set() | |
| def tokenize(self, text): | |
| data = [] | |
| matches = [m for m in self._regexp.finditer(text)] | |
| for i in range(len(matches)): | |
| # Get text | |
| token = matches[i].group() | |
| # Get whitespace | |
| span = matches[i].span() | |
| start_ws = span[0] | |
| if i + 1 < len(matches): | |
| end_ws = matches[i + 1].span()[0] | |
| else: | |
| end_ws = span[1] | |
| # Format data | |
| data.append(( | |
| token, | |
| text[start_ws: end_ws], | |
| span, | |
| )) | |
| return Tokens(data, self.annotators) | |
| def regex_match(text, pattern): | |
| """Test if a regex pattern is contained within a text.""" | |
| try: | |
| pattern = re.compile( | |
| pattern, | |
| flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, | |
| ) | |
| except BaseException: | |
| return False | |
| return pattern.search(text) is not None | |
| def _normalize(text): | |
| return unicodedata.normalize('NFD', text) | |
| def read_jsonl(path): | |
| with open(path) as f: | |
| return [json.loads(l) for l in f] | |
| def read_annotations(annotations_data_path): | |
| return read_jsonl(annotations_data_path) | |
| def has_answers(text, answers, tokenizer, regex=False): | |
| text = _normalize(text) | |
| if regex: | |
| for ans in answers: | |
| ans = _normalize(ans) | |
| if regex_match(text, ans): | |
| return True | |
| else: | |
| text = tokenizer.tokenize(text).words(uncased=True) | |
| for ans in answers: | |
| ans = _normalize(ans) | |
| ans = tokenizer.tokenize(ans).words(uncased=True) | |
| for i in range(0, len(text) - len(ans) + 1): | |
| if ans == text[i: i + len(ans)]: | |
| return True | |
| return False | |
| def evaluate_retrieval(retrieval_file, topk, annotation_file, regex=False): | |
| tokenizer = SimpleTokenizer() | |
| retrieval = json.load(open(retrieval_file)) | |
| annotations = read_annotations(annotation_file) | |
| annotation_ids = {int(a['id']): a['labels'] for a in annotations} | |
| accuracy = { k : collections.defaultdict(list) for k in topk } | |
| max_k = max(topk) | |
| annotation_labels = [ | |
| 'total', | |
| 'no_overlap', | |
| 'question_overlap', | |
| 'no_question_overlap', | |
| 'answer_overlap', | |
| 'no_answer_overlap', | |
| 'answer_overlap_only' | |
| ] | |
| for qid in retrieval.keys(): | |
| answers = retrieval[qid]['answers'] | |
| contexts = retrieval[qid]['contexts'] | |
| has_ans_idx = max_k # first index in contexts that has answers | |
| for idx, ctx in enumerate(contexts): | |
| if idx >= max_k: | |
| break | |
| if 'has_answer' in ctx: | |
| if ctx['has_answer']: | |
| has_ans_idx = idx | |
| break | |
| else: | |
| text = ctx['text'].split('\n')[1] # [0] is title, [1] is text | |
| if has_answers(text, answers, tokenizer, regex): | |
| has_ans_idx = idx | |
| break | |
| for annotation_label in annotation_labels: | |
| if annotation_label in annotation_ids[int(qid)] or annotation_label == 'total' or \ | |
| (annotation_label == 'no_overlap' and ('no_question_overlap' in annotation_ids[int(qid)]) and ('no_answer_overlap' in annotation_ids[int(qid)])): | |
| for k in topk: | |
| accuracy[k][annotation_label].append(0 if has_ans_idx >= k else 1) | |
| for k in topk: | |
| for annotation_label in annotation_labels: | |
| print(f'Top{k}\taccuracy: {np.mean(accuracy[k][annotation_label])} \t {annotation_label}') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--retrieval', type=str, metavar='path', | |
| help="Path to retrieval output file.") | |
| parser.add_argument('--topk', type=int, nargs='+', help="topk to evaluate") | |
| parser.add_argument('--regex', action='store_true', default=False, help="regex match") | |
| parser.add_argument('--dataset_name', choices=['nq', 'tqa', 'webquestions'], type=str, | |
| help='name of datset to evaluate on') | |
| args = parser.parse_args() | |
| evaluate_retrieval(args.retrieval, args.topk, ANNOTATION_PATHS[args.dataset_name], args.regex) | |