From c2ed2bca11d3f470037413c38d53ab0be8afd1c6 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 5 Apr 2019 16:12:42 -0700 Subject: [PATCH] Do not create a tensor in FixedLossScale constructor. This fixes a regression in the official tf.keras model. PiperOrigin-RevId: 242212952 --- .../experimental/loss_scale.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py index 1c58e80753c..d686eabd849 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py @@ -148,10 +148,10 @@ class FixedLossScale(LossScale): Args: loss_scale_value: A Python float. Its ideal value varies depending on - models to run. Choosing a too small loss_scale might affect model quality; - a too big loss_scale might cause inf or nan. There is no single right - loss_scale to apply. There is no harm choosing a relatively big number as - long as no nan or inf is encountered in training. + models to run. Choosing a too small loss_scale might affect model + quality; a too big loss_scale might cause inf or nan. There is no single + right loss_scale to apply. There is no harm choosing a relatively big + number as long as no nan or inf is encountered in training. Raises: ValueError: If loss_scale is less than 1. @@ -161,19 +161,23 @@ class FixedLossScale(LossScale): raise ValueError('loss_scale_value must be a Python int or float.') if loss_scale_value < 1: raise ValueError('loss_scale_value must be at least 1.') - self._python_loss_scale = float(loss_scale_value) - self._tensor_loss_scale = ops.convert_to_tensor(self._python_loss_scale, - dtype=dtypes.float32) + # It's important we do not create tensors in the constructor, as such + # tensors might be on a different device or tf.function vs when the tensor + # is used. This would hurt performance. Therefore, we do not create a tensor + # from loss_scale_value, but instead leave it as a Python float. + # TODO(reedwm): Also do not create tensors in the DynamicLossScale + # constructor. + self._loss_scale_value = float(loss_scale_value) def __call__(self): - return self._tensor_loss_scale + return ops.convert_to_tensor(self._loss_scale_value) def update(self, grads): del grads return control_flow_ops.no_op(), True def get_config(self): - return {'loss_scale_value': self._python_loss_scale} + return {'loss_scale_value': self._loss_scale_value} def _is_all_finite(grads):