Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from streamlit_chat import message | |
| def get_pipe(): | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_name = "heegyu/ajoublue-gpt2-medium-dialog" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return model, tokenizer | |
| def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'): | |
| # print("history:", history) | |
| context = [] | |
| for i, text in enumerate(history): | |
| context.append(f"{i % 2}: {text}</s>") | |
| if len(context) > max_context: | |
| context = context[-max_context:] | |
| context = "".join(context) + f"{bot_id}: " | |
| inputs = tokenizer(context, return_tensors="pt") | |
| generation_args = dict( | |
| max_new_tokens=128, | |
| min_length=inputs["input_ids"].shape[1] + 5, | |
| # no_repeat_ngram_size=4, | |
| eos_token_id=2, | |
| do_sample=True, | |
| top_p=0.95, | |
| temperature=1.35, | |
| # repetition_penalty=1.0, | |
| early_stopping=True | |
| ) | |
| outputs = model.generate(**inputs, **generation_args) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| print("Context:", tokenizer.decode(inputs["input_ids"][0])) | |
| print("Response:", response) | |
| response = response[len(context):].replace("</s>", "").replace("\n", "") | |
| response = response.split("<s>")[0] | |
| # print("Response:", response) | |
| return response | |
| st.title("ajoublue-gpt2-medium νκ΅μ΄ λν λͺ¨λΈ demo") | |
| with st.spinner("loading model..."): | |
| model, tokenizer = get_pipe() | |
| if 'message_history' not in st.session_state: | |
| st.session_state.message_history = [] | |
| history = st.session_state.message_history | |
| # print(st.session_state.message_history) | |
| for i, message_ in enumerate(st.session_state.message_history): | |
| message(message_,is_user=i % 2 == 0, key=i) # display all the previous message | |
| # placeholder = st.empty() # placeholder for latest message | |
| input_ = st.text_input("μ무 λ§μ΄λ ν΄λ³΄μΈμ", value="") | |
| if input_ is not None and len(input_) > 0: | |
| if len(history) <= 1 or history[-2] != input_: | |
| with st.spinner("λλ΅μ μμ±μ€μ λλ€..."): | |
| st.session_state.message_history.append(input_) | |
| response = get_response(tokenizer, model, history) | |
| st.session_state.message_history.append(response) | |
| st.experimental_rerun() |