Spaces:
Build error
Build error
| import numpy as np | |
| import csv | |
| from typing import Optional | |
| from urllib.request import urlopen | |
| import gradio as gr | |
| class SentimentTransform(): | |
| def __init__( | |
| self, | |
| model_name: str = "cardiffnlp/twitter-roberta-base-sentiment", | |
| highlight: bool = False, | |
| positive_sentiment_name: str = "positive", | |
| max_number_of_shap_documents: Optional[int] = None, | |
| min_abs_score: float = 0.1, | |
| sensitivity: float = 0, | |
| **kwargs, | |
| ): | |
| """ | |
| Sentiment Ops. | |
| Parameters | |
| ------------- | |
| model_name: str | |
| The name of the model | |
| sensitivity: float | |
| How confident it is about being `neutral`. If you are dealing with news sources, | |
| you probably want less sensitivity | |
| """ | |
| self.model_name = model_name | |
| self.highlight = highlight | |
| self.positive_sentiment_name = positive_sentiment_name | |
| self.max_number_of_shap_documents = max_number_of_shap_documents | |
| self.min_abs_score = min_abs_score | |
| self.sensitivity = sensitivity | |
| for k, v in kwargs.items(): | |
| setattr(self, k, v) | |
| def preprocess(self, text: str): | |
| new_text = [] | |
| for t in text.split(" "): | |
| t = "@user" if t.startswith("@") and len(t) > 1 else t | |
| t = "http" if t.startswith("http") else t | |
| new_text.append(t) | |
| return " ".join(new_text) | |
| def classifier(self): | |
| if not hasattr(self, "_classifier"): | |
| import transformers | |
| self._classifier = transformers.pipeline( | |
| return_all_scores=True, | |
| model=self.model_name, | |
| ) | |
| return self._classifier | |
| def _get_label_mapping(self, task: str): | |
| # Note: this is specific to the current model | |
| labels = [] | |
| mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt" | |
| with urlopen(mapping_link) as f: | |
| html = f.read().decode("utf-8").split("\n") | |
| csvreader = csv.reader(html, delimiter="\t") | |
| labels = [row[1] for row in csvreader if len(row) > 1] | |
| return labels | |
| def label_mapping(self): | |
| return {"LABEL_0": "negative", "LABEL_1": "neutral", "LABEL_2": "positive"} | |
| def analyze_sentiment( | |
| self, | |
| text, | |
| highlight: bool = False, | |
| positive_sentiment_name: str = "positive", | |
| max_number_of_shap_documents: Optional[int] = None, | |
| min_abs_score: float = 0.1, | |
| ): | |
| if text is None: | |
| return None | |
| labels = self.classifier([str(text)], truncation=True, max_length=512) | |
| ind_max = np.argmax([l["score"] for l in labels[0]]) | |
| sentiment = labels[0][ind_max]["label"] | |
| max_score = labels[0][ind_max]["score"] | |
| sentiment = self.label_mapping.get(sentiment, sentiment) | |
| if sentiment.lower() == "neutral" and max_score > self.sensitivity: | |
| overall_sentiment = 1e-5 | |
| elif sentiment.lower() == "neutral": | |
| # get the next highest score | |
| new_labels = labels[0][:ind_max] + labels[0][(ind_max + 1):] | |
| new_ind_max = np.argmax([l["score"] for l in new_labels]) | |
| new_max_score = new_labels[new_ind_max]["score"] | |
| new_sentiment = new_labels[new_ind_max]["label"] | |
| new_sentiment = self.label_mapping.get(new_sentiment, new_sentiment) | |
| overall_sentiment = self._calculate_overall_sentiment( | |
| new_max_score, new_sentiment | |
| ) | |
| else: | |
| overall_sentiment = self._calculate_overall_sentiment(max_score, sentiment) | |
| # Adjust to avoid bug | |
| if overall_sentiment == 0: | |
| overall_sentiment = 1e-5 | |
| if not highlight: | |
| return { | |
| "sentiment": sentiment, | |
| "overall_sentiment_score": overall_sentiment, | |
| } | |
| shap_documents = self.get_shap_values( | |
| text, | |
| sentiment_ind=ind_max, | |
| max_number_of_shap_documents=max_number_of_shap_documents, | |
| min_abs_score=min_abs_score, | |
| ) | |
| return { | |
| "sentiment": sentiment, | |
| "score": max_score, | |
| "overall_sentiment": overall_sentiment, | |
| "highlight_chunk_": shap_documents, | |
| } | |
| def _calculate_overall_sentiment(self, score: float, sentiment: str): | |
| if sentiment.lower().strip() == self.positive_sentiment_name: | |
| return score | |
| else: | |
| return -score | |
| # def explainer(self): | |
| # if hasattr(self, "_explainer"): | |
| # return self._explainer | |
| # else: | |
| # try: | |
| # import shap | |
| # except ModuleNotFoundError: | |
| # raise MissingPackageError("shap") | |
| # self._explainer = shap.Explainer(self.classifier) | |
| # return self._explainer | |
| def get_shap_values( | |
| self, | |
| text: str, | |
| sentiment_ind: int = 2, | |
| max_number_of_shap_documents: Optional[int] = None, | |
| min_abs_score: float = 0.1, | |
| ): | |
| """Get SHAP values""" | |
| shap_values = self.explainer([text]) | |
| cohorts = {"": shap_values} | |
| cohort_labels = list(cohorts.keys()) | |
| cohort_exps = list(cohorts.values()) | |
| features = cohort_exps[0].data | |
| feature_names = cohort_exps[0].feature_names | |
| values = np.array([cohort_exps[i].values for i in range(len(cohort_exps))]) | |
| shap_docs = [ | |
| {"text": v, "score": f} | |
| for f, v in zip( | |
| [x[sentiment_ind] for x in values[0][0].tolist()], feature_names[0] | |
| ) | |
| ] | |
| if max_number_of_shap_documents is not None: | |
| sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True) | |
| else: | |
| sorted_scores = sorted(shap_docs, key=lambda x: x["score"], reverse=True)[ | |
| :max_number_of_shap_documents | |
| ] | |
| return [d for d in sorted_scores if abs(d["score"]) > min_abs_score] | |
| def transform(self, text): | |
| # # For each document, update the field | |
| # sentiment_docs = [{"_id": d["_id"]} for d in documents] | |
| # for i, t in enumerate(self.text_fields): | |
| # if self.output_fields is not None: | |
| # output_field = self.output_fields[i] | |
| # else: | |
| # output_field = self._get_output_field(t) | |
| sentiment = self.analyze_sentiment( | |
| text, | |
| highlight=self.highlight, | |
| max_number_of_shap_documents=self.max_number_of_shap_documents, | |
| min_abs_score=self.min_abs_score, ) | |
| return sentiment | |
| def sentiment_classifier(text, model_type, sensitivity): | |
| if model_type == 'Social Media Model': | |
| model_name = "cardiffnlp/twitter-roberta-base-sentiment" | |
| elif model_type == 'Survey Model': | |
| model_name = "j-hartmann/sentiment-roberta-large-english-3-classes" | |
| else: | |
| model_name = "j-hartmann/sentiment-roberta-large-english-3-classes" | |
| model = SentimentTransform(model_name=model_name, sensitivity=sensitivity) | |
| res_dict = model.transform(text) | |
| return res_dict['sentiment'], res_dict['overall_sentiment_score'] | |
| demo = gr.Interface( | |
| fn=sentiment_classifier, | |
| inputs=[gr.Textbox(placeholder="Put the text here and click 'submit' to predict its sentiment", label="Input Text"), gr.Dropdown(["Social Media Model", "Survey Model"], value="Survey Model", label="Select the Model that you want to use."), gr.Slider(0, 1, step = 0.01, label="Sensitivity (How confident it is about being `neutral`. If you are dealing with news sources, you probably want less sensitivity.)")], | |
| outputs=[gr.Textbox(label='Sentiment'), gr.Textbox(label='Sentiment Score')], | |
| ) | |
| demo.launch(debug=True) |