Spaces:
Runtime error
Runtime error
| # NOTE: This code is taken from the original KILT library's retrieval evaluation script | |
| # https://github.com/facebookresearch/KILT/blob/9bcb119a7ed5fda88826058b062d0e45c726c676/kilt/eval_retrieval.py | |
| import argparse | |
| import pprint | |
| import json | |
| from collections import defaultdict, OrderedDict | |
| import os | |
| from pyserini.query_iterator import KiltQueryIterator | |
| ########################################################################################## | |
| # Replaced: | |
| # from kilt import kilt_utils | |
| # With the following directly imported code: | |
| def load_data(filename): | |
| data = [] | |
| with open(filename, "r") as fin: | |
| lines = fin.readlines() | |
| for line in lines: | |
| data.append(json.loads(line)) | |
| return data | |
| ########################################################################################## | |
| # Replaced: | |
| # from kilt import eval_downstream | |
| # With the following directly imported code: | |
| def validate_input(gold_records, guess_records): | |
| if len(gold_records) != len(guess_records): | |
| print( | |
| "WARNING: DIFFERENT SIZE gold: {} guess: {}".format( | |
| len(gold_records), len(guess_records) | |
| ) | |
| ) | |
| # align order | |
| gold_ids = [] | |
| for gold in gold_records: | |
| assert str(gold["id"]).strip() not in gold_ids, "Gold IDs should be unique" | |
| gold_ids.append(str(gold["id"]).strip()) | |
| id2guess_record = {} | |
| for guess in guess_records: | |
| assert ( | |
| str(guess["id"]).strip() not in id2guess_record | |
| ), "Prediction IDs should be unique" | |
| id2guess_record[str(guess["id"]).strip()] = guess | |
| guess_records = [] | |
| for id in gold_ids: | |
| if id in id2guess_record: | |
| guess_records.append(id2guess_record[id]) | |
| else: | |
| raise ValueError("ERROR: no prediction provided for id: {}".format(id)) | |
| return gold_records, guess_records | |
| ########################################################################################## | |
| def _remove_duplicates(obj): | |
| obj_tmp = [] | |
| for o in obj: | |
| if o not in obj_tmp: | |
| obj_tmp.append(o) | |
| return obj_tmp | |
| def _get_ids_list(datapoint, rank_keys, verbose=False): | |
| # collect all gold ids | |
| ids_list = [] | |
| for output in datapoint["output"]: | |
| current_ids_list = [] | |
| if "provenance" in output: | |
| for provenance in output["provenance"]: | |
| if any(rank_key not in provenance for rank_key in rank_keys): | |
| missing = set(rank_keys) - set( | |
| list(provenance.keys()) | |
| ).intersection(set(rank_keys)) | |
| if verbose: | |
| print( | |
| f"WARNING: missing key(s) {missing} in provenance, unable to compute retrieval for those." | |
| ) | |
| else: | |
| current_ids_list.append( | |
| "+".join( | |
| [ | |
| str(provenance[rank_key]).strip() | |
| for rank_key in rank_keys | |
| ] | |
| ) | |
| ) | |
| ids_list.append(_remove_duplicates(current_ids_list)) # remove duplicates | |
| # consider only unique ids | |
| return ids_list | |
| def get_rank(guess_item, gold_item, k, rank_keys, verbose=False): | |
| """ | |
| The main idea is to consider each evidence set as a single point in the rank. | |
| The score in the rank for an evidence set is given by the lowest scored evidence in the set. | |
| """ | |
| assert k > 0, "k must be a positive integer grater than 0." | |
| rank = [] | |
| num_distinct_evidence_sets = 0 | |
| guess_ids = _get_ids_list(guess_item, rank_keys)[0] | |
| if guess_ids and len(guess_ids) > 0: | |
| # 1. collect evidence sets and their sizes | |
| evidence_sets = [] | |
| e_size = defaultdict(int) | |
| for output in gold_item["output"]: | |
| if "provenance" in output: | |
| e_set = { | |
| "+".join( | |
| [str(provenance[rank_key]).strip() for rank_key in rank_keys] | |
| ) | |
| for provenance in output["provenance"] | |
| } | |
| if e_set not in evidence_sets: # no duplicate evidence set | |
| evidence_sets.append(e_set) | |
| e_size[len(e_set)] += 1 | |
| num_distinct_evidence_sets = len(evidence_sets) | |
| # 2. check what's the minimum number of predicted pages needed to get a robust P/R@k | |
| min_prediction_size = 0 | |
| c = 0 | |
| for size, freq in sorted(e_size.items(), reverse=True): | |
| for _ in range(freq): | |
| min_prediction_size += size | |
| c += 1 | |
| if c == k: | |
| break | |
| if c == k: | |
| break | |
| # if the number of evidence sets is smaller than k | |
| min_prediction_size += k - c | |
| if verbose and len(guess_ids) < min_prediction_size: | |
| print( | |
| f"WARNING: you should provide at least {min_prediction_size} provenance items for a robust recall@{k} computation (you provided {len(guess_ids)} item(s))." | |
| ) | |
| # 3. rank by gruping pages in each evidence set (each evidence set count as 1), | |
| # the position in the rank of each evidence set is given by the last page in guess_ids | |
| # non evidence pages counts as 1 | |
| rank = [] | |
| for guess_id in guess_ids: | |
| guess_id = str(guess_id).strip() | |
| found = False | |
| for idx, e_set in enumerate(evidence_sets): | |
| e_set_id = f"evidence_set:{idx}" | |
| if guess_id in e_set: | |
| found = True | |
| # remove from the rank previous points referring to this evidence set | |
| if e_set_id in rank: | |
| rank.remove(e_set_id) | |
| # remove the guess_id from the evidence set | |
| e_set.remove(guess_id) | |
| if len(e_set) == 0: | |
| # it was the last evidence, it counts as true in the rank | |
| rank.append(True) | |
| else: | |
| # add a point for this partial evidence set | |
| rank.append(e_set_id) | |
| if not found: | |
| rank.append(False) | |
| return rank, num_distinct_evidence_sets | |
| # 1. Precision computation | |
| def _precision_at_k(rank, k): | |
| # precision @ k | |
| p = rank[:k].count(True) / k | |
| return p | |
| # 2. Recall computation | |
| def _recall_at_k(rank, num_distinct_evidence_sets, k): | |
| r = rank[:k].count(True) / num_distinct_evidence_sets | |
| return r | |
| # 3. Success rate computation | |
| def _success_rate_at_k(rank, k): | |
| # success rate @ k | |
| p = int(True in rank[:k]) | |
| return p | |
| def _computeRprec(guess_ids, gold_ids): | |
| R = len(gold_ids) | |
| num = 0 | |
| for prediction in guess_ids[:R]: | |
| if str(prediction).strip() in gold_ids: | |
| num += 1 | |
| Rprec = num / R if R > 0 else 0 | |
| return Rprec | |
| # R-precision https://link.springer.com/referenceworkentry/10.1007%2F978-0-387-39940-9_486 | |
| def rprecision(guess_item, gold_item, rank_keys): | |
| gold_ids_list = _get_ids_list(gold_item, rank_keys) | |
| guess_ids = _get_ids_list(guess_item, rank_keys)[0] | |
| Rprec_vector = [] | |
| for gold_ids in gold_ids_list: | |
| Rprec = _computeRprec(guess_ids, gold_ids) | |
| Rprec_vector.append(Rprec) | |
| return max(Rprec_vector) | |
| def get_ranking_metrics(guess_item, gold_item, ks, rank_keys): | |
| Rprec = 0 | |
| P_at_k = {"precision@{}".format(k): 0 for k in sorted(ks) if k > 0} | |
| R_at_k = {"recall@{}".format(k): 0 for k in sorted(ks) if k > 1} | |
| S_at_k = {"success_rate@{}".format(k): 0 for k in sorted(ks) if k > 1} | |
| assert ( | |
| "output" in guess_item and len(guess_item["output"]) == 1 | |
| ), f"guess should provide exactly one output for {guess_item['id']}" | |
| Rprec = rprecision(guess_item, gold_item, rank_keys=rank_keys) | |
| for k in ks: | |
| # 0. get rank | |
| rank, num_distinct_evidence_sets = get_rank( | |
| guess_item, gold_item, k, rank_keys=rank_keys | |
| ) | |
| if num_distinct_evidence_sets > 0: | |
| # 1. precision | |
| P_at_k["precision@{}".format(k)] = _precision_at_k(rank, k) | |
| # 2. recall | |
| R_at_k["recall@{}".format(k)] = _recall_at_k( | |
| rank, num_distinct_evidence_sets, k | |
| ) | |
| # 3. success rate | |
| S_at_k["success_rate@{}".format(k)] = _success_rate_at_k(rank, k) | |
| # else: | |
| # print( | |
| # "WARNING: the number of distinct evidence sets is 0 for {}".format( | |
| # gold_item | |
| # ) | |
| # ) | |
| return {"Rprec": Rprec, **P_at_k, **R_at_k, **S_at_k} | |
| def compute(gold_dataset, guess_dataset, ks, rank_keys): | |
| ks = sorted([int(x) for x in ks]) | |
| result = OrderedDict() | |
| result["Rprec"] = 0.0 | |
| for k in ks: | |
| if k > 0: | |
| result["precision@{}".format(k)] = 0.0 | |
| if k > 1: | |
| result["recall@{}".format(k)] = 0.0 | |
| result["success_rate@{}".format(k)] = 0.0 | |
| assert len(guess_dataset) == len( | |
| gold_dataset | |
| ), "different size gold: {} guess: {}".format(len(guess_dataset), len(gold_dataset)) | |
| for gold, guess in zip(guess_dataset, gold_dataset): | |
| assert ( | |
| str(gold["id"]).strip() == str(guess["id"]).strip() | |
| ), "Items must have same order with same IDs" | |
| for guess_item, gold_item in zip(guess_dataset, gold_dataset): | |
| ranking_metrics = get_ranking_metrics(guess_item, gold_item, ks, rank_keys) | |
| result["Rprec"] += ranking_metrics["Rprec"] | |
| for k in ks: | |
| if k > 0: | |
| result["precision@{}".format(k)] += ranking_metrics[ | |
| "precision@{}".format(k) | |
| ] | |
| if k > 1: | |
| result["recall@{}".format(k)] += ranking_metrics["recall@{}".format(k)] | |
| result["success_rate@{}".format(k)] += ranking_metrics[ | |
| "success_rate@{}".format(k) | |
| ] | |
| if len(guess_dataset) > 0: | |
| result["Rprec"] /= len(guess_dataset) | |
| for k in ks: | |
| if k > 0: | |
| result["precision@{}".format(k)] /= len(guess_dataset) | |
| if k > 1: | |
| result["recall@{}".format(k)] /= len(guess_dataset) | |
| result["success_rate@{}".format(k)] /= len(guess_dataset) | |
| return result | |
| def evaluate(gold, guess, ks, rank_keys): | |
| pp = pprint.PrettyPrinter(indent=4) | |
| gold_dataset = load_data(gold) | |
| guess_dataset = load_data(guess) | |
| # 0. validate input | |
| gold_dataset, guess_dataset = validate_input( | |
| gold_dataset, guess_dataset | |
| ) | |
| # 1. get retrieval metrics | |
| result = compute(gold_dataset, guess_dataset, ks, rank_keys) | |
| pp.pprint(result) | |
| return result | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("guess", help="Guess KILT file") | |
| parser.add_argument("gold", help="Gold KILT file") | |
| parser.add_argument( | |
| "--ks", | |
| type=str, | |
| required=False, | |
| default="1,5,10,20", | |
| help="Comma separated list of positive integers for recall@k and precision@k", | |
| ) | |
| parser.add_argument( | |
| "--rank_keys", | |
| type=str, | |
| required=False, | |
| default="wikipedia_id", | |
| help="Comma separated list of rank keys for recall@k and precision@k", | |
| ) | |
| args = parser.parse_args() | |
| args.ks = [int(k) for k in args.ks.split(",")] | |
| args.rank_keys = [rank_key for rank_key in args.rank_keys.split(",")] | |
| ########################################################################################## | |
| # Pyserini change: | |
| # Download gold file if necessary | |
| gold = args.gold | |
| if not os.path.exists(args.gold): | |
| gold = KiltQueryIterator.download_kilt_topics(gold) | |
| ########################################################################################## | |
| evaluate(gold, args.guess, args.ks, args.rank_keys) | |