Support Layer.get_config with a mixed_[b]float16 dtype policy.
Note saving a model with Model.save() is still unsupported. PiperOrigin-RevId: 272759486
This commit is contained in:
parent
0c25833dfe
commit
3c30766e1b
@ -468,8 +468,10 @@ class Layer(module.Module):
|
|||||||
config = {'name': self.name, 'trainable': self.trainable}
|
config = {'name': self.name, 'trainable': self.trainable}
|
||||||
if hasattr(self, '_batch_input_shape'):
|
if hasattr(self, '_batch_input_shape'):
|
||||||
config['batch_input_shape'] = self._batch_input_shape
|
config['batch_input_shape'] = self._batch_input_shape
|
||||||
|
# TODO(reedwm): Remove the hasattr(self, 'dtype') check. All layers have a
|
||||||
|
# dtype.
|
||||||
if hasattr(self, 'dtype'):
|
if hasattr(self, 'dtype'):
|
||||||
config['dtype'] = self.dtype
|
config['dtype'] = policy.serialize(self._dtype_policy)
|
||||||
if hasattr(self, 'dynamic'):
|
if hasattr(self, 'dynamic'):
|
||||||
# Only include `dynamic` in the `config` if it is `True`
|
# Only include `dynamic` in the `config` if it is `True`
|
||||||
if self.dynamic:
|
if self.dynamic:
|
||||||
@ -484,7 +486,6 @@ class Layer(module.Module):
|
|||||||
if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
|
if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
|
||||||
raise NotImplementedError('Layers with arguments in `__init__` must '
|
raise NotImplementedError('Layers with arguments in `__init__` must '
|
||||||
'override `get_config`.')
|
'override `get_config`.')
|
||||||
# TODO(reedwm): Handle serializing self._dtype_policy.
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1677,6 +1678,8 @@ class Layer(module.Module):
|
|||||||
"""Sets self._dtype_policy."""
|
"""Sets self._dtype_policy."""
|
||||||
if isinstance(dtype, policy.Policy):
|
if isinstance(dtype, policy.Policy):
|
||||||
self._dtype_policy = dtype
|
self._dtype_policy = dtype
|
||||||
|
elif isinstance(dtype, dict):
|
||||||
|
self._dtype_policy = policy.deserialize(dtype)
|
||||||
elif dtype:
|
elif dtype:
|
||||||
self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
|
self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
|
||||||
else:
|
else:
|
||||||
|
@ -94,6 +94,9 @@ class AddLayer(AssertTypeLayer):
|
|||||||
**kwargs: Passed to AssertTypeLayer constructor.
|
**kwargs: Passed to AssertTypeLayer constructor.
|
||||||
"""
|
"""
|
||||||
self._regularizer = regularizer
|
self._regularizer = regularizer
|
||||||
|
if isinstance(regularizer, dict):
|
||||||
|
self._regularizer = regularizers.deserialize(regularizer,
|
||||||
|
custom_objects=globals())
|
||||||
self._use_operator = use_operator
|
self._use_operator = use_operator
|
||||||
self._var_name = var_name
|
self._var_name = var_name
|
||||||
super(AddLayer, self).__init__(**kwargs)
|
super(AddLayer, self).__init__(**kwargs)
|
||||||
@ -116,8 +119,7 @@ class AddLayer(AssertTypeLayer):
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = super(AddLayer, self).get_config()
|
config = super(AddLayer, self).get_config()
|
||||||
assert self._regularizer is None, (
|
config['regularizer'] = regularizers.serialize(self._regularizer)
|
||||||
'regularizer must be None to get config for AddLayer')
|
|
||||||
config['use_operator'] = self._use_operator
|
config['use_operator'] = self._use_operator
|
||||||
config['var_name'] = self._var_name
|
config['var_name'] = self._var_name
|
||||||
config['assert_type'] = self._assert_type
|
config['assert_type'] = self._assert_type
|
||||||
@ -159,6 +161,9 @@ class IdentityRegularizer(regularizers.Regularizer):
|
|||||||
assert x.dtype == dtypes.float32
|
assert x.dtype == dtypes.float32
|
||||||
return array_ops.identity(x)
|
return array_ops.identity(x)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
# If called outside any strategy.scope() calls, this will return the default
|
# If called outside any strategy.scope() calls, this will return the default
|
||||||
# strategy.
|
# strategy.
|
||||||
@ -429,6 +434,83 @@ class KerasLayerTest(keras_parameterized.TestCase):
|
|||||||
self._test_checkpointing_layer_weights(
|
self._test_checkpointing_layer_weights(
|
||||||
strategy_fn, mixed_prec_when_saving=False, mixed_prec_when_loading=True)
|
strategy_fn, mixed_prec_when_saving=False, mixed_prec_when_loading=True)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(*TESTCASES)
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_config(self, strategy_fn):
|
||||||
|
x = constant_op.constant([1.], dtype=dtypes.float16)
|
||||||
|
with strategy_fn().scope():
|
||||||
|
for layer, dtype in (
|
||||||
|
(AddLayer(), 'float32'),
|
||||||
|
(AddLayer(dtype='float64'), 'float64'),
|
||||||
|
(AddLayer(dtype=policy.Policy('float64')), 'float64')):
|
||||||
|
config = layer.get_config()
|
||||||
|
self.assertEqual(config['dtype'], dtype)
|
||||||
|
self.assertIsInstance(config['dtype'], str)
|
||||||
|
layer = AddLayer.from_config(config)
|
||||||
|
self.assertEqual(layer.dtype, dtype)
|
||||||
|
self.assertEqual(layer(x).dtype, dtype)
|
||||||
|
self.assertEqual(layer.v.dtype, dtype)
|
||||||
|
|
||||||
|
layer = AddLayer(dtype=policy.Policy('mixed_float16'))
|
||||||
|
config = layer.get_config()
|
||||||
|
self.assertEqual(config['dtype'],
|
||||||
|
{'class_name': 'Policy',
|
||||||
|
'config': {'name': 'mixed_float16'}})
|
||||||
|
layer = AddLayer.from_config(config)
|
||||||
|
self.assertEqual(layer.dtype, 'float32')
|
||||||
|
self.assertEqual(layer(x).dtype, 'float16')
|
||||||
|
self.assertEqual(layer.v.dtype, 'float32')
|
||||||
|
|
||||||
|
layer = AddLayer(dtype=policy.Policy('mixed_float16', loss_scale=None))
|
||||||
|
config = layer.get_config()
|
||||||
|
self.assertEqual(config['dtype'],
|
||||||
|
{'class_name': 'Policy',
|
||||||
|
'config': {'name': 'mixed_float16',
|
||||||
|
'loss_scale': None}})
|
||||||
|
layer = AddLayer.from_config(config)
|
||||||
|
self.assertEqual(layer.dtype, 'float32')
|
||||||
|
self.assertEqual(layer(x).dtype, 'float16')
|
||||||
|
self.assertEqual(layer.v.dtype, 'float32')
|
||||||
|
|
||||||
|
layer = AddLayer(dtype=policy.Policy('float64', loss_scale=2.))
|
||||||
|
config = layer.get_config()
|
||||||
|
self.assertEqual(config['dtype'],
|
||||||
|
{'class_name': 'Policy',
|
||||||
|
'config': {'name': 'float64',
|
||||||
|
'loss_scale': {
|
||||||
|
'class_name': 'FixedLossScale',
|
||||||
|
'config': {'loss_scale_value': 2.0}}}})
|
||||||
|
layer = AddLayer.from_config(config)
|
||||||
|
self.assertEqual(layer.dtype, 'float64')
|
||||||
|
self.assertEqual(layer(x).dtype, 'float64')
|
||||||
|
self.assertEqual(layer.v.dtype, 'float64')
|
||||||
|
|
||||||
|
layer = AddLayer(dtype=policy.Policy('infer'))
|
||||||
|
config = layer.get_config()
|
||||||
|
self.assertIsNone(config['dtype'])
|
||||||
|
layer = AddLayer.from_config(config)
|
||||||
|
# If a layer is serialized with the "infer" policy, when deserialized into
|
||||||
|
# TF 2 it will have the global policy instead of "infer". This is because
|
||||||
|
# "infer" is serialized into None, and passing dtype=None in TensorFlow 2
|
||||||
|
# indicates to use the global policy.
|
||||||
|
self.assertEqual(layer.dtype, 'float32')
|
||||||
|
self.assertEqual(layer(x).dtype, 'float32')
|
||||||
|
self.assertEqual(layer.v.dtype, 'float32')
|
||||||
|
|
||||||
|
layer = AddLayer(dtype=policy.Policy('infer', loss_scale=2.))
|
||||||
|
config = layer.get_config()
|
||||||
|
self.assertEqual(config['dtype'],
|
||||||
|
{'class_name': 'Policy',
|
||||||
|
'config': {'name': 'infer',
|
||||||
|
'loss_scale': {
|
||||||
|
'class_name': 'FixedLossScale',
|
||||||
|
'config': {'loss_scale_value': 2.0}}}})
|
||||||
|
layer = AddLayer.from_config(config)
|
||||||
|
self.assertEqual(layer.dtype, None)
|
||||||
|
self.assertEqual(layer(x).dtype, 'float16')
|
||||||
|
self.assertEqual(layer.v.dtype, 'float16')
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_delete_variable(self):
|
def test_delete_variable(self):
|
||||||
layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
|
layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
|
||||||
@ -485,6 +567,15 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
'testcase_name': 'infer',
|
'testcase_name': 'infer',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
'policy_name': 'mixed_float16'
|
'policy_name': 'mixed_float16'
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'get_config',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'get_config': True
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'get_config_with_regularizer',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'get_config': True,
|
||||||
|
'use_regularizer': True,
|
||||||
}, {
|
}, {
|
||||||
'testcase_name': 'norun_distributed',
|
'testcase_name': 'norun_distributed',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
@ -496,6 +587,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
use_operator=False,
|
use_operator=False,
|
||||||
use_regularizer=False,
|
use_regularizer=False,
|
||||||
policy_name='mixed_float16',
|
policy_name='mixed_float16',
|
||||||
|
get_config=False,
|
||||||
experimental_run_tf_function=True):
|
experimental_run_tf_function=True):
|
||||||
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
|
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
|
||||||
return
|
return
|
||||||
@ -521,6 +613,12 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
layer_list += [layer, cast_f32_layer]
|
layer_list += [layer, cast_f32_layer]
|
||||||
model = testing_utils.get_model_from_layers(
|
model = testing_utils.get_model_from_layers(
|
||||||
layer_list, input_shape=(1,), input_dtype=dtypes.float16)
|
layer_list, input_shape=(1,), input_dtype=dtypes.float16)
|
||||||
|
if get_config:
|
||||||
|
config = model.get_config()
|
||||||
|
model = model.__class__.from_config(
|
||||||
|
config, custom_objects={'AddLayer': AddLayer})
|
||||||
|
(layer,) = (layer for layer in model.layers
|
||||||
|
if isinstance(layer, AddLayer))
|
||||||
|
|
||||||
def loss_fn(y_true, y_pred):
|
def loss_fn(y_true, y_pred):
|
||||||
del y_true
|
del y_true
|
||||||
@ -703,6 +801,15 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
'testcase_name': 'pass_loss_scale_to_policy',
|
'testcase_name': 'pass_loss_scale_to_policy',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
'pass_loss_scale_to_policy': True,
|
'pass_loss_scale_to_policy': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'get_config',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'get_config': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'get_config_and_pass_loss_scale_to_policy',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'get_config': True,
|
||||||
|
'pass_loss_scale_to_policy': True,
|
||||||
}, {
|
}, {
|
||||||
'testcase_name': 'norun_distributed',
|
'testcase_name': 'norun_distributed',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
@ -711,6 +818,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
def test_dynamic_loss_scaling(self,
|
def test_dynamic_loss_scaling(self,
|
||||||
strategy_fn,
|
strategy_fn,
|
||||||
pass_loss_scale_to_policy=False,
|
pass_loss_scale_to_policy=False,
|
||||||
|
get_config=False,
|
||||||
experimental_run_tf_function=True):
|
experimental_run_tf_function=True):
|
||||||
if not self._is_strategy_supported(strategy_fn):
|
if not self._is_strategy_supported(strategy_fn):
|
||||||
return
|
return
|
||||||
@ -746,6 +854,12 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
y = core.Lambda(identity_with_grad_check_fn)(y)
|
y = core.Lambda(identity_with_grad_check_fn)(y)
|
||||||
y = math_ops.cast(y, dtypes.float32)
|
y = math_ops.cast(y, dtypes.float32)
|
||||||
model = models.Model(inputs=x, outputs=y)
|
model = models.Model(inputs=x, outputs=y)
|
||||||
|
if get_config:
|
||||||
|
config = model.get_config()
|
||||||
|
model = model.__class__.from_config(
|
||||||
|
config, custom_objects={'AddLayer': AddLayer})
|
||||||
|
(layer,) = (layer for layer in model.layers
|
||||||
|
if isinstance(layer, AddLayer))
|
||||||
|
|
||||||
def loss_fn(y_true, y_pred):
|
def loss_fn(y_true, y_pred):
|
||||||
del y_true
|
del y_true
|
||||||
@ -980,6 +1094,8 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
|
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
|
||||||
if not self._is_strategy_supported(strategy_fn):
|
if not self._is_strategy_supported(strategy_fn):
|
||||||
return
|
return
|
||||||
|
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
|
||||||
|
# as well.
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
|
@ -24,8 +24,9 @@ import six
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module
|
||||||
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
|
||||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
@ -306,12 +307,15 @@ class Policy(object):
|
|||||||
|
|
||||||
if loss_scale == USE_DEFAULT:
|
if loss_scale == USE_DEFAULT:
|
||||||
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
||||||
|
self._using_default_loss_scale = True
|
||||||
|
else:
|
||||||
|
self._using_default_loss_scale = False
|
||||||
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
||||||
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
|
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
|
||||||
'float16 policies. You passed loss_scale=%r for policy '
|
'float16 policies. You passed loss_scale=%r for policy '
|
||||||
'%s. Consider not passing any loss_scale instead.' %
|
'%s. Consider not passing any loss_scale instead.' %
|
||||||
(loss_scale, name))
|
(loss_scale, name))
|
||||||
self._loss_scale = loss_scale_module.get(loss_scale)
|
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||||
|
|
||||||
def _parse_name(self, name):
|
def _parse_name(self, name):
|
||||||
"""Parses a Policy name into a compute and variable dtype.
|
"""Parses a Policy name into a compute and variable dtype.
|
||||||
@ -426,6 +430,25 @@ class Policy(object):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale)
|
return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {
|
||||||
|
'name': self.name
|
||||||
|
}
|
||||||
|
if not self._using_default_loss_scale:
|
||||||
|
# We only include the loss scale if the default loss scale is not used.
|
||||||
|
# This allows us to change the loss scale config format without breaking
|
||||||
|
# users who use the default loss scale.
|
||||||
|
config['loss_scale'] = keras_loss_scale_module.serialize(self.loss_scale)
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config, custom_objects=None):
|
||||||
|
if 'loss_scale' in config and isinstance(config['loss_scale'], dict):
|
||||||
|
config = config.copy()
|
||||||
|
config['loss_scale'] = keras_loss_scale_module.deserialize(
|
||||||
|
config['loss_scale'], custom_objects=custom_objects)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
|
||||||
def with_input_dtype(policy, dtype):
|
def with_input_dtype(policy, dtype):
|
||||||
"""Copies "infer" `policy`, adding `dtype` to it.
|
"""Copies "infer" `policy`, adding `dtype` to it.
|
||||||
@ -556,3 +579,55 @@ def policy_scope(policy):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
set_policy(old_policy)
|
set_policy(old_policy)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_convertible_to_dtype(dtype):
|
||||||
|
try:
|
||||||
|
dtypes.as_dtype(dtype)
|
||||||
|
return True
|
||||||
|
except TypeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _policy_equivalent_to_dtype(policy):
|
||||||
|
"""Returns True if the Policy is equivalent to a single dtype.
|
||||||
|
|
||||||
|
A policy is equivalent to a single dtype if the policy's compute and variable
|
||||||
|
dtypes are the same and the policy does not cause the layer/model to have
|
||||||
|
additional behavior, such as loss scaling.
|
||||||
|
|
||||||
|
The "infer" policy is considered equivalent to a single dtype.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy: A Policy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True, if the policy is equivalent to a single dtype.
|
||||||
|
"""
|
||||||
|
# We use type() instead of isinstance because a sublcass of Policy is never
|
||||||
|
# equivalent to a dtype.
|
||||||
|
return (type(policy) == Policy and # pylint: disable=unidiomatic-typecheck
|
||||||
|
list(policy.get_config().keys()) == ['name'] and
|
||||||
|
(policy.name == 'infer' or _is_convertible_to_dtype(policy.name)))
|
||||||
|
|
||||||
|
|
||||||
|
def serialize(policy):
|
||||||
|
if _policy_equivalent_to_dtype(policy):
|
||||||
|
# We return either None or the policy name for compatibility with older
|
||||||
|
# versions of Keras. If the policy name is returned, it is a dtype string
|
||||||
|
# such as 'float32'.
|
||||||
|
return None if policy.name == 'infer' else policy.name
|
||||||
|
return generic_utils.serialize_keras_object(policy)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize(config, custom_objects=None):
|
||||||
|
if isinstance(config, str) and _is_convertible_to_dtype(config):
|
||||||
|
return Policy(config)
|
||||||
|
if config is None:
|
||||||
|
return Policy('infer')
|
||||||
|
module_objects = {'Policy': Policy}
|
||||||
|
return generic_utils.deserialize_keras_object(
|
||||||
|
config,
|
||||||
|
module_objects=module_objects,
|
||||||
|
custom_objects=custom_objects,
|
||||||
|
printable_module_name='dtype policy')
|
||||||
|
@ -227,6 +227,97 @@ class PolicyTest(test.TestCase):
|
|||||||
'infer_with_float32_vars')
|
'infer_with_float32_vars')
|
||||||
self.assertEqual(mp_policy.global_policy().name, default_policy)
|
self.assertEqual(mp_policy.global_policy().name, default_policy)
|
||||||
|
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_config(self):
|
||||||
|
for policy in (
|
||||||
|
mp_policy.Policy('float16'),
|
||||||
|
mp_policy.Policy('float32'),
|
||||||
|
mp_policy.Policy('int16'),
|
||||||
|
mp_policy.Policy('mixed_float16'),
|
||||||
|
mp_policy.Policy('mixed_bfloat16'),
|
||||||
|
mp_policy.Policy('infer'),
|
||||||
|
mp_policy.Policy('infer_float32_vars'),
|
||||||
|
mp_policy.Policy('float16_with_float32_vars'),
|
||||||
|
mp_policy.Policy('float32', loss_scale=2.),
|
||||||
|
mp_policy.Policy('float32', loss_scale=None),
|
||||||
|
mp_policy.Policy('mixed_float16', loss_scale=2.),
|
||||||
|
mp_policy.Policy('mixed_float16', loss_scale=None),
|
||||||
|
mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
|
||||||
|
mp_policy.Policy('mixed_bfloat16', loss_scale=None),
|
||||||
|
):
|
||||||
|
config = policy.get_config()
|
||||||
|
new_policy = mp_policy.Policy.from_config(config)
|
||||||
|
# Comparing strings is the easiest way to ensure the policies are the
|
||||||
|
# same, as policy does not override the == operator.
|
||||||
|
self.assertEqual(str(policy), str(new_policy))
|
||||||
|
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_serialization(self):
|
||||||
|
# Test policies that are equivalent to a single dtype
|
||||||
|
for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
|
||||||
|
policy = mp_policy.Policy(policy_name)
|
||||||
|
config = mp_policy.serialize(policy)
|
||||||
|
self.assertEqual(config, policy_name)
|
||||||
|
new_policy = mp_policy.deserialize(config)
|
||||||
|
self.assertEqual(str(policy), str(new_policy))
|
||||||
|
|
||||||
|
# Test "infer" policy
|
||||||
|
policy = mp_policy.Policy('infer')
|
||||||
|
config = mp_policy.serialize(policy)
|
||||||
|
self.assertIsNone(config)
|
||||||
|
new_policy = mp_policy.deserialize(config)
|
||||||
|
self.assertEqual(str(policy), str(new_policy))
|
||||||
|
|
||||||
|
class MyPolicy(mp_policy.Policy):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Test policies that do not override the loss scale
|
||||||
|
for policy in (
|
||||||
|
mp_policy.Policy('mixed_float16'),
|
||||||
|
mp_policy.Policy('mixed_bfloat16'),
|
||||||
|
mp_policy.Policy('infer_with_float32_vars'),
|
||||||
|
mp_policy.Policy('float16_with_float32_vars'),
|
||||||
|
MyPolicy('float32')
|
||||||
|
):
|
||||||
|
config = mp_policy.serialize(policy)
|
||||||
|
self.assertEqual(config, {'class_name': policy.__class__.__name__,
|
||||||
|
'config': {'name': policy.name}})
|
||||||
|
new_policy = mp_policy.deserialize(config,
|
||||||
|
custom_objects={'MyPolicy': MyPolicy})
|
||||||
|
self.assertEqual(str(policy), str(new_policy))
|
||||||
|
|
||||||
|
# Test policies that override the loss scale
|
||||||
|
for policy in (
|
||||||
|
mp_policy.Policy('float32', loss_scale=2.),
|
||||||
|
mp_policy.Policy('float32', loss_scale=None),
|
||||||
|
mp_policy.Policy('mixed_float16', loss_scale=2.),
|
||||||
|
mp_policy.Policy('mixed_float16', loss_scale=None),
|
||||||
|
mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
|
||||||
|
mp_policy.Policy('mixed_bfloat16', loss_scale=None),
|
||||||
|
mp_policy.Policy('infer_with_float32_vars', loss_scale=2.),
|
||||||
|
mp_policy.Policy('infer_with_float32_vars', loss_scale=None),
|
||||||
|
mp_policy.Policy('float16_with_float32_vars', loss_scale=2.),
|
||||||
|
mp_policy.Policy('float16_with_float32_vars', loss_scale=None),
|
||||||
|
):
|
||||||
|
config = mp_policy.serialize(policy)
|
||||||
|
expected_loss_scale_config = None
|
||||||
|
if policy.loss_scale:
|
||||||
|
expected_loss_scale_config = {
|
||||||
|
'class_name': 'FixedLossScale',
|
||||||
|
'config': {'loss_scale_value': 2.}
|
||||||
|
}
|
||||||
|
self.assertEqual(
|
||||||
|
config, {
|
||||||
|
'class_name': policy.__class__.__name__,
|
||||||
|
'config': {
|
||||||
|
'name': policy.name,
|
||||||
|
'loss_scale': expected_loss_scale_config
|
||||||
|
}
|
||||||
|
})
|
||||||
|
new_policy = mp_policy.deserialize(
|
||||||
|
config, custom_objects={'MyPolicy': MyPolicy})
|
||||||
|
self.assertEqual(str(policy), str(new_policy))
|
||||||
|
|
||||||
@testing_utils.enable_v2_dtype_behavior
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
def test_error_if_graph_rewrite_enabled(self):
|
def test_error_if_graph_rewrite_enabled(self):
|
||||||
try:
|
try:
|
||||||
|
@ -43,7 +43,7 @@ def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
|
|||||||
certain value.
|
certain value.
|
||||||
"""
|
"""
|
||||||
@custom_gradient.custom_gradient
|
@custom_gradient.custom_gradient
|
||||||
def identity_with_grad_check(x):
|
def _identity_with_grad_check(x):
|
||||||
"""Function that asserts it's gradient has a certain value."""
|
"""Function that asserts it's gradient has a certain value."""
|
||||||
x = array_ops.identity(x)
|
x = array_ops.identity(x)
|
||||||
def grad(dx):
|
def grad(dx):
|
||||||
@ -57,6 +57,10 @@ def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
|
|||||||
dx = array_ops.identity(dx)
|
dx = array_ops.identity(dx)
|
||||||
return dx
|
return dx
|
||||||
return x, grad
|
return x, grad
|
||||||
|
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
||||||
|
# function. So we define and return a non-decorated function.
|
||||||
|
def identity_with_grad_check(x):
|
||||||
|
return _identity_with_grad_check(x)
|
||||||
return identity_with_grad_check
|
return identity_with_grad_check
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +81,7 @@ def create_identity_with_nan_gradients_fn(have_nan_gradients):
|
|||||||
`have_nan_gradients` is True.
|
`have_nan_gradients` is True.
|
||||||
"""
|
"""
|
||||||
@custom_gradient.custom_gradient
|
@custom_gradient.custom_gradient
|
||||||
def identity_with_nan_gradients(x):
|
def _identity_with_nan_gradients(x):
|
||||||
"""Function whose gradient is NaN iff `have_nan_gradients` is True."""
|
"""Function whose gradient is NaN iff `have_nan_gradients` is True."""
|
||||||
x = array_ops.identity(x)
|
x = array_ops.identity(x)
|
||||||
def grad(dx):
|
def grad(dx):
|
||||||
@ -93,4 +97,8 @@ def create_identity_with_nan_gradients_fn(have_nan_gradients):
|
|||||||
lambda: dx
|
lambda: dx
|
||||||
)
|
)
|
||||||
return x, grad
|
return x, grad
|
||||||
|
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
||||||
|
# function. So we define and return a non-decorated function.
|
||||||
|
def identity_with_nan_gradients(x):
|
||||||
|
return _identity_with_nan_gradients(x)
|
||||||
return identity_with_nan_gradients
|
return identity_with_nan_gradients
|
||||||
|
@ -26,4 +26,12 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,4 +26,12 @@ tf_class {
|
|||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user