From 821b2fc9f10a59a21723dba5994dea194cfd5cce Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 9 Aug 2019 11:49:28 -0700 Subject: [PATCH] Implement __repr__ for LossScale subclasses. PiperOrigin-RevId: 262603680 --- .../training/experimental/loss_scale.py | 14 ++++++++++++ .../training/experimental/loss_scale_test.py | 22 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index 46da10183df..46f52f0a955 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -227,6 +227,9 @@ class FixedLossScale(LossScale): del grads return control_flow_ops.no_op(), True + def __repr__(self): + return 'FixedLossScale(%s)' % self._loss_scale_value + def get_config(self): return {'loss_scale_value': self._loss_scale_value} @@ -376,6 +379,17 @@ class DynamicLossScale(LossScale): should_apply_gradients = is_finite return update_op, should_apply_gradients + def __repr__(self): + if context.executing_eagerly(): + return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, ' + 'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' % + (self._current_loss_scale.numpy(), self._num_good_steps.numpy(), + self.initial_loss_scale, self.increment_period, self.multiplier)) + else: + return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, ' + 'multiplier=%s)' % + (self.initial_loss_scale, self.increment_period, self.multiplier)) + def get_config(self): return { 'initial_loss_scale': self.initial_loss_scale, diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py index c3e18a18422..e4a11144041 100644 --- a/tensorflow/python/training/experimental/loss_scale_test.py +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -92,6 +92,11 @@ class FixedLossScaleTest(test.TestCase): scalar = loss_scale_module.FixedLossScale(123) self.assertIsInstance(scalar(), ops.Tensor) + @test_util.run_in_graph_and_eager_modes + def test_repr(self): + loss_scale = loss_scale_module.FixedLossScale(123) + self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)') + def _get_example_iter(inputs): dataset = dataset_ops.Dataset.from_tensor_slices(inputs) @@ -302,5 +307,22 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase): scalar = loss_scale_module.DynamicLossScale() self.assertIsInstance(scalar(), ops.Tensor) + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_repr(self, strategy_fn): + with strategy_fn().scope(): + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=1, increment_period=2, multiplier=3) + if context.executing_eagerly(): + self.assertEqual(repr(loss_scale), + 'DynamicLossScale(current_loss_scale=1.0, ' + 'num_good_steps=0, initial_loss_scale=1.0, ' + 'increment_period=2, multiplier=3.0)') + else: + self.assertEqual(repr(loss_scale), + 'DynamicLossScale(initial_loss_scale=1.0, ' + 'increment_period=2, multiplier=3.0)') + + if __name__ == '__main__': test.main()