Implement __repr__ for LossScale subclasses.
PiperOrigin-RevId: 262603680
This commit is contained in:
parent
26220ec083
commit
d24f662e2a
tensorflow/python/training/experimental
@ -227,6 +227,9 @@ class FixedLossScale(LossScale):
|
||||
del grads
|
||||
return control_flow_ops.no_op(), True
|
||||
|
||||
def __repr__(self):
|
||||
return 'FixedLossScale(%s)' % self._loss_scale_value
|
||||
|
||||
def get_config(self):
|
||||
return {'loss_scale_value': self._loss_scale_value}
|
||||
|
||||
@ -376,6 +379,17 @@ class DynamicLossScale(LossScale):
|
||||
should_apply_gradients = is_finite
|
||||
return update_op, should_apply_gradients
|
||||
|
||||
def __repr__(self):
|
||||
if context.executing_eagerly():
|
||||
return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
|
||||
'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
|
||||
(self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
|
||||
self.initial_loss_scale, self.increment_period, self.multiplier))
|
||||
else:
|
||||
return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
|
||||
'multiplier=%s)' %
|
||||
(self.initial_loss_scale, self.increment_period, self.multiplier))
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'initial_loss_scale': self.initial_loss_scale,
|
||||
|
@ -92,6 +92,11 @@ class FixedLossScaleTest(test.TestCase):
|
||||
scalar = loss_scale_module.FixedLossScale(123)
|
||||
self.assertIsInstance(scalar(), ops.Tensor)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_repr(self):
|
||||
loss_scale = loss_scale_module.FixedLossScale(123)
|
||||
self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)')
|
||||
|
||||
|
||||
def _get_example_iter(inputs):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
|
||||
@ -302,5 +307,22 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
||||
scalar = loss_scale_module.DynamicLossScale()
|
||||
self.assertIsInstance(scalar(), ops.Tensor)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_repr(self, strategy_fn):
|
||||
with strategy_fn().scope():
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=1, increment_period=2, multiplier=3)
|
||||
if context.executing_eagerly():
|
||||
self.assertEqual(repr(loss_scale),
|
||||
'DynamicLossScale(current_loss_scale=1.0, '
|
||||
'num_good_steps=0, initial_loss_scale=1.0, '
|
||||
'increment_period=2, multiplier=3.0)')
|
||||
else:
|
||||
self.assertEqual(repr(loss_scale),
|
||||
'DynamicLossScale(initial_loss_scale=1.0, '
|
||||
'increment_period=2, multiplier=3.0)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user