Spaces:
Runtime error
Runtime error
| # | |
| # Pyserini: Reproducible IR research with sparse and dense representations | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import json | |
| import os | |
| from abc import ABC, abstractmethod | |
| from enum import Enum, unique | |
| from typing import List | |
| from pyserini.search import JLuceneSearcherResult | |
| class OutputFormat(Enum): | |
| TREC = 'trec' | |
| MSMARCO = "msmarco" | |
| KILT = 'kilt' | |
| class OutputWriter(ABC): | |
| def __init__(self, file_path: str, mode: str = 'w', | |
| max_hits: int = 1000, tag: str = None, topics: dict = None, | |
| use_max_passage: bool = False, max_passage_delimiter: str = None, max_passage_hits: int = 100): | |
| self.file_path = file_path | |
| self.mode = mode | |
| self.tag = tag | |
| self.topics = topics | |
| self.use_max_passage = use_max_passage | |
| self.max_passage_delimiter = max_passage_delimiter if use_max_passage else None | |
| self.max_hits = max_passage_hits if use_max_passage else max_hits | |
| self._file = None | |
| def __enter__(self): | |
| dirname = os.path.dirname(self.file_path) | |
| if dirname: | |
| os.makedirs(dirname, exist_ok=True) | |
| self._file = open(self.file_path, self.mode) | |
| return self | |
| def __exit__(self, exc_type, exc_value, exc_traceback): | |
| self._file.close() | |
| def hits_iterator(self, hits: List[JLuceneSearcherResult]): | |
| unique_docs = set() | |
| rank = 1 | |
| for hit in hits: | |
| if self.use_max_passage and self.max_passage_delimiter: | |
| docid = hit.docid.split(self.max_passage_delimiter)[0] | |
| else: | |
| docid = hit.docid.strip() | |
| if self.use_max_passage: | |
| if docid in unique_docs: | |
| continue | |
| unique_docs.add(docid) | |
| yield docid, rank, hit.score, hit | |
| rank = rank + 1 | |
| if rank > self.max_hits: | |
| break | |
| def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
| raise NotImplementedError() | |
| class TrecWriter(OutputWriter): | |
| def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
| for docid, rank, score, _ in self.hits_iterator(hits): | |
| self._file.write(f'{topic} Q0 {docid} {rank} {score:.6f} {self.tag}\n') | |
| class MsMarcoWriter(OutputWriter): | |
| def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
| for docid, rank, score, _ in self.hits_iterator(hits): | |
| self._file.write(f'{topic}\t{docid}\t{rank}\n') | |
| class KiltWriter(OutputWriter): | |
| def write(self, topic: str, hits: List[JLuceneSearcherResult]): | |
| datapoint = self.topics[topic] | |
| provenance = [] | |
| for docid, rank, score, _ in self.hits_iterator(hits): | |
| provenance.append({"wikipedia_id": docid}) | |
| datapoint["output"] = [{"provenance": provenance}] | |
| json.dump(datapoint, self._file) | |
| self._file.write('\n') | |
| def get_output_writer(file_path: str, output_format: OutputFormat, *args, **kwargs) -> OutputWriter: | |
| mapping = { | |
| OutputFormat.TREC: TrecWriter, | |
| OutputFormat.MSMARCO: MsMarcoWriter, | |
| OutputFormat.KILT: KiltWriter, | |
| } | |
| return mapping[output_format](file_path, *args, **kwargs) | |
| def tie_breaker(hits): | |
| return sorted(hits, key=lambda x: (-x.score, x.docid)) | |