Do not create a tensor in FixedLossScale constructor.
This fixes a regression in the official tf.keras model. PiperOrigin-RevId: 242212952
This commit is contained in:
parent
b70247fbdd
commit
c2ed2bca11
@ -148,10 +148,10 @@ class FixedLossScale(LossScale):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss_scale_value: A Python float. Its ideal value varies depending on
|
loss_scale_value: A Python float. Its ideal value varies depending on
|
||||||
models to run. Choosing a too small loss_scale might affect model quality;
|
models to run. Choosing a too small loss_scale might affect model
|
||||||
a too big loss_scale might cause inf or nan. There is no single right
|
quality; a too big loss_scale might cause inf or nan. There is no single
|
||||||
loss_scale to apply. There is no harm choosing a relatively big number as
|
right loss_scale to apply. There is no harm choosing a relatively big
|
||||||
long as no nan or inf is encountered in training.
|
number as long as no nan or inf is encountered in training.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If loss_scale is less than 1.
|
ValueError: If loss_scale is less than 1.
|
||||||
@ -161,19 +161,23 @@ class FixedLossScale(LossScale):
|
|||||||
raise ValueError('loss_scale_value must be a Python int or float.')
|
raise ValueError('loss_scale_value must be a Python int or float.')
|
||||||
if loss_scale_value < 1:
|
if loss_scale_value < 1:
|
||||||
raise ValueError('loss_scale_value must be at least 1.')
|
raise ValueError('loss_scale_value must be at least 1.')
|
||||||
self._python_loss_scale = float(loss_scale_value)
|
# It's important we do not create tensors in the constructor, as such
|
||||||
self._tensor_loss_scale = ops.convert_to_tensor(self._python_loss_scale,
|
# tensors might be on a different device or tf.function vs when the tensor
|
||||||
dtype=dtypes.float32)
|
# is used. This would hurt performance. Therefore, we do not create a tensor
|
||||||
|
# from loss_scale_value, but instead leave it as a Python float.
|
||||||
|
# TODO(reedwm): Also do not create tensors in the DynamicLossScale
|
||||||
|
# constructor.
|
||||||
|
self._loss_scale_value = float(loss_scale_value)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
return self._tensor_loss_scale
|
return ops.convert_to_tensor(self._loss_scale_value)
|
||||||
|
|
||||||
def update(self, grads):
|
def update(self, grads):
|
||||||
del grads
|
del grads
|
||||||
return control_flow_ops.no_op(), True
|
return control_flow_ops.no_op(), True
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {'loss_scale_value': self._python_loss_scale}
|
return {'loss_scale_value': self._loss_scale_value}
|
||||||
|
|
||||||
|
|
||||||
def _is_all_finite(grads):
|
def _is_all_finite(grads):
|
||||||
|
Loading…
Reference in New Issue
Block a user