Spaces:
Runtime error
Runtime error
| # | |
| # Pyserini: Reproducible IR research with sparse and dense representations | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import multiprocessing | |
| from multiprocessing.pool import ThreadPool | |
| from tqdm import tqdm | |
| from pyserini.search.lucene import LuceneSearcher | |
| import kilt.kilt_utils as utils | |
| from kilt.retrievers.base_retriever import Retriever | |
| import jnius | |
| from nltk import bigrams, word_tokenize, SnowballStemmer | |
| from nltk.corpus import stopwords | |
| import string | |
| ent_start_token = "[START_ENT]" | |
| ent_end_token = "[END_ENT]" | |
| STOPWORDS = set(stopwords.words('english') + list(string.punctuation)) | |
| stemmer = SnowballStemmer("english") | |
| def parse_hits(hits): | |
| doc_ids = [] | |
| doc_scores = [] | |
| for hit in hits: | |
| wikipedia_id = hit.docid.split('-')[0] | |
| if wikipedia_id and wikipedia_id not in doc_ids: | |
| doc_ids.append(wikipedia_id) | |
| doc_scores.append(hit.score) | |
| return doc_ids, doc_scores | |
| def _get_predictions_thread(arguments): | |
| id = arguments["id"] | |
| queries_data = arguments["queries_data"] | |
| topk = arguments["topk"] | |
| ranker = arguments["ranker"] | |
| logger = arguments["logger"] | |
| use_bigrams = arguments["use_bigrams"] | |
| stem_bigrams = arguments["stem_bigrams"] | |
| if id == 0: | |
| iter_ = tqdm(queries_data) | |
| else: | |
| iter_ = queries_data | |
| result_doc_ids = [] | |
| result_doc_scores = [] | |
| result_query_id = [] | |
| for query_element in iter_: | |
| query = ( | |
| query_element["query"] | |
| .replace(ent_start_token, "") | |
| .replace(ent_end_token, "") | |
| .strip() | |
| ) | |
| result_query_id.append(query_element["id"]) | |
| doc_ids = [] | |
| doc_scores = [] | |
| if use_bigrams: | |
| tokens = filter(lambda word: word.lower() not in STOPWORDS, word_tokenize(query)) | |
| if stem_bigrams: | |
| tokens = map(stemmer.stem, tokens) | |
| bigram_query = bigrams(tokens) | |
| bigram_query = " ".join(["".join(bigram) for bigram in bigram_query]) | |
| query += " " + bigram_query | |
| try: | |
| hits = ranker.search(query, k=topk) | |
| doc_ids, doc_scores = parse_hits(hits) | |
| # doc_ids = [hit.docid for hit in hits] | |
| # doc_scores = [hit.score for hit in hits] | |
| except RuntimeError as e: | |
| if logger: | |
| logger.warning("RuntimeError: {}".format(e)) | |
| except jnius.JavaException as e: | |
| if logger: | |
| logger.warning("{query} jnius.JavaException: {}".format(query_element, e)) | |
| if 'maxClauseCount' in str(e): | |
| query = " ".join(query.split()[:950]) | |
| hits = ranker.search(query, k=topk) | |
| doc_ids, doc_scores = parse_hits(hits) | |
| else: | |
| print(query, str(e)) | |
| raise e | |
| # doc_ids = [hit.docid for hit in hits] | |
| # doc_scores = [hit.score for hit in hits] | |
| except Exception as e: | |
| print(query, str(e)) | |
| raise e | |
| result_doc_ids.append(doc_ids) | |
| result_doc_scores.append(doc_scores) | |
| return result_doc_ids, result_doc_scores, result_query_id | |
| class Anserini(Retriever): | |
| def __init__(self, name, num_threads, index_dir=None, k1=0.9, b=0.4, use_bigrams=False, stem_bigrams=False): | |
| super().__init__(name) | |
| self.num_threads = min(num_threads, int(multiprocessing.cpu_count())) | |
| # initialize a ranker per thread | |
| self.arguments = [] | |
| for id in tqdm(range(self.num_threads)): | |
| ranker = LuceneSearcher(index_dir) | |
| ranker.set_bm25(k1, b) | |
| self.arguments.append( | |
| { | |
| "id": id, | |
| "ranker": ranker, | |
| "use_bigrams": use_bigrams, | |
| "stem_bigrams": stem_bigrams | |
| } | |
| ) | |
| def fed_data(self, queries_data, topk, logger=None): | |
| chunked_queries = utils.chunk_it(queries_data, self.num_threads) | |
| for idx, arg in enumerate(self.arguments): | |
| arg["queries_data"] = chunked_queries[idx] | |
| arg["topk"] = topk | |
| arg["logger"] = logger | |
| def run(self): | |
| pool = ThreadPool(self.num_threads) | |
| results = pool.map(_get_predictions_thread, self.arguments) | |
| all_doc_id = [] | |
| all_doc_scores = [] | |
| all_query_id = [] | |
| provenance = {} | |
| for x in results: | |
| i, s, q = x | |
| all_doc_id.extend(i) | |
| all_doc_scores.extend(s) | |
| all_query_id.extend(q) | |
| for query_id, doc_ids in zip(q, i): | |
| provenance[query_id] = [] | |
| for d_id in doc_ids: | |
| provenance[query_id].append({"wikipedia_id": str(d_id).strip()}) | |
| pool.terminate() | |
| pool.join() | |
| return all_doc_id, all_doc_scores, all_query_id, provenance | |