Spaces:
Runtime error
Runtime error
| # | |
| # Pyserini: Reproducible IR research with sparse and dense representations | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import heapq | |
| import json | |
| import os | |
| import shutil | |
| import tarfile | |
| import unittest | |
| from random import randint | |
| from urllib.request import urlretrieve | |
| from pyserini import analysis, search | |
| from pyserini.index.lucene import IndexReader | |
| class TestIndexUtilsForLucene8(unittest.TestCase): | |
| # This class contains the test cases from test_index_reader that require Lucene 8 backwards compatibility. | |
| def setUp(self): | |
| # Download pre-built CACM index built using Lucene 8; append a random value to avoid filename clashes. | |
| r = randint(0, 10000000) | |
| self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene8-index.cacm.tar.gz' | |
| self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) | |
| self.index_dir = 'index{}/'.format(r) | |
| _, _ = urlretrieve(self.collection_url, self.tarball_name) | |
| tarball = tarfile.open(self.tarball_name) | |
| tarball.extractall(self.index_dir) | |
| tarball.close() | |
| self.index_path = os.path.join(self.index_dir, 'lucene-index.cacm') | |
| self.searcher = search.LuceneSearcher(self.index_path) | |
| self.index_reader = IndexReader(self.index_path) | |
| self.temp_folders = [] | |
| def test_query_doc_score_default(self): | |
| queries = ['information retrieval', 'databases'] | |
| for query in queries: | |
| hits = self.searcher.search(query) | |
| # We're going to verify that the score of each hit is about the same as the output of | |
| # compute_query_document_score | |
| for i in range(0, len(hits)): | |
| self.assertAlmostEqual(hits[i].score, | |
| self.index_reader.compute_query_document_score(hits[i].docid, query), places=4) | |
| def test_query_doc_score_custom_similarity(self): | |
| custom_bm25 = search.LuceneSimilarities.bm25(0.8, 0.2) | |
| queries = ['information retrieval', 'databases'] | |
| self.searcher.set_bm25(0.8, 0.2) | |
| for query in queries: | |
| hits = self.searcher.search(query) | |
| # We're going to verify that the score of each hit is about the same as the output of | |
| # compute_query_document_score | |
| for i in range(0, len(hits)): | |
| self.assertAlmostEqual(hits[i].score, | |
| self.index_reader.compute_query_document_score( | |
| hits[i].docid, query, similarity=custom_bm25), places=4) | |
| custom_qld = search.LuceneSimilarities.qld(500) | |
| self.searcher.set_qld(500) | |
| for query in queries: | |
| hits = self.searcher.search(query) | |
| # We're going to verify that the score of each hit is about the same as the output of | |
| # compute_query_document_score | |
| for i in range(0, len(hits)): | |
| self.assertAlmostEqual(hits[i].score, | |
| self.index_reader.compute_query_document_score( | |
| hits[i].docid, query, similarity=custom_qld), places=4) | |
| def test_dump_documents_BM25(self): | |
| file_path = 'collections/cacm_documents_bm25_dump.jsonl' | |
| self.index_reader.dump_documents_BM25(file_path) | |
| dump_file = open(file_path, 'r') | |
| num_lines = sum(1 for line in dump_file) | |
| dump_file.seek(0) | |
| assert num_lines == self.index_reader.stats()['documents'] | |
| def compare_searcher(query): | |
| """Comparing searching with LuceneSearcher to brute-force searching through documents in dump | |
| The scores should match. | |
| Parameters | |
| ---------- | |
| query : str | |
| The query for search. | |
| """ | |
| # Search through documents BM25 dump | |
| query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) | |
| heap = [] # heapq implements a min-heap, we can invert the values to have a max-heap | |
| for line in dump_file: | |
| doc = json.loads(line) | |
| score = 0 | |
| for term in query_terms: | |
| if term in doc['vector']: | |
| score += doc['vector'][term] | |
| heapq.heappush(heap, (-1*score, doc['id'])) | |
| dump_file.seek(0) | |
| # Using LuceneSearcher instead | |
| hits = self.searcher.search(query) | |
| for i in range(0, 10): | |
| top = heapq.heappop(heap) | |
| self.assertEqual(hits[i].docid, top[1]) | |
| self.assertAlmostEqual(hits[i].score, -1*top[0], places=3) | |
| compare_searcher('I am interested in articles written either by Prieve or Udo Pooch') | |
| compare_searcher('Performance evaluation and modelling of computer systems') | |
| compare_searcher('Addressing schemes for resources in networks; resource addressing in network operating systems') | |
| dump_file.close() | |
| os.remove(file_path) | |
| def test_quantize_weights(self): | |
| dump_file_path = 'collections/cacm_documents_bm25_dump.jsonl' | |
| quantized_file_path = 'collections/cacm_documents_bm25_dump_quantized.jsonl' | |
| self.index_reader.dump_documents_BM25(dump_file_path) | |
| self.index_reader.quantize_weights(dump_file_path, quantized_file_path) | |
| quantized_weights_file = open(quantized_file_path, 'r') | |
| num_lines = sum(1 for line in quantized_weights_file) | |
| quantized_weights_file.seek(0) | |
| assert num_lines == self.index_reader.stats()['documents'] | |
| def compare_searcher_quantized(query, tolerance=1): | |
| """Comparing searching with LuceneSearcher to brute-force searching through documents in dump | |
| If the weights are quantized the scores will not match but the rankings should still roughly match. | |
| Parameters | |
| ---------- | |
| query : str | |
| The query for search. | |
| tolerance : int | |
| Number of places within which rankings should match i.e. if the ranking of some document with | |
| searching through documents in the dump is 2, then with a tolerance of 1 the ranking of the same | |
| document with Lucene searcher should be between 1-3. | |
| """ | |
| query_terms = self.index_reader.analyze(query, analyzer=analysis.get_lucene_analyzer()) | |
| heap = [] | |
| for line in quantized_weights_file: | |
| doc = json.loads(line) | |
| score = 0 | |
| for term in query_terms: | |
| if term in doc['vector']: | |
| score += doc['vector'][term] | |
| heapq.heappush(heap, (-1*score, doc['id'])) | |
| quantized_weights_file.seek(0) | |
| hits = self.searcher.search(query) | |
| for i in range(0, 10): | |
| top = heapq.heappop(heap) | |
| match_within_tolerance = False | |
| for j in range(tolerance+1): | |
| match_within_tolerance = (i-j >= 0 and hits[i-j].docid == top[1]) or (hits[i+j].docid == top[1]) | |
| if match_within_tolerance: | |
| break | |
| self.assertEqual(match_within_tolerance, True) | |
| compare_searcher_quantized('I am interested in articles written either by Prieve or Udo Pooch') | |
| compare_searcher_quantized('Performance evaluation and modelling of computer systems') | |
| compare_searcher_quantized('Addressing schemes for resources in networks; resource addressing in network operating systems') | |
| quantized_weights_file.close() | |
| os.remove(quantized_file_path) | |
| def tearDown(self): | |
| os.remove(self.tarball_name) | |
| shutil.rmtree(self.index_dir) | |
| for f in self.temp_folders: | |
| shutil.rmtree(f) | |
| if __name__ == '__main__': | |
| unittest.main() | |