Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| from pyserini.search.faiss import ( | |
| AutoQueryEncoder, | |
| AnceQueryEncoder, | |
| DprQueryEncoder, | |
| TctColBertQueryEncoder, | |
| ) | |
| def _init_encoder_from_str(encoder, device="cpu"): | |
| encoder_lower = encoder.lower() | |
| if "dpr" in encoder_lower: | |
| return DprQueryEncoder(encoder_dir=encoder, device=device) | |
| elif "tct_colbert" in encoder_lower: | |
| return TctColBertQueryEncoder(encoder_dir=encoder, device=device) | |
| elif "ance" in encoder_lower: | |
| return AnceQueryEncoder(encoder_dir=encoder, device=device) | |
| elif "sentence" in encoder_lower: | |
| return AutoQueryEncoder( | |
| encoder_dir=encoder, pooling="mean", l2_norm=True, device=device | |
| ) | |
| else: | |
| return AutoQueryEncoder(encoder_dir=encoder, device=device) | |
| def load_index(searcher_class, index_dir, query_encoder=None): | |
| if query_encoder is not None: | |
| searcher = searcher_class(index_dir=index_dir, query_encoder=query_encoder) | |
| else: | |
| searcher = searcher_class(index_dir=index_dir) | |
| return searcher | |
| class OnlineSearcher(object): | |
| def __init__(self, args): | |
| self.args = args | |
| if args.index_type == "sparse": | |
| query_encoder = None | |
| elif args.index_type == "dense" or args.index_type == "hybrid": | |
| query_encoder = _init_encoder_from_str( | |
| encoder=args.encoder, device=args.device | |
| ) | |
| else: | |
| raise ValueError( | |
| f"index_type {args.index_type} should be chosen among sparse, dense, or hybrid" | |
| ) | |
| # load index | |
| if args.index_type == "hybrid": | |
| args.index = args.index.split(",") | |
| assert ( | |
| len(args.index) == 2 | |
| ), "require both sparse and dense index delimited by comma" | |
| from pyserini.search.lucene import LuceneSearcher | |
| self.ssearcher = load_index( | |
| searcher_class=LuceneSearcher, index_dir=args.index[0] | |
| ) | |
| self.ssearcher.set_language(args.lang_abbr) | |
| from pyserini.search.faiss import FaissSearcher | |
| self.dsearcher = load_index( | |
| searcher_class=FaissSearcher, | |
| index_dir=args.index[1], | |
| query_encoder=query_encoder, | |
| ) | |
| from pyserini.search.hybrid import HybridSearcher | |
| self.searcher = HybridSearcher(self.dsearcher, self.ssearcher) | |
| print(f"load {self.ssearcher.num_docs} documents from {args.index}") | |
| else: | |
| if args.index_type == "sparse": | |
| from pyserini.search.lucene import LuceneSearcher as Searcher | |
| elif args.index_type == "dense": | |
| from pyserini.search.faiss import FaissSearcher as Searcher | |
| self.searcher = load_index( | |
| searcher_class=Searcher, | |
| index_dir=args.index, | |
| query_encoder=query_encoder, | |
| ) | |
| if args.index_type == "sparse": | |
| self.searcher.set_language(args.lang_abbr) | |
| print(f"load {self.searcher.num_docs} documents from {args.index}") | |
| def search(self, query, k=10): | |
| if self.args.index_type == "hybrid": | |
| hits = self.searcher.search( | |
| query, alpha=self.args.alpha, normalization=self.args.normalization, k=k | |
| ) | |
| else: | |
| hits = self.searcher.search(query) | |
| return hits | |
| def print_result(self, hits, k): | |
| # Print the first k hits: | |
| docs = [] | |
| for i in range(0, min(k, len(hits))): | |
| print(f"{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}") | |
| if ( | |
| self.args.index_type == "sparse" | |
| ): # faiss searcher does not store document raw text | |
| doc = self.searcher.doc(hits[i].docid) | |
| elif self.args.index_type == "hybrid": | |
| doc = self.searcher.sparse_searcher.doc(hits[i].docid) | |
| else: | |
| doc = None | |
| if doc is not None and not self.args.hide_text: | |
| doc_raw = doc.raw() | |
| docs.append(json.loads(doc_raw)) | |
| print(doc_raw) | |
| docs = "\n\n".join( | |
| [f'문서 {idx+1}\n{doc["contents"]}' for idx, doc in enumerate(docs)] | |
| ) | |
| return docs | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Search interactively") | |
| parser.add_argument( | |
| "--index_type", | |
| type=str, | |
| required=True, | |
| help="choose indexing type", | |
| choices=["sparse", "dense", "hybrid"], | |
| ) | |
| parser.add_argument( | |
| "--index", | |
| type=str, | |
| required=True, | |
| help="Path to index or name of prebuilt index.", | |
| ) | |
| parser.add_argument("--query", type=str, required=True, help="Query text") | |
| parser.add_argument( | |
| "--lang_abbr", | |
| type=str, | |
| required=False, | |
| default="ko", | |
| help="for language specific algorithms for sparse retrieveal)", | |
| ) | |
| parser.add_argument( | |
| "--encoder", type=str, required=False, help="encoder name or checkpoint path" | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| required=False, | |
| default="cpu", | |
| help="device to use for encoding queries (cf. pyserini does not support faiss-gpu)", | |
| ) | |
| # for hybrid search | |
| parser.add_argument( | |
| "--alpha", | |
| type=float, | |
| default=0.5, | |
| help="weight for hybrid search: alpha*score(sparse) + score(dense)", | |
| ) | |
| parser.add_argument( | |
| "--normalization", | |
| action="store_true", | |
| help="normalize sparse & dens score before fusion", | |
| ) | |
| # search range | |
| parser.add_argument( | |
| "--k", | |
| type=int, | |
| default=10, | |
| help="the number of passages to return (default: 10)", | |
| ) | |
| # print option | |
| parser.add_argument( | |
| "--hide_text", action="store_true", help="do not print if this is true" | |
| ) | |
| args = parser.parse_args() | |
| # make searcher | |
| searcher = OnlineSearcher(args) | |
| print(f"given query: {args.query}") | |
| # search | |
| hits = searcher.search(args.query) | |
| # print results | |
| searcher.print_result(hits, args.k) | |