Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import numpy as np | |
| import torch | |
| import arxiv | |
| def main(): | |
| id_provided = True | |
| st.set_page_config( | |
| layout="wide", | |
| initial_sidebar_state="auto", | |
| page_title="Title Generator!", | |
| page_icon=None, | |
| ) | |
| st.title("Title Generator: Generate a title from the abstract of a paper") | |
| st.text("") | |
| st.text("") | |
| example = st.text_area("Provide the link/id for an arxiv paper", """https://arxiv.org/abs/2111.10339""", | |
| ) | |
| # st.selectbox("Provide the link/id for an arxiv paper", example_prompts) | |
| # Take the message which needs to be processed | |
| message = st.text_area("...or paste a paper's abstract to generate a title") | |
| if len(message)<1: | |
| message=example | |
| id_provided = True | |
| ids = message.split('/')[-1] | |
| search = arxiv.Search(id_list=[ids]) | |
| for result in search.results(): | |
| message = result.summary | |
| title = result.title | |
| else: | |
| id_provided = False | |
| st.text("") | |
| models_to_choose = [ | |
| "AryanLala/autonlp-Scientific_Title_Generator-34558227", | |
| "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full" | |
| ] | |
| BASE_MODEL = st.selectbox("Choose a model to generate the title", models_to_choose) | |
| def preprocess(text): | |
| if ((BASE_MODEL == "AryanLala/autonlp-Scientific_Title_Generator-34558227") | | |
| (BASE_MODEL == "shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full")): | |
| return [text] | |
| else: | |
| st.error("Please select a model first") | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) | |
| return model, tokenizer | |
| def get_summary(text): | |
| with st.spinner(text="Processing your request"): | |
| model, tokenizer = load_model() | |
| preprocessed = preprocess(text) | |
| inputs = tokenizer( | |
| preprocessed, truncation=True, padding="longest", return_tensors="pt" | |
| ) | |
| output = model.generate( | |
| **inputs, | |
| max_length=60, | |
| num_beams=10, | |
| num_return_sequences=1, | |
| temperature=1.5, | |
| ) | |
| target_text = tokenizer.batch_decode(output, skip_special_tokens=True) | |
| return target_text[0] | |
| # Define function to run when submit is clicked | |
| def submit(message): | |
| if len(message) > 0: | |
| summary = get_summary(message) | |
| if id_provided: | |
| html_str = f""" | |
| <style> | |
| p.a {{ | |
| font: 20px Courier; | |
| }} | |
| </style> | |
| <p class="a"><b>Title Generated:></b> {summary} </p> | |
| <p class="a"><b>Original Title:></b> {title} </p> | |
| """ | |
| else: | |
| html_str = f""" | |
| <style> | |
| p.a {{ | |
| font: 20px Courier; | |
| }} | |
| </style> | |
| <p class="a"><b>Title Generated:></b> {summary} </p> | |
| """ | |
| st.markdown(html_str, unsafe_allow_html=True) | |
| # st.markdown(emoji) | |
| else: | |
| st.error("The text can't be empty") | |
| # Run algo when submit button is clicked | |
| if st.button("Submit"): | |
| submit(message) | |
| with st.expander("Additional Information"): | |
| st.markdown(""" | |
| The models used were fine-tuned on subset of data from the [Arxiv Dataset](https://huggingface.co/datasets/arxiv_dataset) | |
| The task of the models is to suggest an appropraite title from the abstract of a scientific paper. | |
| The model [AryanLala/autonlp-Scientific_Title_Generator-34558227]() was trained on data | |
| from the Cs.AI (Artificial Intelligence) category of papers. | |
| The model [shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full](https://huggingface.co/shamikbose89/mt5-small-finetuned-arxiv-cs-finetuned-arxiv-cs-full) | |
| was trained on the categories: cs.AI, cs.LG, cs.NI, cs.GR cs.CL, cs.CV (Artificial Intelligence, Machine Learning, Networking and Internet Architecture, Graphics, Computation and Language, Computer Vision and Pattern Recognition) | |
| Also, <b>Thank you to arXiv for use of its open access interoperability.</b> It allows us to pull the required abstracts from passed ids | |
| """,unsafe_allow_html=True,) | |
| st.text('\n') | |
| st.text('\n') | |
| st.markdown( | |
| '''<span style="color:blue; font-size:10px">App created by [@akshay7](https://huggingface.co/akshay7), [@AryanLala](https://huggingface.co/AryanLala) and [@shamikbose89](https://huggingface.co/shamikbose89) | |
| </span>''', | |
| unsafe_allow_html=True, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |