Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -30,6 +30,8 @@ from transformers.optimization import get_linear_schedule_with_warmup
|
|
| 30 |
from transformers import BertForMaskedLM, AlbertTokenizer
|
| 31 |
from transformers import AutoConfig
|
| 32 |
from transformers import MegatronBertForMaskedLM
|
|
|
|
|
|
|
| 33 |
import argparse
|
| 34 |
import copy
|
| 35 |
import streamlit as st
|
|
@@ -297,9 +299,12 @@ class UniMCModel(nn.Module):
|
|
| 297 |
self.config = AutoConfig.from_pretrained(pre_train_dir)
|
| 298 |
if self.config.model_type == 'megatron-bert':
|
| 299 |
self.bert = MegatronBertForMaskedLM.from_pretrained(pre_train_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
else:
|
| 301 |
self.bert = BertForMaskedLM.from_pretrained(pre_train_dir)
|
| 302 |
-
|
| 303 |
self.loss_func = torch.nn.CrossEntropyLoss()
|
| 304 |
self.yes_token = yes_token
|
| 305 |
|
|
@@ -626,54 +631,82 @@ def load_model(model_path):
|
|
| 626 |
model = UniMCPipelines(args)
|
| 627 |
return model
|
| 628 |
|
| 629 |
-
|
| 630 |
def main():
|
| 631 |
|
| 632 |
text_dict={
|
| 633 |
-
'
|
| 634 |
-
'
|
| 635 |
-
'
|
| 636 |
-
'
|
| 637 |
-
'
|
| 638 |
}
|
| 639 |
|
| 640 |
question_dict={
|
| 641 |
-
'
|
| 642 |
-
'
|
| 643 |
-
'
|
| 644 |
-
'
|
| 645 |
-
'
|
| 646 |
}
|
| 647 |
|
| 648 |
choice_dict={
|
| 649 |
-
'
|
| 650 |
-
'
|
| 651 |
-
'
|
| 652 |
-
'
|
| 653 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
}
|
| 655 |
|
| 656 |
|
| 657 |
|
| 658 |
st.subheader("UniMC Zero-shot 体验")
|
| 659 |
|
| 660 |
-
st.sidebar.header("
|
| 661 |
sbform = st.sidebar.form("固定参数设置")
|
| 662 |
-
language = sbform.selectbox('
|
| 663 |
-
sbform.form_submit_button("
|
| 664 |
|
| 665 |
-
if
|
| 666 |
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
| 667 |
else:
|
| 668 |
-
model = load_model('IDEA-CCNL/Erlangshen-UniMC-
|
| 669 |
|
| 670 |
-
st.info("
|
| 671 |
-
model_type = st.selectbox('
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
|
| 673 |
-
|
| 674 |
-
sentences = st.text_area("请输入句子:", text_dict[model_type])
|
| 675 |
-
question = st.text_input("请输入问题(不输入问题也可以):", "")
|
| 676 |
-
choice = st.text_input("输入标签(以中文;分割):", choice_dict[model_type])
|
| 677 |
choice = choice.split(';')
|
| 678 |
|
| 679 |
data = [{"texta": sentences,
|
|
@@ -683,15 +716,13 @@ def main():
|
|
| 683 |
"answer": "", "label": 0,
|
| 684 |
"id": 0}]
|
| 685 |
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
"**Enter a text** above and **press the button** to predict the category."
|
| 694 |
-
)
|
| 695 |
|
| 696 |
|
| 697 |
|
|
|
|
| 30 |
from transformers import BertForMaskedLM, AlbertTokenizer
|
| 31 |
from transformers import AutoConfig
|
| 32 |
from transformers import MegatronBertForMaskedLM
|
| 33 |
+
from modeling_deberta_v2 import DebertaV2ForMaskedLM
|
| 34 |
+
from modeling_albert import AlbertForMaskedLM
|
| 35 |
import argparse
|
| 36 |
import copy
|
| 37 |
import streamlit as st
|
|
|
|
| 299 |
self.config = AutoConfig.from_pretrained(pre_train_dir)
|
| 300 |
if self.config.model_type == 'megatron-bert':
|
| 301 |
self.bert = MegatronBertForMaskedLM.from_pretrained(pre_train_dir)
|
| 302 |
+
elif self.config.model_type == 'deberta-v2':
|
| 303 |
+
self.bert = DebertaV2ForMaskedLM.from_pretrained(pre_train_dir)
|
| 304 |
+
elif self.config.model_type == 'albert':
|
| 305 |
+
self.bert = AlbertForMaskedLM.from_pretrained(pre_train_dir)
|
| 306 |
else:
|
| 307 |
self.bert = BertForMaskedLM.from_pretrained(pre_train_dir)
|
|
|
|
| 308 |
self.loss_func = torch.nn.CrossEntropyLoss()
|
| 309 |
self.yes_token = yes_token
|
| 310 |
|
|
|
|
| 631 |
model = UniMCPipelines(args)
|
| 632 |
return model
|
| 633 |
|
|
|
|
| 634 |
def main():
|
| 635 |
|
| 636 |
text_dict={
|
| 637 |
+
'Text classification「文本分类」':"彭于晏不着急,胡歌不着急,那我也不着急",
|
| 638 |
+
'Sentiment「情感分析」':"刚买iphone13 pro 还不到一个月,天天死机最差的一次购物体验",
|
| 639 |
+
'Similarity「语义匹配」':"今天心情不好",
|
| 640 |
+
'NLI 「自然语言推理」':"小明正在上高中",
|
| 641 |
+
'Multiple Choice「多项式阅读理解」':"女:您看这件衣服挺不错的,质量好,价钱也不贵。\n男:再看看吧。",
|
| 642 |
}
|
| 643 |
|
| 644 |
question_dict={
|
| 645 |
+
'Text classification「文本分类」':"这是什么类型的新闻?",
|
| 646 |
+
'Sentiment「情感分析」':"",
|
| 647 |
+
'Similarity「语义匹配」':"",
|
| 648 |
+
'NLI 「自然语言推理」':"",
|
| 649 |
+
'Multiple Choice「多项式阅读理解」':"这个男的是什么意思?",
|
| 650 |
}
|
| 651 |
|
| 652 |
choice_dict={
|
| 653 |
+
'Text classification「文本分类」':"故事;文化;娱乐;体育;财经;房产;汽车;教育;科技",
|
| 654 |
+
'Sentiment「情感分析」':"这是一条好评;这是一条差评",
|
| 655 |
+
'Similarity「语义匹配」':"可以理解为:我很不开心;不能理解为:我很不开心",
|
| 656 |
+
'NLI 「自然语言推理」':"可以推断出:小明是一个初中生;不能推断出:小明是一个初中生;很难推断出:小明是一个初中生",
|
| 657 |
+
'Multiple Choice「多项式阅读理解」':"不想要这件;衣服挺好的;衣服质量不好",
|
| 658 |
+
}
|
| 659 |
+
|
| 660 |
+
text_dict_en={
|
| 661 |
+
'Text classification「文本分类」':"Henkel AG & Company KGaA operates worldwide with leading brands and technologies in three business areas: Laundry & Home Care Beauty Care and Adhesive Technologies. Henkel is the name behind some of America’s favorite brands.",
|
| 662 |
+
'Sentiment「情感分析」':"a gorgeous , high-spirited musical from india that exquisitely blends music , dance , song , and high drama . ",
|
| 663 |
+
'Similarity「语义匹配」':"Ricky Clemons ' brief , troubled Missouri basketball career is over .",
|
| 664 |
+
'NLI 「自然语言推理」':"That was then, and then's gone. It's now now. I don't mean I 've done a sudden transformation.",
|
| 665 |
+
'Multiple Choice「多项式阅读理解」':"A huge crowd is in the stands in an arena. A man throws a javelin. Photographers take pictures in the background. several men",
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
question_dict_en={
|
| 669 |
+
'Text classification「文本分类」':"",
|
| 670 |
+
'Sentiment「情感分析」':"",
|
| 671 |
+
'Similarity「语义匹配」':"",
|
| 672 |
+
'NLI 「自然语言推理」':"",
|
| 673 |
+
'Multiple Choice「多项式��读理解」':"",
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
choice_dict_en={
|
| 677 |
+
'Text classification「文本分类」':"Company;Educational Institution;Artist;Athlete;Office Holder",
|
| 678 |
+
'Sentiment「情感分析」':"it's great;it's terrible",
|
| 679 |
+
'Similarity「语义匹配」':"That can be interpreted as Missouri kicked Ricky Clemons off its team , ending his troubled career there .;That cannot be interpreted as Missouri kicked Ricky Clemons off its team , ending his troubled career there .",
|
| 680 |
+
'NLI 「自然语言推理」':"we can infer that she has done a sudden transformation;we can not infer that she has done a sudden transformation;it is diffcult for us to infer that she has done a sudden transformation",
|
| 681 |
+
'Multiple Choice「多项式阅读理解」':"are water boarding in a river.;are shown throwing balls.;challenge the man to jump onto the rope.;run to where the javelin lands.",
|
| 682 |
}
|
| 683 |
|
| 684 |
|
| 685 |
|
| 686 |
st.subheader("UniMC Zero-shot 体验")
|
| 687 |
|
| 688 |
+
st.sidebar.header("Configuration「参数配置」")
|
| 689 |
sbform = st.sidebar.form("固定参数设置")
|
| 690 |
+
language = sbform.selectbox('Select a language「选择语言」', ['中文「Chinese」', 'English「英文」'])
|
| 691 |
+
sbform.form_submit_button("Submit configuration「提交配置」")
|
| 692 |
|
| 693 |
+
if '中文' in language:
|
| 694 |
model = load_model('IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese')
|
| 695 |
else:
|
| 696 |
+
model = load_model('IDEA-CCNL/Erlangshen-UniMC-Albert-235M-English')
|
| 697 |
|
| 698 |
+
st.info("Please input the following information「请输入以下信息...」")
|
| 699 |
+
model_type = st.selectbox('Select task type「选择任务类型」',['Text classification「文本分类」','Sentiment「情感分析」','Similarity「语义匹配」','NLI 「自然语言推理」','Multiple Choice「多项式阅读理解」'])
|
| 700 |
+
|
| 701 |
+
if '中文' in language:
|
| 702 |
+
sentences = st.text_area("Please input the context「请输入句子」", text_dict[model_type])
|
| 703 |
+
question = st.text_input("Please input the question「请输入问题(不输入问题也可以)」", question_dict[model_type])
|
| 704 |
+
choice = st.text_input("Please input the label「输入标签(以中文;分割)」", choice_dict[model_type])
|
| 705 |
+
else:
|
| 706 |
+
sentences = st.text_area("Please input the context「请输入句子」", text_dict_en[model_type])
|
| 707 |
+
question = st.text_input("Please input the question「请输入问题(不输入问题也可以)」", question_dict_en[model_type])
|
| 708 |
+
choice = st.text_input("Please input the label「输入标签(以中文;分割)」", choice_dict[model_type])
|
| 709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
choice = choice.split(';')
|
| 711 |
|
| 712 |
data = [{"texta": sentences,
|
|
|
|
| 716 |
"answer": "", "label": 0,
|
| 717 |
"id": 0}]
|
| 718 |
|
| 719 |
+
|
| 720 |
+
start=time.time()
|
| 721 |
+
result = model.predict(data, cuda=False)
|
| 722 |
+
st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
|
| 723 |
+
st.json(result[0])
|
| 724 |
+
f1.form_submit_button("Submit「点击一下,开始预测!」")
|
| 725 |
+
|
|
|
|
|
|
|
| 726 |
|
| 727 |
|
| 728 |
|