Spaces:
Build error
Build error
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Adapted from https://github.com/huggingface/transformers/blob/f93c90d21749b61bd89152a7fe99a839df29ed94/src/transformers/debug_utils.py | |
| """ | |
| import json | |
| from transformers.utils import ExplicitEnum, is_torch_available, logging | |
| from m4.training.utils import get_stats | |
| if is_torch_available(): | |
| import torch | |
| logger = logging.get_logger(__name__) | |
| class ActivationTracker: | |
| """ | |
| This debug class helps detect and understand where the model starts getting very large or very small, and more | |
| importantly `nan` or `inf` activation elements. | |
| This class will plug hooks into the model and record the activation values of the model into a list of dictionaries: `jsonl_stats`. | |
| Recording is only active during training, not during validation, and when `trace_activation` is set to True. | |
| In practise, since this tracking requires additional computation, we only track activations every X steps. | |
| In the case of gradient accumulation, all the batches being accumulated are being recorded and identified by the `batch_idx` key. | |
| Args: | |
| model (`nn.Module`): | |
| The model to debug. | |
| abort_after_batch_num (`int``, *optional*): | |
| Whether to abort after a certain batch number has finished | |
| """ | |
| def __init__( | |
| self, | |
| model, | |
| abort_after_batch_num=None, | |
| ): | |
| self.model = model | |
| self.is_validation = False | |
| self.abort_after_batch_num = abort_after_batch_num | |
| self.jsonl_stats = [] | |
| self.batch_number = 0 | |
| self.detected_overflow = False | |
| self.analyse_model() | |
| self.register_forward_hook() | |
| def analyse_model(self): | |
| # extract the fully qualified module names, to be able to report at run time. e.g.: | |
| # encoder.block.2.layer.0.SelfAttention.o | |
| # | |
| # for shared weights only the first shared module name will be registered | |
| self.module_names = {m: name for name, m in self.model.named_modules()} | |
| def analyse_variable(self, var, ctx, current_module_stats): | |
| if torch.is_tensor(var): | |
| dict_stats = get_stats(var, ctx) | |
| current_module_stats.update(dict_stats) | |
| # self.expand_frame(text_stats) | |
| if detect_overflow(var, ctx): | |
| self.detected_overflow = True | |
| return current_module_stats | |
| def create_frame(self, module, input, output): | |
| module_name = f"{self.module_names[module]}" | |
| module_type = f"{module.__class__.__name__}" | |
| current_module_stats = {} | |
| # inputs | |
| if isinstance(input, tuple): | |
| for i, x in enumerate(input): | |
| current_module_stats = self.analyse_variable(x, f"input[{i}]", current_module_stats) | |
| else: | |
| current_module_stats = self.analyse_variable(input, "input", current_module_stats) | |
| # outputs | |
| if isinstance(output, tuple): | |
| for i, x in enumerate(output): | |
| # possibly a tuple of tuples | |
| if isinstance(x, tuple): | |
| for j, y in enumerate(x): | |
| current_module_stats = self.analyse_variable(y, f"output[{i}][{j}]", current_module_stats) | |
| else: | |
| current_module_stats = self.analyse_variable(x, f"output[{i}]", current_module_stats) | |
| else: | |
| current_module_stats = self.analyse_variable(output, "output", current_module_stats) | |
| if current_module_stats: | |
| # When we activate gradient checkpointing, the forward hook will be called twice for some (not all) modules. | |
| # That will lead to double (repeated) entries in the list. | |
| # This is a hack to avoid these double entries. | |
| if (module_name, module_type) not in [(x["name"], x["type"]) for x in self.jsonl_stats]: | |
| self.jsonl_stats.append( | |
| { | |
| "name": module_name, | |
| "type": module_type, | |
| **current_module_stats, | |
| } | |
| ) | |
| def register_forward_hook(self): | |
| self.model.apply(self._register_forward_hook) | |
| def _register_forward_hook(self, module): | |
| module.register_forward_hook(self.forward_hook) | |
| def forward_hook(self, module, input, output): | |
| # - input is a tuple of packed inputs (could be non-Tensors) | |
| # - output could be a Tensor or a tuple of Tensors and non-Tensors | |
| trace_activation = self.trace_activation | |
| # count batch numbers - the very first forward hook of the batch will be called when the | |
| # batch completes - i.e. it gets called very last - we know this batch has finished | |
| if module == self.model: | |
| self.batch_number += 1 | |
| if trace_activation and not self.is_validation: | |
| self.create_frame(module, input, output) | |
| if self.detected_overflow: | |
| # now we can abort, as it's pointless to continue running | |
| raise ValueError( | |
| "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " | |
| "Please scroll up above this traceback to see the activation values prior to this event." | |
| ) | |
| # abort after certain batch if requested to do so | |
| if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: | |
| raise ValueError( | |
| f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to" | |
| f" `abort_after_batch_num={self.abort_after_batch_num}` arg" | |
| ) | |
| def fill_in_batch_idx(self, batch_idx): | |
| if not self.jsonl_stats: | |
| return | |
| for r in self.jsonl_stats: | |
| if "batch_idx" not in r: | |
| r["batch_idx"] = batch_idx | |
| else: | |
| if not (r["batch_idx"] <= batch_idx): | |
| raise ValueError("`batch_idx` should be increasing") | |
| def dump_stats(self, log_activations_filename, curr_opt_step): | |
| with open(log_activations_filename, "a") as file: | |
| # append stats to file | |
| for r in self.jsonl_stats: | |
| r["step"] = curr_opt_step | |
| file.write(json.dumps(r) + "\n") | |
| def reset_jsonl_stats(self): | |
| self.jsonl_stats = [] | |
| def activate_hooks(self): | |
| self.trace_activation = True | |
| def deactivate_hooks(self): | |
| self.trace_activation = False | |
| def is_eval(self): | |
| self.is_validation = True | |
| def is_train(self): | |
| self.is_validation = False | |
| def detect_overflow(var, ctx): | |
| """ | |
| Report whether the tensor contains any `nan` or `inf` entries. | |
| This is useful for detecting overflows/underflows and best to call right after the function that did some math that | |
| modified the tensor in question. | |
| This function contains a few other helper features that you can enable and tweak directly if you want to track | |
| various other things. | |
| Args: | |
| var: the tensor variable to check | |
| ctx: the message to print as a context | |
| Return: | |
| `True` if `inf` or `nan` was detected, `False` otherwise | |
| """ | |
| detected = False | |
| if torch.isnan(var).any().item(): | |
| detected = True | |
| print(f"{ctx} has nans") | |
| if torch.isinf(var).any().item(): | |
| detected = True | |
| print(f"{ctx} has infs") | |
| # if needed to monitor large elements can enable the following | |
| if 0: # and detected: | |
| n100 = var[torch.ge(var.abs(), 100)] | |
| if n100.numel() > 0: | |
| print(f"{ctx}: n100={n100.numel()}") | |
| n1000 = var[torch.ge(var.abs(), 1000)] | |
| if n1000.numel() > 0: | |
| print(f"{ctx}: n1000={n1000.numel()}") | |
| n10000 = var[torch.ge(var.abs(), 10000)] | |
| if n10000.numel() > 0: | |
| print(f"{ctx}: n10000={n10000.numel()}") | |
| if 0: | |
| print(f"min={var.min():9.2e} max={var.max():9.2e}") | |
| if 0: | |
| print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") | |
| return detected | |
| class DebugOption(ExplicitEnum): | |
| UNDERFLOW_OVERFLOW = "underflow_overflow" | |
| TPU_METRICS_DEBUG = "tpu_metrics_debug" | |