Do not create a tensor in FixedLossScale constructor.

This fixes a regression in the official tf.keras model.

PiperOrigin-RevId: 242212952
This commit is contained in:
Reed Wanderman-Milne 2019-04-05 16:12:42 -07:00 committed by TensorFlower Gardener
parent b70247fbdd
commit c2ed2bca11

View File

@ -148,10 +148,10 @@ class FixedLossScale(LossScale):
Args:
loss_scale_value: A Python float. Its ideal value varies depending on
models to run. Choosing a too small loss_scale might affect model quality;
a too big loss_scale might cause inf or nan. There is no single right
loss_scale to apply. There is no harm choosing a relatively big number as
long as no nan or inf is encountered in training.
models to run. Choosing a too small loss_scale might affect model
quality; a too big loss_scale might cause inf or nan. There is no single
right loss_scale to apply. There is no harm choosing a relatively big
number as long as no nan or inf is encountered in training.
Raises:
ValueError: If loss_scale is less than 1.
@ -161,19 +161,23 @@ class FixedLossScale(LossScale):
raise ValueError('loss_scale_value must be a Python int or float.')
if loss_scale_value < 1:
raise ValueError('loss_scale_value must be at least 1.')
self._python_loss_scale = float(loss_scale_value)
self._tensor_loss_scale = ops.convert_to_tensor(self._python_loss_scale,
dtype=dtypes.float32)
# It's important we do not create tensors in the constructor, as such
# tensors might be on a different device or tf.function vs when the tensor
# is used. This would hurt performance. Therefore, we do not create a tensor
# from loss_scale_value, but instead leave it as a Python float.
# TODO(reedwm): Also do not create tensors in the DynamicLossScale
# constructor.
self._loss_scale_value = float(loss_scale_value)
def __call__(self):
return self._tensor_loss_scale
return ops.convert_to_tensor(self._loss_scale_value)
def update(self, grads):
del grads
return control_flow_ops.no_op(), True
def get_config(self):
return {'loss_scale_value': self._python_loss_scale}
return {'loss_scale_value': self._loss_scale_value}
def _is_all_finite(grads):