Spaces:
Runtime error
Runtime error
| # | |
| # Pyserini: Reproducible IR research with sparse and dense representations | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import cmd | |
| import json | |
| import os | |
| import random | |
| from pyserini.search.lucene import LuceneSearcher | |
| from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AnceQueryEncoder | |
| from pyserini.search.hybrid import HybridSearcher | |
| from pyserini import search | |
| class MsMarcoDemo(cmd.Cmd): | |
| dev_topics = list(search.get_topics('msmarco-passage-dev-subset').values()) | |
| ssearcher = LuceneSearcher.from_prebuilt_index('msmarco-passage') | |
| dsearcher = None | |
| hsearcher = None | |
| searcher = ssearcher | |
| k = 10 | |
| prompt = '>>> ' | |
| # https://stackoverflow.com/questions/35213134/command-prefixes-in-python-cli-using-cmd-in-pythons-standard-library | |
| def precmd(self, line): | |
| if line[0] == '/': | |
| line = line[1:] | |
| return line | |
| def do_help(self, arg): | |
| print(f'/help : returns this message') | |
| print(f'/k [NUM] : sets k (number of hits to return) to [NUM]') | |
| print(f'/model [MODEL] : sets encoder to use the model [MODEL] (one of tct, ance)') | |
| print(f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)') | |
| print(f'/random : returns results for a random question from dev subset') | |
| def do_k(self, arg): | |
| print(f'setting k = {int(arg)}') | |
| self.k = int(arg) | |
| def do_mode(self, arg): | |
| if arg == "sparse": | |
| self.searcher = self.ssearcher | |
| elif arg == "dense": | |
| if self.dsearcher is None: | |
| print(f'Specify model through /model before using dense retrieval.') | |
| return | |
| self.searcher = self.dsearcher | |
| elif arg == "hybrid": | |
| if self.hsearcher is None: | |
| print(f'Specify model through /model before using hybrid retrieval.') | |
| return | |
| self.searcher = self.hsearcher | |
| else: | |
| print( | |
| f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].') | |
| return | |
| print(f'setting retriver = {arg}') | |
| def do_model(self, arg): | |
| if arg == "tct": | |
| encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco") | |
| index = "msmarco-passage-tct_colbert-hnsw" | |
| elif arg == "ance": | |
| encoder = AnceQueryEncoder("castorini/ance-msmarco-passage") | |
| index = "msmarco-passage-ance-bf" | |
| else: | |
| print( | |
| f'Model "{arg}" is invalid. Model should be one of [tct, ance].') | |
| return | |
| self.dsearcher = FaissSearcher.from_prebuilt_index( | |
| index, | |
| encoder | |
| ) | |
| self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher) | |
| print(f'setting model = {arg}') | |
| def do_random(self, arg): | |
| q = random.choice(self.dev_topics)['title'] | |
| print(f'question: {q}') | |
| self.default(q) | |
| def do_EOF(self, line): | |
| return True | |
| def default(self, q): | |
| hits = self.searcher.search(q, self.k) | |
| for i in range(0, len(hits)): | |
| raw_doc = None | |
| if isinstance(self.searcher, LuceneSearcher): | |
| raw_doc = hits[i].raw | |
| else: | |
| doc = self.searcher.doc(hits[i].docid) | |
| if doc: | |
| raw_doc = doc.raw() | |
| jsondoc = json.loads(raw_doc) | |
| print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}') | |
| if __name__ == '__main__': | |
| MsMarcoDemo().cmdloop() | |