Spaces:
Runtime error
Runtime error
| # | |
| # Pyserini: Reproducible IR research with sparse and dense representations | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import json | |
| import os | |
| import faiss | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| class DocumentEncoder: | |
| def encode(self, texts, **kwargs): | |
| pass | |
| def _mean_pooling(last_hidden_state, attention_mask): | |
| token_embeddings = last_hidden_state | |
| input_mask_expanded = ( | |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| ) | |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| return sum_embeddings / sum_mask | |
| class QueryEncoder: | |
| def encode(self, text, **kwargs): | |
| pass | |
| class PcaEncoder: | |
| def __init__(self, encoder, pca_model_path): | |
| self.encoder = encoder | |
| self.pca_mat = faiss.read_VectorTransform(pca_model_path) | |
| def encode(self, text, **kwargs): | |
| if isinstance(text, str): | |
| embeddings = self.encoder.encode(text, **kwargs) | |
| embeddings = self.pca_mat.apply_py(np.array([embeddings])) | |
| embeddings = embeddings[0] | |
| else: | |
| embeddings = self.encoder.encode(text, **kwargs) | |
| embeddings = self.pca_mat.apply_py(embeddings) | |
| return embeddings | |
| class JsonlCollectionIterator: | |
| def __init__(self, collection_path: str, fields=None, delimiter="\n\n"): | |
| if fields: | |
| self.fields = fields | |
| else: | |
| self.fields = ["text"] | |
| self.delimiter = delimiter | |
| self.all_info = self._load(collection_path) | |
| self.size = len(self.all_info["id"]) | |
| self.batch_size = 1 | |
| self.shard_id = 0 | |
| self.shard_num = 1 | |
| def __call__(self, batch_size=1, shard_id=0, shard_num=1): | |
| self.batch_size = batch_size | |
| self.shard_id = shard_id | |
| self.shard_num = shard_num | |
| return self | |
| def __iter__(self): | |
| total_len = self.size | |
| shard_size = int(total_len / self.shard_num) | |
| start_idx = self.shard_id * shard_size | |
| end_idx = min(start_idx + shard_size, total_len) | |
| if self.shard_id == self.shard_num - 1: | |
| end_idx = total_len | |
| to_yield = {} | |
| for idx in tqdm(range(start_idx, end_idx, self.batch_size)): | |
| for key in self.all_info: | |
| to_yield[key] = self.all_info[key][ | |
| idx : min(idx + self.batch_size, end_idx) | |
| ] | |
| yield to_yield | |
| def _parse_fields_from_info(self, info): | |
| """ | |
| :params info: dict, containing all fields as speicifed in self.fields either under | |
| the key of the field name or under the key of 'contents'. If under `contents`, this | |
| function will parse the input contents into each fields based the self.delimiter | |
| return: List, each corresponds to the value of self.fields | |
| """ | |
| n_fields = len(self.fields) | |
| # if all fields are under the key of info, read these rather than 'contents' | |
| if all([field in info for field in self.fields]): | |
| return [info[field].strip() for field in self.fields] | |
| assert "contents" in info, f"contents not found in info: {info}" | |
| contents = info["contents"] | |
| # whether to remove the final self.delimiter (especially \n) | |
| # in CACM, a \n is always there at the end of contents, which we want to remove; | |
| # but in SciFact, Fiqa, and more, there are documents that only have title but not text (e.g. "This is title\n") | |
| # where the trailing \n indicates empty fields | |
| if contents.count(self.delimiter) == n_fields: | |
| # the user appends one more delimiter to the end, we remove it | |
| if contents.endswith(self.delimiter): | |
| # not using .rstrip() as there might be more than one delimiters at the end | |
| contents = contents[: -len(self.delimiter)] | |
| return [field.strip(" ") for field in contents.split(self.delimiter)] | |
| def _load(self, collection_path): | |
| filenames = [] | |
| if os.path.isfile(collection_path): | |
| filenames.append(collection_path) | |
| else: | |
| for filename in os.listdir(collection_path): | |
| filenames.append(os.path.join(collection_path, filename)) | |
| all_info = {field: [] for field in self.fields} | |
| all_info["id"] = [] | |
| for filename in filenames: | |
| with open(filename) as f: | |
| for line_i, line in tqdm(enumerate(f)): | |
| # try: | |
| info = json.loads(line) | |
| _id = info.get("id", info.get("docid", None)) | |
| if _id is None: | |
| raise ValueError( | |
| f"Cannot find 'id' or 'docid' from {filename}." | |
| ) | |
| all_info["id"].append(str(_id)) | |
| fields_info = self._parse_fields_from_info(info) | |
| if len(fields_info) != len(self.fields): | |
| raise ValueError( | |
| f"{len(fields_info)} fields are found at Line#{line_i} in file {filename}." | |
| f"{len(self.fields)} fields expected." | |
| f"Line content: {info['contents']}" | |
| ) | |
| for i in range(len(fields_info)): | |
| all_info[self.fields[i]].append(fields_info[i]) | |
| # except: | |
| # print(f"skip line with error: {line}") | |
| return all_info | |
| class RepresentationWriter: | |
| def __enter__(self): | |
| pass | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| pass | |
| def write(self, batch_info, fields=None): | |
| pass | |
| class JsonlRepresentationWriter(RepresentationWriter): | |
| def __init__(self, dir_path): | |
| self.dir_path = dir_path | |
| self.filename = "embeddings.jsonl" | |
| self.file = None | |
| def __enter__(self): | |
| if not os.path.exists(self.dir_path): | |
| os.makedirs(self.dir_path) | |
| self.file = open(os.path.join(self.dir_path, self.filename), "w") | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.file.close() | |
| def write(self, batch_info, fields=None): | |
| for i in range(len(batch_info["id"])): | |
| contents = "\n".join([batch_info[key][i] for key in fields]) | |
| vector = batch_info["vector"][i] | |
| vector = vector.tolist() if isinstance(vector, np.ndarray) else vector | |
| self.file.write( | |
| json.dumps( | |
| {"id": batch_info["id"][i], "contents": contents, "vector": vector} | |
| ) | |
| + "\n" | |
| ) | |
| class FaissRepresentationWriter(RepresentationWriter): | |
| def __init__(self, dir_path, dimension=768): | |
| self.dir_path = dir_path | |
| self.index_name = "index" | |
| self.id_file_name = "docid" | |
| self.dimension = dimension | |
| self.index = faiss.IndexFlatIP(self.dimension) | |
| self.id_file = None | |
| def __enter__(self): | |
| if not os.path.exists(self.dir_path): | |
| os.makedirs(self.dir_path) | |
| self.id_file = open(os.path.join(self.dir_path, self.id_file_name), "w") | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| self.id_file.close() | |
| faiss.write_index(self.index, os.path.join(self.dir_path, self.index_name)) | |
| def write(self, batch_info, fields=None): | |
| for id_ in batch_info["id"]: | |
| self.id_file.write(f"{id_}\n") | |
| self.index.add(np.ascontiguousarray(batch_info["vector"])) | |