Spaces:
Runtime error
Runtime error
Martijn van Beers
commited on
Commit
·
ab7830f
1
Parent(s):
733749d
Hack it up to do multiple explanations
Browse filesAdds in explanations from captum's LayerIntegratedGradients
app.py
CHANGED
|
@@ -5,10 +5,11 @@ sys.path.append("BERT_explainability")
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
|
|
|
| 8 |
from BERT_explainability.ExplanationGenerator import Generator
|
| 9 |
from BERT_explainability.roberta2 import RobertaForSequenceClassification
|
| 10 |
from transformers import AutoTokenizer
|
| 11 |
-
|
| 12 |
from captum.attr import visualization
|
| 13 |
import torch
|
| 14 |
|
|
@@ -39,6 +40,7 @@ model = RobertaForSequenceClassification.from_pretrained(
|
|
| 39 |
"textattack/roberta-base-SST-2"
|
| 40 |
).to(device)
|
| 41 |
model.eval()
|
|
|
|
| 42 |
tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
|
| 43 |
# initialize the explanations generator
|
| 44 |
explanations = Generator(model, "roberta")
|
|
@@ -151,7 +153,7 @@ def visualize_text(datarecords, legend=True):
|
|
| 151 |
return html
|
| 152 |
|
| 153 |
|
| 154 |
-
def show_explanation(model, input_ids, attention_mask, index=None, start_layer=
|
| 155 |
# generate an explanation for the input
|
| 156 |
output, expl = generate_relevance(
|
| 157 |
model, input_ids, attention_mask, index=index, start_layer=start_layer
|
|
@@ -177,32 +179,87 @@ def show_explanation(model, input_ids, attention_mask, index=None, start_layer=0
|
|
| 177 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
|
| 178 |
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
|
| 179 |
]
|
| 180 |
-
vis_data_records.append(list(zip(tokens, nrm.tolist())))
|
| 181 |
#print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
return vis_data_records
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
def sentence_sentiment(input_text):
|
| 199 |
text_batch = [input_text]
|
| 200 |
encoding = tokenizer(text_batch, return_tensors="pt")
|
| 201 |
input_ids = encoding["input_ids"].to(device)
|
| 202 |
attention_mask = encoding["attention_mask"].to(device)
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
return
|
| 206 |
|
| 207 |
def sentiment_explanation_hila(input_text):
|
| 208 |
text_batch = [input_text]
|
|
@@ -216,27 +273,19 @@ def sentiment_explanation_hila(input_text):
|
|
| 216 |
return show_explanation(model, input_ids, attention_mask)
|
| 217 |
|
| 218 |
hila = gradio.Interface(
|
| 219 |
-
fn=
|
| 220 |
inputs="text",
|
| 221 |
-
outputs="
|
| 222 |
-
title="RoBERTa Explanability",
|
| 223 |
-
description="Quick demo of a version of [Hila Chefer's](https://github.com/hila-chefer) [Transformer-Explanability](https://github.com/hila-chefer/Transformer-Explainability/) but without the layerwise relevance propagation (as in [Transformer-MM_explainability](https://github.com/hila-chefer/Transformer-MM-Explainability/)) for a RoBERTa model.",
|
| 224 |
-
examples=[
|
| 225 |
-
[
|
| 226 |
-
"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"
|
| 227 |
-
],
|
| 228 |
-
[
|
| 229 |
-
"I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
|
| 230 |
-
],
|
| 231 |
-
],
|
| 232 |
-
interpretation=sentiment_explanation_hila
|
| 233 |
)
|
| 234 |
-
|
| 235 |
fn=sentence_sentiment,
|
| 236 |
inputs="text",
|
| 237 |
-
outputs="
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
| 240 |
examples=[
|
| 241 |
[
|
| 242 |
"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"
|
|
@@ -245,8 +294,5 @@ shap = gradio.Interface(
|
|
| 245 |
"I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
|
| 246 |
],
|
| 247 |
],
|
| 248 |
-
interpretation="shap"
|
| 249 |
)
|
| 250 |
-
|
| 251 |
-
iface = gradio.Parallel(hila, shap)
|
| 252 |
iface.launch()
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
| 8 |
+
from transformers import AutoModelForSequenceClassification
|
| 9 |
from BERT_explainability.ExplanationGenerator import Generator
|
| 10 |
from BERT_explainability.roberta2 import RobertaForSequenceClassification
|
| 11 |
from transformers import AutoTokenizer
|
| 12 |
+
from captum.attr import LayerIntegratedGradients
|
| 13 |
from captum.attr import visualization
|
| 14 |
import torch
|
| 15 |
|
|
|
|
| 40 |
"textattack/roberta-base-SST-2"
|
| 41 |
).to(device)
|
| 42 |
model.eval()
|
| 43 |
+
model2 = AutoModelForSequenceClassification.from_pretrained("textattack/roberta-base-SST-2")
|
| 44 |
tokenizer = AutoTokenizer.from_pretrained("textattack/roberta-base-SST-2")
|
| 45 |
# initialize the explanations generator
|
| 46 |
explanations = Generator(model, "roberta")
|
|
|
|
| 153 |
return html
|
| 154 |
|
| 155 |
|
| 156 |
+
def show_explanation(model, input_ids, attention_mask, index=None, start_layer=8):
|
| 157 |
# generate an explanation for the input
|
| 158 |
output, expl = generate_relevance(
|
| 159 |
model, input_ids, attention_mask, index=index, start_layer=start_layer
|
|
|
|
| 179 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
|
| 180 |
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
|
| 181 |
]
|
| 182 |
+
# vis_data_records.append(list(zip(tokens, nrm.tolist())))
|
| 183 |
#print([(tokens[i], nrm[i].item()) for i in range(len(tokens))])
|
| 184 |
+
vis_data_records.append(
|
| 185 |
+
visualization.VisualizationDataRecord(
|
| 186 |
+
nrm,
|
| 187 |
+
output[record][classification],
|
| 188 |
+
classification,
|
| 189 |
+
classification,
|
| 190 |
+
index,
|
| 191 |
+
1,
|
| 192 |
+
tokens,
|
| 193 |
+
1,
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
return visualize_text(vis_data_records)
|
| 197 |
+
# return vis_data_records
|
| 198 |
+
|
| 199 |
+
def custom_forward(inputs, attention_mask=None, pos=0):
|
| 200 |
+
# print("inputs", inputs.shape)
|
| 201 |
+
result = model2(inputs, attention_mask=attention_mask, return_dict=True)
|
| 202 |
+
preds = result.logits
|
| 203 |
+
# print("preds", preds.shape)
|
| 204 |
+
return preds
|
| 205 |
|
| 206 |
+
def summarize_attributions(attributions):
|
| 207 |
+
attributions = attributions.sum(dim=-1).squeeze(0)
|
| 208 |
+
attributions = attributions / torch.norm(attributions)
|
| 209 |
+
return attributions
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def run_attribution_model(input_ids, attention_mask, ref_token_id=tokenizer.unk_token_id, layer=None, steps=20):
|
| 213 |
+
try:
|
| 214 |
+
output = model2(input_ids=input_ids, attention_mask=attention_mask)[0]
|
| 215 |
+
index = output.argmax(axis=-1).detach().cpu().numpy()
|
| 216 |
+
|
| 217 |
+
ablator = LayerIntegratedGradients(custom_forward, layer)
|
| 218 |
+
input_tensor = input_ids
|
| 219 |
+
attention_mask = attention_mask
|
| 220 |
+
attributions = ablator.attribute(
|
| 221 |
+
inputs=input_ids,
|
| 222 |
+
baselines=ref_token_id,
|
| 223 |
+
additional_forward_args=(attention_mask),
|
| 224 |
+
target=1,
|
| 225 |
+
n_steps=steps,
|
| 226 |
+
)
|
| 227 |
+
attributions = summarize_attributions(attributions).unsqueeze_(0)
|
| 228 |
+
finally:
|
| 229 |
+
pass
|
| 230 |
+
vis_data_records = []
|
| 231 |
+
print("IN", input_ids.size())
|
| 232 |
+
print("ATTR", attributions.shape)
|
| 233 |
+
for record in range(input_ids.size(0)):
|
| 234 |
+
classification = output[record].argmax(dim=-1).item()
|
| 235 |
+
class_name = classifications[classification]
|
| 236 |
+
attr = attributions[record]
|
| 237 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[record].flatten())[
|
| 238 |
+
1 : 0 - ((attention_mask[record] == 0).sum().item() + 1)
|
| 239 |
+
]
|
| 240 |
+
print("TOK", len(tokens), attr.shape)
|
| 241 |
+
vis_data_records.append(
|
| 242 |
+
visualization.VisualizationDataRecord(
|
| 243 |
+
attr,
|
| 244 |
+
output[record][classification],
|
| 245 |
+
classification,
|
| 246 |
+
classification,
|
| 247 |
+
index,
|
| 248 |
+
1,
|
| 249 |
+
tokens,
|
| 250 |
+
1,
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
return visualize_text(vis_data_records)
|
| 254 |
|
| 255 |
def sentence_sentiment(input_text):
|
| 256 |
text_batch = [input_text]
|
| 257 |
encoding = tokenizer(text_batch, return_tensors="pt")
|
| 258 |
input_ids = encoding["input_ids"].to(device)
|
| 259 |
attention_mask = encoding["attention_mask"].to(device)
|
| 260 |
+
layer = getattr(model2.roberta.encoder.layer, "8")
|
| 261 |
+
output = run_attribution_model(input_ids, attention_mask, layer=layer)
|
| 262 |
+
return output
|
| 263 |
|
| 264 |
def sentiment_explanation_hila(input_text):
|
| 265 |
text_batch = [input_text]
|
|
|
|
| 273 |
return show_explanation(model, input_ids, attention_mask)
|
| 274 |
|
| 275 |
hila = gradio.Interface(
|
| 276 |
+
fn=sentiment_explanation_hila,
|
| 277 |
inputs="text",
|
| 278 |
+
outputs="html",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
)
|
| 280 |
+
lig = gradio.Interface(
|
| 281 |
fn=sentence_sentiment,
|
| 282 |
inputs="text",
|
| 283 |
+
outputs="html",
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
iface = gradio.Parallel(hila, lig,
|
| 287 |
+
title="RoBERTa Explanability",
|
| 288 |
+
description="Quick comparison demo of explainability for sentiment prediction with RoBERTa. The outputs are from:\n\n* a version of [Hila Chefer's](https://github.com/hila-chefer) [Transformer-Explanability](https://github.com/hila-chefer/Transformer-Explainability/) but without the layerwise relevance propagation (as in [Transformer-MM_explainability](https://github.com/hila-chefer/Transformer-MM-Explainability/)) for a RoBERTa model.\n* [captum](https://captum.ai/)'s LayerIntegratedGradients",
|
| 289 |
examples=[
|
| 290 |
[
|
| 291 |
"This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great"
|
|
|
|
| 294 |
"I really didn't like this movie. Some of the actors were good, but overall the movie was boring"
|
| 295 |
],
|
| 296 |
],
|
|
|
|
| 297 |
)
|
|
|
|
|
|
|
| 298 |
iface.launch()
|