Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import pipeline | |
| from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
| def summarize(data, modelname): | |
| if (modelname == 'Bart'): | |
| summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
| print("world") | |
| output = summarizer(data, max_length=300, min_length=30, do_sample=False) | |
| return output[0]["summary_text"] | |
| elif (modelname == 'Pegasus'): | |
| model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") | |
| tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum") | |
| else: | |
| summarizer = pipeline("summarization", model="lidiya/bart-large-xsum-samsum") | |
| print("world") | |
| output = summarizer(data, max_length=300, min_length=30, do_sample=False) | |
| return output[0]["summary_text"] | |
| # Create tokens - number representation of our text | |
| tokens = tokenizer(data, truncation=True, padding="longest", return_tensors="pt") | |
| summary = model.generate(**tokens) | |
| return tokenizer.decode(summary[0]) | |
| st.sidebar.title("Text Summarization") | |
| uploaded_file = st.file_uploader("Choose a file",help=" you can choose .txt file") | |
| data = "" | |
| output = "" | |
| if uploaded_file is not None: | |
| # To read file as bytes: | |
| bytes_data = uploaded_file.getvalue() | |
| data = bytes_data.decode("utf-8") | |
| modelname = st.radio("Choose your model", | |
| ["Bart", "Pegasus" ,"Meeting summary(bart-large-cnn-samsum)"], | |
| help=" you can choose between 3 models to summarize your text. More to come!", ) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.header("Copy paste your text or Upload file") | |
| if (uploaded_file is not None): | |
| with st.expander("Text to summarize", expanded=True): | |
| st.write( | |
| data | |
| ) | |
| else: | |
| with st.expander("Text to summarize", expanded=True): | |
| data = st.text_area("Paste your text below (max 500 words)", height=510, ) | |
| MAX_WORDS = 500 | |
| import re | |
| res = len(re.findall(r"\w+", data)) | |
| Summarizebtn = st.button("Summarize") | |
| if (Summarizebtn): | |
| output = summarize(data, modelname) | |
| with col2: | |
| st.header("Summary") | |
| if (len(output) > 0): | |
| with st.expander("", expanded=True): | |
| st.write(output) | |
| elif(Summarizebtn): | |
| st.balloons() | |