Spaces:
Paused
Paused
| import json | |
| import os | |
| import gradio as gr | |
| import requests | |
| from huggingface_hub import HfApi | |
| import traceback | |
| hf_api = HfApi() | |
| roots_datasets = { | |
| dset.id.split("/")[-1]: dset | |
| for dset in hf_api.list_datasets( | |
| author="bigscience-data", use_auth_token=os.environ.get("bigscience_data_token") | |
| ) | |
| } | |
| def get_docid_html(docid): | |
| data_org, dataset, docid = docid.split("/") | |
| metadata = roots_datasets[dataset] | |
| if metadata.private: | |
| docid_html = ( | |
| f"<a " | |
| f'class="underline-on-hover"' | |
| f'title="This dataset is private. See the introductory text for more information"' | |
| f'style="color:#AA4A44;"' | |
| f'href="https://huggingface.co/datasets/bigscience-data/{dataset}"' | |
| f'target="_blank"><b>π{dataset}</b></a><span style="color: #7978FF;">/{docid}</span>' | |
| ) | |
| else: | |
| docid_html = ( | |
| f"<a " | |
| f'class="underline-on-hover"' | |
| f'title="This dataset is licensed {metadata.tags[0].split(":")[-1]}"' | |
| f'style="color:#2D31FA;"' | |
| f'href="https://huggingface.co/datasets/bigscience-data/{dataset}"' | |
| f'target="_blank"><b>{dataset}</b></a><span style="color: #7978FF;">/{docid}</span>' | |
| ) | |
| return docid_html | |
| PII_TAGS = {"KEY", "EMAIL", "USER", "IP_ADDRESS", "ID", "IPv4", "IPv6"} | |
| PII_PREFIX = "PI:" | |
| def process_pii(text): | |
| for tag in PII_TAGS: | |
| text = text.replace( | |
| PII_PREFIX + tag, | |
| """<b><mark style="background: Fuchsia; color: Lime;">REDACTED {}</mark></b>""".format( | |
| tag | |
| ), | |
| ) | |
| return text | |
| def format_meta(result): | |
| meta_html = ( | |
| """ | |
| <p class='underline-on-hover' style='font-size:12px; font-family: Arial; color:#585858; text-align: left;'> | |
| <a href='{}' target='_blank'>{}</a></p>""".format( | |
| result["meta"]["url"], result["meta"]["url"] | |
| ) | |
| if "meta" in result and result["meta"] is not None and "url" in result["meta"] | |
| else "" | |
| ) | |
| docid_html = get_docid_html(result["docid"]) | |
| return """{} | |
| <p style='font-size:14px; font-family: Arial; color:#7978FF; text-align: left;'>Document ID: {}</p> | |
| <p style='font-size:12px; font-family: Arial; color:MediumAquaMarine'>Language: {}</p> | |
| """.format( | |
| meta_html, | |
| docid_html, | |
| result["lang"] if lang in result else None, | |
| ) | |
| return meta_html | |
| def process_results(results, highlight_terms): | |
| if len(results) == 0: | |
| return """<br><p style='font-family: Arial; color:Silver; text-align: center;'> | |
| No results retrieved.</p><br><hr>""" | |
| results_html = "" | |
| for result in results: | |
| tokens = result["text"].split() | |
| tokens_html = [] | |
| for token in tokens: | |
| if token in highlight_terms: | |
| tokens_html.append("<b>{}</b>".format(token)) | |
| else: | |
| tokens_html.append(token) | |
| tokens_html = " ".join(tokens_html) | |
| tokens_html = process_pii(tokens_html) | |
| meta_html = format_meta(result) | |
| meta_html += """ | |
| <p style='font-family: Arial;'>{}</p> | |
| <br> | |
| """.format( | |
| tokens_html | |
| ) | |
| results_html += meta_html | |
| return results_html + "<hr>" | |
| def process_exact_match_payload(payload, query): | |
| datasets = set() | |
| results = payload["results"] | |
| results_html = ( | |
| "<p style='font-family: Arial;'>Total nubmer of results: {}</p>".format( | |
| payload["num_results"] | |
| ) | |
| ) | |
| for result in results: | |
| _, dataset, _ = result["docid"].split("/") | |
| datasets.add(dataset) | |
| text = result["text"] | |
| meta_html = format_meta(result) | |
| query_start = text.find(query) | |
| query_end = query_start + len(query) | |
| tokens_html = text[0:query_start] | |
| tokens_html += "<b>{}</b>".format(text[query_start:query_end]) | |
| tokens_html += text[query_end:] | |
| result_html = ( | |
| meta_html | |
| + """ | |
| <p style='font-family: Arial;'>{}</p> | |
| <br> | |
| """.format( | |
| tokens_html | |
| ) | |
| ) | |
| results_html += result_html | |
| return results_html + "<hr>", list(datasets) | |
| def process_bm25_match_payload(payload, language): | |
| if "err" in payload: | |
| if payload["err"]["type"] == "unsupported_lang": | |
| detected_lang = payload["err"]["meta"]["detected_lang"] | |
| return f""" | |
| <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
| Detected language <b>{detected_lang}</b> is not supported.<br> | |
| Please choose a language from the dropdown or type another query. | |
| </p><br><hr><br>""" | |
| results = payload["results"] | |
| highlight_terms = payload["highlight_terms"] | |
| if language == "detect_language": | |
| return ( | |
| ( | |
| ( | |
| f"""<p style='font-family: Arial; color:MediumAquaMarine; text-align: center; line-height: 3em'> | |
| Detected language: <b>{results[0]["lang"]}</b></p><br><hr><br>""" | |
| if len(results) > 0 and language == "detect_language" | |
| else "" | |
| ) | |
| + process_results(results, highlight_terms) | |
| ), | |
| [], | |
| ) | |
| if language == "all": | |
| datasets = set() | |
| get_docid_html(result["docid"]) | |
| results_html = "" | |
| for lang, results_for_lang in results.items(): | |
| if len(results_for_lang) == 0: | |
| results_html += f"""<p style='font-family: Arial; color:Silver; text-align: left; line-height: 3em'> | |
| No results for language: <b>{lang}</b><hr></p>""" | |
| continue | |
| collapsible_results = f""" | |
| <details> | |
| <summary style='font-family: Arial; color:MediumAquaMarine; text-align: left; line-height: 3em'> | |
| Results for language: <b>{lang}</b><hr> | |
| </summary> | |
| {process_results(results_for_lang, highlight_terms)} | |
| </details>""" | |
| results_html += collapsible_results | |
| for r in results_for_lang: | |
| _, dataset, _ = r["docid"].split("/") | |
| datasets.add(dataset) | |
| return results_html, list(datasets) | |
| datasets = set() | |
| for r in results: | |
| _, dataset, _ = r["docid"].split("/") | |
| datasets.add(dataset) | |
| return process_results(results, highlight_terms), list(datasets) | |
| def scisearch(query, language, num_results=10): | |
| datasets = [] | |
| try: | |
| query = query.strip() | |
| exact_search = False | |
| if query.startswith('"') and query.endswith('"') and len(query) >= 2: | |
| exact_search = True | |
| query = query[1:-1] | |
| else: | |
| query = " ".join(query.split()) | |
| if query == "" or query is None: | |
| return "" | |
| post_data = {"query": query, "k": num_results} | |
| if language != "detect_language": | |
| post_data["lang"] = language | |
| address = ( | |
| "http://34.105.160.81:8080" if exact_search else os.environ.get("address") | |
| ) | |
| output = requests.post( | |
| address, | |
| headers={"Content-type": "application/json"}, | |
| data=json.dumps(post_data), | |
| timeout=60, | |
| ) | |
| payload = json.loads(output.text) | |
| return ( | |
| process_bm25_match_payload(payload, language) | |
| if not exact_search | |
| else process_exact_match_payload(payload, query) | |
| ) | |
| except Exception as e: | |
| results_html = f""" | |
| <p style='font-size:18px; font-family: Arial; color:MediumVioletRed; text-align: center;'> | |
| Raised {type(e).__name__}</p> | |
| <p style='font-size:14px; font-family: Arial; '> | |
| Check if a relevant discussion already exists in the Community tab. If not, please open a discussion. | |
| </p> | |
| """ | |
| print(e) | |
| print(traceback.format_exc()) | |
| return results_html, datasets | |
| def flag(query, language, num_results, issue_description): | |
| try: | |
| post_data = { | |
| "query": query, | |
| "k": num_results, | |
| "flag": True, | |
| "description": issue_description, | |
| } | |
| if language != "detect_language": | |
| post_data["lang"] = language | |
| output = requests.post( | |
| os.environ.get("address"), | |
| headers={"Content-type": "application/json"}, | |
| data=json.dumps(post_data), | |
| timeout=120, | |
| ) | |
| results = json.loads(output.text) | |
| except: | |
| print("Error flagging") | |
| return "" | |
| description = """# <p style="text-align: center;"> πΈ π ROOTS search tool π πΈ </p> | |
| The ROOTS corpus was developed during the [BigScience workshop](https://bigscience.huggingface.co/) for the purpose | |
| of training the Multilingual Large Language Model [BLOOM](https://huggingface.co/bigscience/bloom). This tool allows | |
| you to search through the ROOTS corpus. We serve a BM25 index for each language or group of languages included in | |
| ROOTS. You can read more about the details of the tool design | |
| [here](https://huggingface.co/spaces/bigscience-data/scisearch/blob/main/roots_search_tool_specs.pdf). For more | |
| information and instructions on how to access the full corpus check [this form](https://forms.gle/qyYswbEL5kA23Wu99).""" | |
| if __name__ == "__main__": | |
| demo = gr.Blocks( | |
| css=".underline-on-hover:hover { text-decoration: underline; } .flagging { font-size:12px; color:Silver; }" | |
| ) | |
| with demo: | |
| with gr.Row(): | |
| gr.Markdown(value=description) | |
| with gr.Row(): | |
| query = gr.Textbox( | |
| lines=1, | |
| max_lines=1, | |
| placeholder="Put your query in double quotes for exact search.", | |
| label="Query", | |
| ) | |
| with gr.Row(): | |
| lang = gr.Dropdown( | |
| choices=[ | |
| "ar", | |
| "ca", | |
| "code", | |
| "en", | |
| "es", | |
| "eu", | |
| "fr", | |
| "id", | |
| "indic", | |
| "nigercongo", | |
| "pt", | |
| "vi", | |
| "zh", | |
| "detect_language", | |
| "all", | |
| ], | |
| value="en", | |
| label="Language", | |
| ) | |
| with gr.Row(): | |
| k = gr.Slider(1, 100, value=10, step=1, label="Max Results") | |
| with gr.Row(): | |
| """ | |
| with gr.Column(scale=1): | |
| exact_search = gr.Checkbox( | |
| value=False, label="Exact Search", variant="compact" | |
| ) | |
| """ | |
| with gr.Column(scale=4): | |
| submit_btn = gr.Button("Submit") | |
| with gr.Row(visible=False) as datasets_filter: | |
| available_datasets = gr.Dropdown( | |
| type="value", | |
| choices=["ran", "swam", "ate", "slept"], | |
| label="Datasets", | |
| multiselect=True, | |
| ) | |
| with gr.Row(): | |
| results = gr.HTML(label="Results") | |
| with gr.Column(visible=False) as flagging_form: | |
| flag_txt = gr.Textbox( | |
| lines=1, | |
| placeholder="Type here...", | |
| label="""If you choose to flag your search, we will save the query, language and the number of results | |
| you requested. Please consider adding relevant additional context below:""", | |
| ) | |
| flag_btn = gr.Button("Flag Results") | |
| flag_btn.click(flag, inputs=[query, lang, k, flag_txt], outputs=[flag_txt]) | |
| def submit(query, lang, k, dropdown_input): | |
| print("submitting", query, lang, k) | |
| query = query.strip() | |
| if query is None or query == "": | |
| return "", "" | |
| results_html, datasets = scisearch(query, lang, k) | |
| print(datasets) | |
| return { | |
| results: results_html, | |
| flagging_form: gr.update(visible=True), | |
| datasets_filter: gr.update(visible=True), | |
| available_datasets: gr.Dropdown.update(choices=datasets), | |
| } | |
| def filter_datasets(): | |
| pass | |
| query.submit( | |
| fn=submit, | |
| inputs=[query, lang, k, available_datasets], | |
| outputs=[results, flagging_form, datasets_filter, available_datasets], | |
| ) | |
| submit_btn.click( | |
| submit, | |
| inputs=[query, lang, k, available_datasets], | |
| outputs=[results, flagging_form, datasets_filter, available_datasets], | |
| ) | |
| available_datasets.change(filter_datasets, inputs=[], outputs=[]) | |
| demo.launch(enable_queue=True, debug=True) | |