Spaces:
Paused
Paused
| import gradio as gr | |
| import os | |
| from smolagents import Tool, CodeAgent, TransformersModel, stream_to_gradio, HfApiModel | |
| import spaces | |
| from dotenv import load_dotenv | |
| import datasets | |
| from langchain.docstore.document import Document | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import HuggingFaceDatasetLoader | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| """ | |
| sample questions | |
| pass me some fun general facts from the retreiver | |
| """ | |
| load_dotenv() | |
| def dummy(): | |
| pass | |
| class RetrieverTool(Tool): | |
| """Since we need to add a vectordb as an attribute of the tool, | |
| we cannot simply use the simple tool constructor with a @tool decorator | |
| Used bm25 retrival method because it is fast. | |
| For more accuracy in retrival, you can replace it with semantic search | |
| using vector representations for documents. | |
| check out MTEB Leaderboard for accuracy ranking | |
| """ | |
| name = "retriever" | |
| description = """Uses semantic search to retrieve the parts of transformers documentation | |
| that could be most relevant to answer your query. | |
| Afterwards, this tool summaries the findings from the extracted document | |
| """ | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, docs: list[Document], **kwargs): | |
| super().__init__(**kwargs) | |
| chroma_data_path = "chroma_data/" | |
| if not os.path.isdir(chroma_data_path): | |
| print("in if clause") | |
| os.makedirs(chroma_data_path, exist_ok=True) | |
| collection_name = "demo_docs" | |
| embedding_func = embedding_functions.DefaultEmbeddingFunction() | |
| client = chromadb.PersistentClient(path=chroma_data_path) | |
| collection = client.get_or_create_collection( | |
| name=collection_name, | |
| embedding_function=embedding_func, | |
| metadata={"hnsw:space": "cosine"}, | |
| ) | |
| collection.upsert( | |
| documents=[doc.page_content for doc in docs], | |
| ids=[f"id{i}" for i in range(len(docs))], | |
| ) | |
| self.collection = collection | |
| def forward(self, query: str) -> str: | |
| assert isinstance(query, str), "Your search query must be a string" | |
| docs = self.collection.query(query_texts=[query], n_results=5) | |
| retrieved_text = "\nRetrieved documents:\n" + "".join( | |
| [ | |
| f"\n\n===== Document {str(i)} =====\n" + doc | |
| for i, doc in zip(docs["ids"][0], docs["documents"][0]) | |
| ] | |
| ) | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "summaries this text:" + retrieved_text} | |
| ], | |
| } | |
| ] | |
| return retrieved_text + "\n" + model(messages).content | |
| if __name__ == "__main__": | |
| # knowledge_base = datasets.load_dataset("MuskumPillerum/General-Knowledge", split="train") | |
| # print(knowledge_base.column_names) | |
| # source_docs = [ | |
| # Document( | |
| # page_content=doc["Answer"], metadata={"question": doc["Question"]} | |
| # ) | |
| # for doc in knowledge_base | |
| # ] | |
| source_docs = HuggingFaceDatasetLoader("MuskumPillerum/General-Knowledge", "Answer").load()[:100] | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| add_start_index=True, | |
| strip_whitespace=True, | |
| separators=["\n\n", "\n", ".", " ", ""], | |
| ) | |
| docs_processed = text_splitter.split_documents(source_docs) | |
| retriever_tool = RetrieverTool(docs_processed) | |
| # Not working at the moment | |
| # model = TransformersModel( | |
| # # model_id="Qwen/Qwen2.5-Coder-32B-Instruct", | |
| # device_map="cuda", | |
| # model_id="meta-llama/Llama-3.2-3B-Instruct" | |
| # ) | |
| model = HfApiModel( | |
| model_id="meta-llama/Llama-3.2-3B-Instruct", | |
| token=os.getenv("my_first_agents_hf_tokens") | |
| ) | |
| agent = CodeAgent( | |
| tools=[retriever_tool], | |
| model=model, | |
| max_steps=10, | |
| verbosity_level=10, | |
| ) | |
| def enter_message(new_message, conversation_history): | |
| conversation_history.append(gr.ChatMessage(role="user", content=new_message)) | |
| for msg in stream_to_gradio(agent, new_message): | |
| conversation_history.append(msg) | |
| yield "", conversation_history | |
| def clear_message(chat_history: list): | |
| agent.memory.reset() | |
| return chat_history.clear(), "" | |
| with gr.Blocks() as b: | |
| gr.Markdown("# Demo agentic rag on some general knowledge.") | |
| chatbot = gr.Chatbot(type="messages", height=1000) | |
| textbox = gr.Textbox(lines=1, label="chat message (with default sample question)", | |
| value="pass me some fun general facts from the retreiver tool") | |
| with gr.Row(): | |
| clear_messages_button = gr.ClearButton([textbox, chatbot]) | |
| stop_generating_button = gr.Button("stop generating") | |
| enter_button = gr.Button("enter") | |
| reply_button_click_event = enter_button.click(enter_message, [textbox, chatbot], [textbox, chatbot]) | |
| submit_event = textbox.submit(enter_message, [textbox, chatbot], [textbox, chatbot]) | |
| clear_messages_button.click(fn=clear_message, inputs=chatbot, outputs=[chatbot, textbox], | |
| cancels=[reply_button_click_event, submit_event]) | |
| stop_generating_button.click(cancels=[reply_button_click_event, submit_event]) | |
| b.launch() | |