Spaces:
Runtime error
Runtime error
| # | |
| # Pyserini: Reproducible IR research with sparse and dense representations | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # Example usage | |
| # python -m pyserini.eval.trec_eval -m ndcg_cut.10,20 -m all_trec qrels.dev.small.tsv runs/run.Colbert.txt -remove-unjudged -cutoffs.20,50 | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| import platform | |
| import pandas as pd | |
| import tempfile | |
| from pyserini.search import get_qrels_file | |
| from pyserini.util import download_evaluation_script | |
| script_path = download_evaluation_script('trec_eval') | |
| cmd_prefix = ['java', '-jar', script_path] | |
| args = sys.argv | |
| # Option to discard non-judged hits in run file | |
| judged_docs_only = '' | |
| judged_result = [] | |
| cutoffs = [] | |
| if '-remove-unjudged' in args: | |
| judged_docs_only = args.pop(args.index('-remove-unjudged')) | |
| if any([i.startswith('judged.') for i in args]): | |
| # Find what position the arg is in. | |
| idx = [i.startswith('judged.') for i in args].index(True) | |
| cutoffs = args.pop(idx) | |
| cutoffs = list(map(int, cutoffs[7:].split(','))) | |
| # Get rid of the '-m' before the 'judged.xxx' option | |
| args.pop(idx-1) | |
| temp_file = '' | |
| if len(args) > 1: | |
| if not os.path.exists(args[-2]): | |
| args[-2] = get_qrels_file(args[-2]) | |
| if os.path.exists(args[-1]): | |
| # Convert run to trec if it's on msmarco | |
| with open(args[-1]) as f: | |
| first_line = f.readline() | |
| if 'Q0' not in first_line: | |
| temp_file = tempfile.NamedTemporaryFile(delete=False).name | |
| print('msmarco run detected. Converting to trec...') | |
| run = pd.read_csv(args[-1], delim_whitespace=True, header=None, names=['query_id', 'doc_id', 'rank']) | |
| run['score'] = 1 / run['rank'] | |
| run.insert(1, 'Q0', 'Q0') | |
| run['name'] = 'TEMPRUN' | |
| run.to_csv(temp_file, sep='\t', header=None, index=None) | |
| args[-1] = temp_file | |
| run = pd.read_csv(args[-1], delim_whitespace=True, header=None) | |
| qrels = pd.read_csv(args[-2], delim_whitespace=True, header=None) | |
| # cast doc_id column as string | |
| run[0] = run[0].astype(str) | |
| qrels[0] = qrels[0].astype(str) | |
| # Discard non-judged hits | |
| if judged_docs_only: | |
| if not temp_file: | |
| temp_file = tempfile.NamedTemporaryFile(delete=False).name | |
| judged_indexes = pd.merge(run[[0,2]].reset_index(), qrels[[0,2]], on = [0,2])['index'] | |
| run = run.loc[judged_indexes] | |
| run.to_csv(temp_file, sep='\t', header=None, index=None) | |
| args[-1] = temp_file | |
| # Measure judged@cutoffs | |
| for cutoff in cutoffs: | |
| run_cutoff = run.groupby(0).head(cutoff) | |
| judged = len(pd.merge(run_cutoff[[0,2]], qrels[[0,2]], on = [0,2])) / len(run_cutoff) | |
| metric_name = f'judged_{cutoff}' | |
| judged_result.append(f'{metric_name:22}\tall\t{judged:.4f}') | |
| cmd = cmd_prefix + args[1:] | |
| else: | |
| cmd = cmd_prefix | |
| print(f'Running command: {cmd}') | |
| shell = platform.system() == "Windows" | |
| process = subprocess.Popen(cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| shell=shell) | |
| stdout, stderr = process.communicate() | |
| if stderr: | |
| print(stderr.decode("utf-8")) | |
| print('Results:') | |
| print(stdout.decode("utf-8").rstrip()) | |
| for judged in judged_result: | |
| print(judged) | |
| if temp_file: | |
| os.remove(temp_file) | |