Do not create a tensor in FixedLossScale constructor.
This fixes a regression in the official tf.keras model. PiperOrigin-RevId: 242212952
This commit is contained in:
parent
b70247fbdd
commit
c2ed2bca11
@ -148,10 +148,10 @@ class FixedLossScale(LossScale):
|
||||
|
||||
Args:
|
||||
loss_scale_value: A Python float. Its ideal value varies depending on
|
||||
models to run. Choosing a too small loss_scale might affect model quality;
|
||||
a too big loss_scale might cause inf or nan. There is no single right
|
||||
loss_scale to apply. There is no harm choosing a relatively big number as
|
||||
long as no nan or inf is encountered in training.
|
||||
models to run. Choosing a too small loss_scale might affect model
|
||||
quality; a too big loss_scale might cause inf or nan. There is no single
|
||||
right loss_scale to apply. There is no harm choosing a relatively big
|
||||
number as long as no nan or inf is encountered in training.
|
||||
|
||||
Raises:
|
||||
ValueError: If loss_scale is less than 1.
|
||||
@ -161,19 +161,23 @@ class FixedLossScale(LossScale):
|
||||
raise ValueError('loss_scale_value must be a Python int or float.')
|
||||
if loss_scale_value < 1:
|
||||
raise ValueError('loss_scale_value must be at least 1.')
|
||||
self._python_loss_scale = float(loss_scale_value)
|
||||
self._tensor_loss_scale = ops.convert_to_tensor(self._python_loss_scale,
|
||||
dtype=dtypes.float32)
|
||||
# It's important we do not create tensors in the constructor, as such
|
||||
# tensors might be on a different device or tf.function vs when the tensor
|
||||
# is used. This would hurt performance. Therefore, we do not create a tensor
|
||||
# from loss_scale_value, but instead leave it as a Python float.
|
||||
# TODO(reedwm): Also do not create tensors in the DynamicLossScale
|
||||
# constructor.
|
||||
self._loss_scale_value = float(loss_scale_value)
|
||||
|
||||
def __call__(self):
|
||||
return self._tensor_loss_scale
|
||||
return ops.convert_to_tensor(self._loss_scale_value)
|
||||
|
||||
def update(self, grads):
|
||||
del grads
|
||||
return control_flow_ops.no_op(), True
|
||||
|
||||
def get_config(self):
|
||||
return {'loss_scale_value': self._python_loss_scale}
|
||||
return {'loss_scale_value': self._loss_scale_value}
|
||||
|
||||
|
||||
def _is_all_finite(grads):
|
||||
|
Loading…
Reference in New Issue
Block a user