Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import streamlit.components.v1 as component | |
| from googletrans import Translator | |
| from model import load_model | |
| # from huggingface_hub import snapshot_download | |
| page = st.sidebar.selectbox("Model ", ["Finetuned on News data", "Pretrained GPT2"]) | |
| translator = Translator() | |
| seed = st.sidebar.text_input('Starting text', 'ආයුබෝවන්') | |
| seq_num = st.sidebar.number_input('Number of sequences to generate ', 1, 20, 5) | |
| max_len = st.sidebar.number_input('Length of a sequence ', 5, 300, 100) | |
| gen_bt = st.sidebar.button('Generate') | |
| def generate(model, tokenizer, seed, seq_num, max_len): | |
| sentences = [] | |
| input_ids = tokenizer.encode(seed, return_tensors='pt') | |
| beam_outputs = model.generate( | |
| input_ids, | |
| do_sample=True, | |
| max_length=max_len, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7, | |
| num_return_sequences=seq_num, | |
| no_repeat_ngram_size=2, | |
| early_stopping=True | |
| ) | |
| for beam_out in beam_outputs: | |
| sentences.append(tokenizer.decode(beam_out, skip_special_tokens=True)) | |
| return sentences | |
| def html(body): | |
| st.markdown(body, unsafe_allow_html=True) | |
| def card_begin_str(Sinhala_sentence): | |
| return ( | |
| "<style>div.card{background-color:#023b1d;border-radius: 5px;box-shadow: 0 4px 8px 0 rgba(0,0,0,0.2);transition: 0.3s;} small{ margin: 5px;}</style>" | |
| '<div class="card">' | |
| '<div class="container">' | |
| f'<small><font color="white">{Sinhala_sentence}</font></small>' | |
| ) | |
| def card_end_str(): | |
| return "</div></div>" | |
| def card(sinhala_sentence, english_sentence): | |
| lines = [card_begin_str(sinhala_sentence), f"<p>{english_sentence}</p>", card_end_str()] | |
| html("".join(lines)) | |
| def br(n): | |
| html(n * "<br>") | |
| def card_html(sinhala_sentence, english_sentence): | |
| with open('./app.css') as f: | |
| css_file = f.read() | |
| return component.html( | |
| f""" | |
| <style>{css_file}</style> | |
| <article class="class_1 bg-white rounded-lg p-4 relative"> | |
| <p class="font-bold items-center text-sm text-primary relative mb-1">{sinhala_sentence}</p> | |
| <div class="flex items-center text-white-400 mb-4"> | |
| <i class="fab fa-google mx-2"></i> | |
| <small class="text-white-400">English Translations are by Google Translate</small> | |
| </div> | |
| <p class="not-italic items-center text-sm text-primary relative mb-4"> | |
| {english_sentence} | |
| </p> | |
| </article> | |
| """ | |
| ) | |
| if page == 'Pretrained GPT2': | |
| st.title('Sinhala Text generation with GPT2') | |
| st.markdown('A simple demo using [Sinhala-gpt2 model](https://huggingface.co/flax-community/Sinhala-gpt2) trained during hf-flax week') | |
| model, tokenizer = load_model('flax-community/Sinhala-gpt2') | |
| if gen_bt: | |
| try: | |
| with st.spinner('Generating...'): | |
| # generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
| # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num) | |
| seqs = generate(model, tokenizer, seed, seq_num, max_len) | |
| st.warning("English sentences were translated by Google Translate.") | |
| for i, seq in enumerate(seqs): | |
| english_sentence = translator.translate(seq, src='si', dest='en').text | |
| # card(seq, english_sentence) | |
| html(card_begin_str(seq)) | |
| st.info(english_sentence) | |
| html(card_end_str()) | |
| except Exception as e: | |
| st.exception(f'Exception: {e}') | |
| else: | |
| st.title('Sinhala Text generation with Finetuned GPT2') | |
| st.markdown('This model has been [finetuned Sinhala-gpt2 model](https://huggingface.co/keshan/sinhala-gpt2-newswire) with 6000 news articles(~12MB)') | |
| model, tokenizer = load_model('keshan/sinhala-gpt2-newswire') | |
| if gen_bt: | |
| try: | |
| with st.spinner('Generating...'): | |
| # generator = pipeline('text-generation', model=model, tokenizer=tokenizer) | |
| # seqs = generator(seed, max_length=max_len, num_return_sequences=seq_num) | |
| seqs = generate(model, tokenizer, seed, seq_num, max_len) | |
| st.warning("English sentences were translated by Google Translate.") | |
| for i, seq in enumerate(seqs): | |
| # st.info(f'Generated sequence {i+1}:') | |
| # st.write(seq) | |
| # st.info(f'English translation (by Google Translation):') | |
| # st.write(translator.translate(seq, src='si', dest='en').text) | |
| english_sentence = translator.translate(seq, src='si', dest='en').text | |
| # card(seq, english_sentence) | |
| html(card_begin_str(seq)) | |
| st.info(english_sentence) | |
| html(card_end_str()) | |
| except Exception as e: | |
| st.exception(f'Exception: {e}') | |
| st.markdown('____________') | |