Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| # ๋ชจ๋ธ ์ค๋นํ๊ธฐ | |
| from transformers import RobertaForSequenceClassification, AutoTokenizer | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import os | |
| # [theme] | |
| # base="dark" | |
| # primaryColor="purple" | |
| # ์ ๋ชฉ ์ ๋ ฅ | |
| st.header('ํ๊ตญํ์ค์ฐ์ ๋ถ๋ฅ ์๋์ฝ๋ฉ ์๋น์ค') | |
| # ์ฌ๋ก๋ ์ํ๋๋ก | |
| def md_loading(): | |
| ## cpu | |
| # device = torch.device('cpu') | |
| tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base') | |
| model = RobertaForSequenceClassification.from_pretrained('klue/roberta-base', num_labels=495) | |
| model_checkpoint = 'upsampling_20.bin' | |
| project_path = './' | |
| output_model_file = os.path.join(project_path, model_checkpoint) | |
| model.load_state_dict(torch.load(output_model_file, map_location=torch.device('cpu'))) | |
| label_tbl = np.load('./label_table.npy') | |
| loc_tbl = pd.read_csv('./kisc_table.csv', encoding='utf-8') | |
| print('ready') | |
| return tokenizer, model, label_tbl, loc_tbl | |
| # ๋ชจ๋ธ ๋ก๋ | |
| tokenizer, model, label_tbl, loc_tbl = md_loading() | |
| # ํ ์คํธ input ๋ฐ์ค | |
| business = st.text_input('์ฌ์ ์ฒด๋ช ').replace(',', '') | |
| business_work = st.text_input('์ฌ์ ์ฒด ํ๋์ผ').replace(',', '') | |
| work_department = st.text_input('๊ทผ๋ฌด๋ถ์').replace(',', '') | |
| work_position = st.text_input('์ง์ฑ ').replace(',', '') | |
| what_do_i = st.text_input('๋ด๊ฐ ํ๋ ์ผ').replace(',', '') | |
| # md_input: ๋ชจ๋ธ์ ์ ๋ ฅํ input ๊ฐ ์ ์ | |
| md_input = ', '.join([business, business_work, work_department, work_position, what_do_i]) | |
| ## ์์ ํ์ธ | |
| # st.write(md_input) | |
| # ๋ฒํผ | |
| if st.button('ํ์ธ'): | |
| ## ๋ฒํผ ํด๋ฆญ ์ ์ํ์ฌํญ | |
| ### ๋ชจ๋ธ ์คํ | |
| query_tokens = md_input.split(',') | |
| input_ids = np.zeros(shape=[1, 64]) | |
| attention_mask = np.zeros(shape=[1, 64]) | |
| seq = '[CLS] ' | |
| try: | |
| for i in range(5): | |
| seq += query_tokens[i] + ' ' | |
| except: | |
| None | |
| tokens = tokenizer.tokenize(seq) | |
| ids = tokenizer.convert_tokens_to_ids(tokens) | |
| length = len(ids) | |
| if length > 64: | |
| length = 64 | |
| for i in range(length): | |
| input_ids[0, i] = ids[i] | |
| attention_mask[0, i] = 1 | |
| input_ids = torch.from_numpy(input_ids).type(torch.long) | |
| attention_mask = torch.from_numpy(attention_mask).type(torch.long) | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=None) | |
| logits = outputs.logits | |
| # # ๋จ๋ ์์ธก ์ | |
| # arg_idx = torch.argmax(logits, dim=1) | |
| # print('arg_idx:', arg_idx) | |
| # num_ans = label_tbl[arg_idx] | |
| # str_ans = loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == num_ans].values | |
| # ์์ k๋ฒ์งธ๊น์ง ์์ธก ์ | |
| k = 10 | |
| topk_idx = torch.topk(logits.flatten(), k).indices | |
| num_ans_topk = label_tbl[topk_idx] | |
| str_ans_topk = [loc_tbl['ํญ๋ชฉ๋ช '][loc_tbl['์ฝ๋'] == k] for k in num_ans_topk] | |
| # print(num_ans, str_ans) | |
| # print(num_ans_topk) | |
| # print('์ฌ์ ์ฒด๋ช :', query_tokens[0]) | |
| # print('์ฌ์ ์ฒด ํ๋์ผ:', query_tokens[1]) | |
| # print('๊ทผ๋ฌด๋ถ์:', query_tokens[2]) | |
| # print('์ง์ฑ :', query_tokens[3]) | |
| # print('๋ด๊ฐ ํ๋์ผ:', query_tokens[4]) | |
| # print('์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:', num_ans, str_ans) | |
| # ans = '' | |
| # ans1, ans2, ans3 = '', '', '' | |
| ## ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ ์ถ๋ ฅ | |
| # st.write("์ฐ์ ์ฝ๋ ๋ฐ ๋ถ๋ฅ:", num_ans, str_ans[0]) | |
| # st.write("์ธ๋ถ๋ฅ ์ฝ๋") | |
| # for i in range(k): | |
| # st.write(str(i+1) + '์์:', num_ans_topk[i], str_ans_topk[i].iloc[0]) | |
| # print(num_ans) | |
| # print(str_ans, type(str_ans)) | |
| str_ans_topk_list = [] | |
| for i in range(k): | |
| str_ans_topk_list.append(str_ans_topk[i].iloc[0]) | |
| # print(str_ans_topk_list) | |
| ans_topk_df = pd.DataFrame({ | |
| 'NO': range(1, k+1), | |
| '์ธ๋ถ๋ฅ ์ฝ๋': num_ans_topk, | |
| '์ธ๋ถ๋ฅ ๋ช ์นญ': str_ans_topk_list | |
| }) | |
| ans_topk_df = ans_topk_df.set_index('NO') | |
| st.dataframe(ans_topk_df) |