Update custom_generate/generate.py
Browse files- custom_generate/generate.py +13 -70
custom_generate/generate.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch
|
|
| 5 |
import torch.nn as nn
|
| 6 |
|
| 7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
| 8 |
-
from transformers.cache_utils import Cache,
|
| 9 |
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
from transformers.generation.utils import (
|
| 11 |
ALL_CACHE_NAMES,
|
|
@@ -249,17 +249,13 @@ def _contrastive_search(
|
|
| 249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 250 |
"for contrastive search."
|
| 251 |
)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
or (
|
| 256 |
-
isinstance(past_key_values, EncoderDecoderCache)
|
| 257 |
-
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
| 258 |
-
)
|
| 259 |
):
|
| 260 |
raise ValueError(
|
| 261 |
-
f"
|
| 262 |
-
"
|
| 263 |
)
|
| 264 |
|
| 265 |
# contrastive_search main logic start:
|
|
@@ -294,24 +290,7 @@ def _contrastive_search(
|
|
| 294 |
|
| 295 |
if not sequential:
|
| 296 |
# Replicates the new past_key_values to match the `top_k` candidates
|
| 297 |
-
|
| 298 |
-
# If it is a static cache, modify it in-place layer after layer to save memory
|
| 299 |
-
if isinstance(past, DynamicCache) or (
|
| 300 |
-
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
|
| 301 |
-
):
|
| 302 |
-
past.batch_repeat_interleave(top_k)
|
| 303 |
-
else:
|
| 304 |
-
new_key_values = []
|
| 305 |
-
for layer in past:
|
| 306 |
-
items = []
|
| 307 |
-
# item is either the key or the value matrix
|
| 308 |
-
for item in layer:
|
| 309 |
-
items.append(item.repeat_interleave(top_k, dim=0))
|
| 310 |
-
new_key_values.append(tuple(items))
|
| 311 |
-
|
| 312 |
-
past = tuple(new_key_values)
|
| 313 |
-
|
| 314 |
-
model_kwargs["past_key_values"] = past
|
| 315 |
|
| 316 |
if sequential:
|
| 317 |
all_outputs = []
|
|
@@ -325,19 +304,10 @@ def _contrastive_search(
|
|
| 325 |
output_hidden_states=True,
|
| 326 |
output_attentions=output_attentions,
|
| 327 |
)
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
)
|
| 332 |
-
# Remove past K-V from output since we don't need to stack later
|
| 333 |
-
outputs["past_key_values"] = None
|
| 334 |
-
# Remove last token from past K-V since we don't want to append it at this point
|
| 335 |
-
model_kwargs["past_key_values"].crop(-1)
|
| 336 |
-
else:
|
| 337 |
-
raise ValueError(
|
| 338 |
-
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 339 |
-
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 340 |
-
)
|
| 341 |
|
| 342 |
all_outputs.append(outputs)
|
| 343 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
|
@@ -424,22 +394,7 @@ def _contrastive_search(
|
|
| 424 |
next_past_key_values = None
|
| 425 |
for possible_cache_name in ALL_CACHE_NAMES:
|
| 426 |
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
|
| 427 |
-
|
| 428 |
-
if isinstance(next_past_key_values, DynamicCache) or (
|
| 429 |
-
isinstance(next_past_key_values, EncoderDecoderCache)
|
| 430 |
-
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
|
| 431 |
-
):
|
| 432 |
-
next_past_key_values.batch_select_indices(augmented_idx)
|
| 433 |
-
else:
|
| 434 |
-
new_key_values = []
|
| 435 |
-
for layer in next_past_key_values:
|
| 436 |
-
items = []
|
| 437 |
-
# item is either the key or the value matrix
|
| 438 |
-
for item in layer:
|
| 439 |
-
items.append(item[augmented_idx, ...])
|
| 440 |
-
new_key_values.append(tuple(items))
|
| 441 |
-
|
| 442 |
-
next_past_key_values = tuple(new_key_values)
|
| 443 |
|
| 444 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
|
| 445 |
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
|
@@ -503,19 +458,7 @@ def _contrastive_search(
|
|
| 503 |
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
| 504 |
# `past_key_values` to be consistent with the other decoding methods
|
| 505 |
if model_kwargs.get("past_key_values") is not None:
|
| 506 |
-
|
| 507 |
-
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
| 508 |
-
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
|
| 509 |
-
):
|
| 510 |
-
model_kwargs["past_key_values"].crop(-1)
|
| 511 |
-
else:
|
| 512 |
-
past_key_values = []
|
| 513 |
-
for layer in model_kwargs["past_key_values"]:
|
| 514 |
-
layer_past_key_values = []
|
| 515 |
-
for item in layer:
|
| 516 |
-
layer_past_key_values.append(item[..., :-1, :])
|
| 517 |
-
past_key_values.append(tuple(layer_past_key_values))
|
| 518 |
-
model_kwargs["past_key_values"] = tuple(past_key_values)
|
| 519 |
|
| 520 |
if model.config.is_encoder_decoder:
|
| 521 |
return GenerateEncoderDecoderOutput(
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
|
| 7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
| 8 |
+
from transformers.cache_utils import Cache, EncoderDecoderCache
|
| 9 |
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
from transformers.generation.utils import (
|
| 11 |
ALL_CACHE_NAMES,
|
|
|
|
| 249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 250 |
"for contrastive search."
|
| 251 |
)
|
| 252 |
+
elif (
|
| 253 |
+
not isinstance(past_key_values[0], (tuple, torch.Tensor))
|
| 254 |
+
or past_key_values[0][0].shape[0] != batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
):
|
| 256 |
raise ValueError(
|
| 257 |
+
f"{model.__class__.__name__} does not have a standard cache format and therefore **can't** be "
|
| 258 |
+
"used for contrastive search without further modifications."
|
| 259 |
)
|
| 260 |
|
| 261 |
# contrastive_search main logic start:
|
|
|
|
| 290 |
|
| 291 |
if not sequential:
|
| 292 |
# Replicates the new past_key_values to match the `top_k` candidates
|
| 293 |
+
model_kwargs["past_key_values"].batch_repeat_interleave(top_k)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
if sequential:
|
| 296 |
all_outputs = []
|
|
|
|
| 304 |
output_hidden_states=True,
|
| 305 |
output_attentions=output_attentions,
|
| 306 |
)
|
| 307 |
+
# Remove past K-V from output since we don't need to stack later
|
| 308 |
+
outputs["past_key_values"] = None
|
| 309 |
+
# Remove last token from past K-V since we don't want to append it at this point
|
| 310 |
+
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
all_outputs.append(outputs)
|
| 313 |
outputs = stack_model_outputs(all_outputs, model.config.get_text_config())
|
|
|
|
| 394 |
next_past_key_values = None
|
| 395 |
for possible_cache_name in ALL_CACHE_NAMES:
|
| 396 |
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
|
| 397 |
+
next_past_key_values.batch_select_indices(augmented_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
|
| 399 |
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
|
| 400 |
logit_for_next_step = logit_for_next_step.to(input_ids.device)
|
|
|
|
| 458 |
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
| 459 |
# `past_key_values` to be consistent with the other decoding methods
|
| 460 |
if model_kwargs.get("past_key_values") is not None:
|
| 461 |
+
model_kwargs["past_key_values"].crop(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
|
| 463 |
if model.config.is_encoder_decoder:
|
| 464 |
return GenerateEncoderDecoderOutput(
|