Implement __repr__ for LossScale subclasses.

PiperOrigin-RevId: 262603680
This commit is contained in:
Reed Wanderman-Milne 2019-08-09 11:49:28 -07:00 committed by TensorFlower Gardener
parent 26220ec083
commit d24f662e2a
2 changed files with 36 additions and 0 deletions
tensorflow/python/training/experimental

View File

@ -227,6 +227,9 @@ class FixedLossScale(LossScale):
del grads
return control_flow_ops.no_op(), True
def __repr__(self):
return 'FixedLossScale(%s)' % self._loss_scale_value
def get_config(self):
return {'loss_scale_value': self._loss_scale_value}
@ -376,6 +379,17 @@ class DynamicLossScale(LossScale):
should_apply_gradients = is_finite
return update_op, should_apply_gradients
def __repr__(self):
if context.executing_eagerly():
return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
(self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
self.initial_loss_scale, self.increment_period, self.multiplier))
else:
return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
'multiplier=%s)' %
(self.initial_loss_scale, self.increment_period, self.multiplier))
def get_config(self):
return {
'initial_loss_scale': self.initial_loss_scale,

View File

@ -92,6 +92,11 @@ class FixedLossScaleTest(test.TestCase):
scalar = loss_scale_module.FixedLossScale(123)
self.assertIsInstance(scalar(), ops.Tensor)
@test_util.run_in_graph_and_eager_modes
def test_repr(self):
loss_scale = loss_scale_module.FixedLossScale(123)
self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)')
def _get_example_iter(inputs):
dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
@ -302,5 +307,22 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
scalar = loss_scale_module.DynamicLossScale()
self.assertIsInstance(scalar(), ops.Tensor)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_repr(self, strategy_fn):
with strategy_fn().scope():
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1, increment_period=2, multiplier=3)
if context.executing_eagerly():
self.assertEqual(repr(loss_scale),
'DynamicLossScale(current_loss_scale=1.0, '
'num_good_steps=0, initial_loss_scale=1.0, '
'increment_period=2, multiplier=3.0)')
else:
self.assertEqual(repr(loss_scale),
'DynamicLossScale(initial_loss_scale=1.0, '
'increment_period=2, multiplier=3.0)')
if __name__ == '__main__':
test.main()