import gradio as gr import os from smolagents import Tool, CodeAgent, TransformersModel, stream_to_gradio, HfApiModel import spaces from dotenv import load_dotenv import datasets from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import HuggingFaceDatasetLoader import chromadb from chromadb.utils import embedding_functions """ sample questions pass me some fun general facts from the retreiver """ load_dotenv() @spaces.GPU def dummy(): pass class RetrieverTool(Tool): """Since we need to add a vectordb as an attribute of the tool, we cannot simply use the simple tool constructor with a @tool decorator Used bm25 retrival method because it is fast. For more accuracy in retrival, you can replace it with semantic search using vector representations for documents. check out MTEB Leaderboard for accuracy ranking """ name = "retriever" description = """Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query. Afterwards, this tool summaries the findings from the extracted document """ inputs = { "query": { "type": "string", "description": "The python list of queries to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", } } output_type = "string" def __init__(self, docs: list[Document], **kwargs): super().__init__(**kwargs) chroma_data_path = "chroma_data/" if not os.path.isdir(chroma_data_path): print("in if clause") os.makedirs(chroma_data_path, exist_ok=True) collection_name = "demo_docs" embedding_func = embedding_functions.DefaultEmbeddingFunction() client = chromadb.PersistentClient(path=chroma_data_path) collection = client.get_or_create_collection( name=collection_name, embedding_function=embedding_func, metadata={"hnsw:space": "cosine"}, ) collection.upsert( documents=[doc.page_content for doc in docs], ids=[f"id{i}" for i in range(len(docs))], ) self.collection = collection def forward(self, query: str) -> str: assert isinstance(query, str), "Your search query must be a string" docs = self.collection.query(query_texts=[query], n_results=5) retrieved_text = "\nRetrieved documents:\n" + "".join( [ f"\n\n===== Document {str(i)} =====\n" + doc for i, doc in zip(docs["ids"][0], docs["documents"][0]) ] ) messages = [ { "role": "user", "content": [ {"type": "text", "text": "summaries this text:" + retrieved_text} ], } ] return retrieved_text + "\n" + model(messages).content if __name__ == "__main__": # knowledge_base = datasets.load_dataset("MuskumPillerum/General-Knowledge", split="train") # print(knowledge_base.column_names) # source_docs = [ # Document( # page_content=doc["Answer"], metadata={"question": doc["Question"]} # ) # for doc in knowledge_base # ] source_docs = HuggingFaceDatasetLoader("MuskumPillerum/General-Knowledge", "Answer").load()[:100] text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=50, add_start_index=True, strip_whitespace=True, separators=["\n\n", "\n", ".", " ", ""], ) docs_processed = text_splitter.split_documents(source_docs) retriever_tool = RetrieverTool(docs_processed) # Not working at the moment # model = TransformersModel( # # model_id="Qwen/Qwen2.5-Coder-32B-Instruct", # device_map="cuda", # model_id="meta-llama/Llama-3.2-3B-Instruct" # ) model = HfApiModel( model_id="meta-llama/Llama-3.2-3B-Instruct", token=os.getenv("my_first_agents_hf_tokens") ) agent = CodeAgent( tools=[retriever_tool], model=model, max_steps=10, verbosity_level=10, ) def enter_message(new_message, conversation_history): conversation_history.append(gr.ChatMessage(role="user", content=new_message)) for msg in stream_to_gradio(agent, new_message): conversation_history.append(msg) yield "", conversation_history def clear_message(chat_history: list): agent.memory.reset() return chat_history.clear(), "" with gr.Blocks() as b: gr.Markdown("# Demo agentic rag on some general knowledge.") chatbot = gr.Chatbot(type="messages", height=1000) textbox = gr.Textbox(lines=1, label="chat message (with default sample question)", value="pass me some fun general facts from the retreiver tool") with gr.Row(): clear_messages_button = gr.ClearButton([textbox, chatbot]) stop_generating_button = gr.Button("stop generating") enter_button = gr.Button("enter") reply_button_click_event = enter_button.click(enter_message, [textbox, chatbot], [textbox, chatbot]) submit_event = textbox.submit(enter_message, [textbox, chatbot], [textbox, chatbot]) clear_messages_button.click(fn=clear_message, inputs=chatbot, outputs=[chatbot, textbox], cancels=[reply_button_click_event, submit_event]) stop_generating_button.click(cancels=[reply_button_click_event, submit_event]) b.launch()