Do not create a tensor in FixedLossScale constructor.

This fixes a regression in the official tf.keras model.

PiperOrigin-RevId: 242212952
This commit is contained in:
Reed Wanderman-Milne 2019-04-05 16:12:42 -07:00 committed by TensorFlower Gardener
parent b70247fbdd
commit c2ed2bca11

View File

@ -148,10 +148,10 @@ class FixedLossScale(LossScale):
Args: Args:
loss_scale_value: A Python float. Its ideal value varies depending on loss_scale_value: A Python float. Its ideal value varies depending on
models to run. Choosing a too small loss_scale might affect model quality; models to run. Choosing a too small loss_scale might affect model
a too big loss_scale might cause inf or nan. There is no single right quality; a too big loss_scale might cause inf or nan. There is no single
loss_scale to apply. There is no harm choosing a relatively big number as right loss_scale to apply. There is no harm choosing a relatively big
long as no nan or inf is encountered in training. number as long as no nan or inf is encountered in training.
Raises: Raises:
ValueError: If loss_scale is less than 1. ValueError: If loss_scale is less than 1.
@ -161,19 +161,23 @@ class FixedLossScale(LossScale):
raise ValueError('loss_scale_value must be a Python int or float.') raise ValueError('loss_scale_value must be a Python int or float.')
if loss_scale_value < 1: if loss_scale_value < 1:
raise ValueError('loss_scale_value must be at least 1.') raise ValueError('loss_scale_value must be at least 1.')
self._python_loss_scale = float(loss_scale_value) # It's important we do not create tensors in the constructor, as such
self._tensor_loss_scale = ops.convert_to_tensor(self._python_loss_scale, # tensors might be on a different device or tf.function vs when the tensor
dtype=dtypes.float32) # is used. This would hurt performance. Therefore, we do not create a tensor
# from loss_scale_value, but instead leave it as a Python float.
# TODO(reedwm): Also do not create tensors in the DynamicLossScale
# constructor.
self._loss_scale_value = float(loss_scale_value)
def __call__(self): def __call__(self):
return self._tensor_loss_scale return ops.convert_to_tensor(self._loss_scale_value)
def update(self, grads): def update(self, grads):
del grads del grads
return control_flow_ops.no_op(), True return control_flow_ops.no_op(), True
def get_config(self): def get_config(self):
return {'loss_scale_value': self._python_loss_scale} return {'loss_scale_value': self._loss_scale_value}
def _is_all_finite(grads): def _is_all_finite(grads):