Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import shutil | |
| import tempfile | |
| import unittest | |
| from unittest.mock import patch | |
| from transformers import ( | |
| DefaultFlowCallback, | |
| IntervalStrategy, | |
| PrinterCallback, | |
| ProgressCallback, | |
| Trainer, | |
| TrainerCallback, | |
| TrainingArguments, | |
| is_torch_available, | |
| ) | |
| from transformers.testing_utils import require_torch | |
| if is_torch_available(): | |
| from transformers.trainer import DEFAULT_CALLBACKS | |
| from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel | |
| class MyTestTrainerCallback(TrainerCallback): | |
| "A callback that registers the events that goes through." | |
| def __init__(self): | |
| self.events = [] | |
| def on_init_end(self, args, state, control, **kwargs): | |
| self.events.append("on_init_end") | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| self.events.append("on_train_begin") | |
| def on_train_end(self, args, state, control, **kwargs): | |
| self.events.append("on_train_end") | |
| def on_epoch_begin(self, args, state, control, **kwargs): | |
| self.events.append("on_epoch_begin") | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| self.events.append("on_epoch_end") | |
| def on_step_begin(self, args, state, control, **kwargs): | |
| self.events.append("on_step_begin") | |
| def on_step_end(self, args, state, control, **kwargs): | |
| self.events.append("on_step_end") | |
| def on_evaluate(self, args, state, control, **kwargs): | |
| self.events.append("on_evaluate") | |
| def on_predict(self, args, state, control, **kwargs): | |
| self.events.append("on_predict") | |
| def on_save(self, args, state, control, **kwargs): | |
| self.events.append("on_save") | |
| def on_log(self, args, state, control, **kwargs): | |
| self.events.append("on_log") | |
| def on_prediction_step(self, args, state, control, **kwargs): | |
| self.events.append("on_prediction_step") | |
| class TrainerCallbackTest(unittest.TestCase): | |
| def setUp(self): | |
| self.output_dir = tempfile.mkdtemp() | |
| def tearDown(self): | |
| shutil.rmtree(self.output_dir) | |
| def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs): | |
| # disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure | |
| # its set to False since the tests later on depend on its value. | |
| train_dataset = RegressionDataset(length=train_len) | |
| eval_dataset = RegressionDataset(length=eval_len) | |
| config = RegressionModelConfig(a=a, b=b) | |
| model = RegressionPreTrainedModel(config) | |
| args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, report_to=[], **kwargs) | |
| return Trainer( | |
| model, | |
| args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| callbacks=callbacks, | |
| ) | |
| def check_callbacks_equality(self, cbs1, cbs2): | |
| self.assertEqual(len(cbs1), len(cbs2)) | |
| # Order doesn't matter | |
| cbs1 = sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__) | |
| cbs2 = sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__) | |
| for cb1, cb2 in zip(cbs1, cbs2): | |
| if isinstance(cb1, type) and isinstance(cb2, type): | |
| self.assertEqual(cb1, cb2) | |
| elif isinstance(cb1, type) and not isinstance(cb2, type): | |
| self.assertEqual(cb1, cb2.__class__) | |
| elif not isinstance(cb1, type) and isinstance(cb2, type): | |
| self.assertEqual(cb1.__class__, cb2) | |
| else: | |
| self.assertEqual(cb1, cb2) | |
| def get_expected_events(self, trainer): | |
| expected_events = ["on_init_end", "on_train_begin"] | |
| step = 0 | |
| train_dl_len = len(trainer.get_eval_dataloader()) | |
| evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"] | |
| for _ in range(trainer.state.num_train_epochs): | |
| expected_events.append("on_epoch_begin") | |
| for _ in range(train_dl_len): | |
| step += 1 | |
| expected_events += ["on_step_begin", "on_step_end"] | |
| if step % trainer.args.logging_steps == 0: | |
| expected_events.append("on_log") | |
| if trainer.args.eval_strategy == IntervalStrategy.STEPS and step % trainer.args.eval_steps == 0: | |
| expected_events += evaluation_events.copy() | |
| if step % trainer.args.save_steps == 0: | |
| expected_events.append("on_save") | |
| expected_events.append("on_epoch_end") | |
| if trainer.args.eval_strategy == IntervalStrategy.EPOCH: | |
| expected_events += evaluation_events.copy() | |
| expected_events += ["on_log", "on_train_end"] | |
| return expected_events | |
| def test_init_callback(self): | |
| trainer = self.get_trainer() | |
| expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| # Callbacks passed at init are added to the default callbacks | |
| trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) | |
| expected_callbacks.append(MyTestTrainerCallback) | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| # TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback | |
| trainer = self.get_trainer(disable_tqdm=True) | |
| expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback] | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| def test_add_remove_callback(self): | |
| expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] | |
| trainer = self.get_trainer() | |
| # We can add, pop, or remove by class name | |
| trainer.remove_callback(DefaultFlowCallback) | |
| expected_callbacks.remove(DefaultFlowCallback) | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| trainer = self.get_trainer() | |
| cb = trainer.pop_callback(DefaultFlowCallback) | |
| self.assertEqual(cb.__class__, DefaultFlowCallback) | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| trainer.add_callback(DefaultFlowCallback) | |
| expected_callbacks.insert(0, DefaultFlowCallback) | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| # We can also add, pop, or remove by instance | |
| trainer = self.get_trainer() | |
| cb = trainer.callback_handler.callbacks[0] | |
| trainer.remove_callback(cb) | |
| expected_callbacks.remove(DefaultFlowCallback) | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| trainer = self.get_trainer() | |
| cb1 = trainer.callback_handler.callbacks[0] | |
| cb2 = trainer.pop_callback(cb1) | |
| self.assertEqual(cb1, cb2) | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| trainer.add_callback(cb1) | |
| expected_callbacks.insert(0, DefaultFlowCallback) | |
| self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) | |
| def test_event_flow(self): | |
| import warnings | |
| # XXX: for now ignore scatter_gather warnings in this test since it's not relevant to what's being tested | |
| warnings.simplefilter(action="ignore", category=UserWarning) | |
| trainer = self.get_trainer(callbacks=[MyTestTrainerCallback]) | |
| trainer.train() | |
| events = trainer.callback_handler.callbacks[-2].events | |
| self.assertEqual(events, self.get_expected_events(trainer)) | |
| # Independent log/save/eval | |
| trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], logging_steps=5) | |
| trainer.train() | |
| events = trainer.callback_handler.callbacks[-2].events | |
| self.assertEqual(events, self.get_expected_events(trainer)) | |
| trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], save_steps=5) | |
| trainer.train() | |
| events = trainer.callback_handler.callbacks[-2].events | |
| self.assertEqual(events, self.get_expected_events(trainer)) | |
| trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_steps=5, eval_strategy="steps") | |
| trainer.train() | |
| events = trainer.callback_handler.callbacks[-2].events | |
| self.assertEqual(events, self.get_expected_events(trainer)) | |
| trainer = self.get_trainer(callbacks=[MyTestTrainerCallback], eval_strategy="epoch") | |
| trainer.train() | |
| events = trainer.callback_handler.callbacks[-2].events | |
| self.assertEqual(events, self.get_expected_events(trainer)) | |
| # A bit of everything | |
| trainer = self.get_trainer( | |
| callbacks=[MyTestTrainerCallback], | |
| logging_steps=3, | |
| save_steps=10, | |
| eval_steps=5, | |
| eval_strategy="steps", | |
| ) | |
| trainer.train() | |
| events = trainer.callback_handler.callbacks[-2].events | |
| self.assertEqual(events, self.get_expected_events(trainer)) | |
| # warning should be emitted for duplicated callbacks | |
| with patch("transformers.trainer_callback.logger.warning") as warn_mock: | |
| trainer = self.get_trainer( | |
| callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], | |
| ) | |
| assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0] | |