Support Layer.get_config with a mixed_[b]float16 dtype policy.

Note saving a model with Model.save() is still unsupported.

PiperOrigin-RevId: 272759486
This commit is contained in:
Reed Wanderman-Milne 2019-10-03 15:52:24 -07:00 committed by TensorFlower Gardener
parent 0c25833dfe
commit 3c30766e1b
7 changed files with 317 additions and 8 deletions

View File

@ -468,8 +468,10 @@ class Layer(module.Module):
config = {'name': self.name, 'trainable': self.trainable}
if hasattr(self, '_batch_input_shape'):
config['batch_input_shape'] = self._batch_input_shape
# TODO(reedwm): Remove the hasattr(self, 'dtype') check. All layers have a
# dtype.
if hasattr(self, 'dtype'):
config['dtype'] = self.dtype
config['dtype'] = policy.serialize(self._dtype_policy)
if hasattr(self, 'dynamic'):
# Only include `dynamic` in the `config` if it is `True`
if self.dynamic:
@ -484,7 +486,6 @@ class Layer(module.Module):
if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
raise NotImplementedError('Layers with arguments in `__init__` must '
'override `get_config`.')
# TODO(reedwm): Handle serializing self._dtype_policy.
return config
@classmethod
@ -1677,6 +1678,8 @@ class Layer(module.Module):
"""Sets self._dtype_policy."""
if isinstance(dtype, policy.Policy):
self._dtype_policy = dtype
elif isinstance(dtype, dict):
self._dtype_policy = policy.deserialize(dtype)
elif dtype:
self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
else:

View File

@ -94,6 +94,9 @@ class AddLayer(AssertTypeLayer):
**kwargs: Passed to AssertTypeLayer constructor.
"""
self._regularizer = regularizer
if isinstance(regularizer, dict):
self._regularizer = regularizers.deserialize(regularizer,
custom_objects=globals())
self._use_operator = use_operator
self._var_name = var_name
super(AddLayer, self).__init__(**kwargs)
@ -116,8 +119,7 @@ class AddLayer(AssertTypeLayer):
def get_config(self):
config = super(AddLayer, self).get_config()
assert self._regularizer is None, (
'regularizer must be None to get config for AddLayer')
config['regularizer'] = regularizers.serialize(self._regularizer)
config['use_operator'] = self._use_operator
config['var_name'] = self._var_name
config['assert_type'] = self._assert_type
@ -159,6 +161,9 @@ class IdentityRegularizer(regularizers.Regularizer):
assert x.dtype == dtypes.float32
return array_ops.identity(x)
def get_config(self):
return {}
# If called outside any strategy.scope() calls, this will return the default
# strategy.
@ -429,6 +434,83 @@ class KerasLayerTest(keras_parameterized.TestCase):
self._test_checkpointing_layer_weights(
strategy_fn, mixed_prec_when_saving=False, mixed_prec_when_loading=True)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
@testing_utils.enable_v2_dtype_behavior
def test_config(self, strategy_fn):
x = constant_op.constant([1.], dtype=dtypes.float16)
with strategy_fn().scope():
for layer, dtype in (
(AddLayer(), 'float32'),
(AddLayer(dtype='float64'), 'float64'),
(AddLayer(dtype=policy.Policy('float64')), 'float64')):
config = layer.get_config()
self.assertEqual(config['dtype'], dtype)
self.assertIsInstance(config['dtype'], str)
layer = AddLayer.from_config(config)
self.assertEqual(layer.dtype, dtype)
self.assertEqual(layer(x).dtype, dtype)
self.assertEqual(layer.v.dtype, dtype)
layer = AddLayer(dtype=policy.Policy('mixed_float16'))
config = layer.get_config()
self.assertEqual(config['dtype'],
{'class_name': 'Policy',
'config': {'name': 'mixed_float16'}})
layer = AddLayer.from_config(config)
self.assertEqual(layer.dtype, 'float32')
self.assertEqual(layer(x).dtype, 'float16')
self.assertEqual(layer.v.dtype, 'float32')
layer = AddLayer(dtype=policy.Policy('mixed_float16', loss_scale=None))
config = layer.get_config()
self.assertEqual(config['dtype'],
{'class_name': 'Policy',
'config': {'name': 'mixed_float16',
'loss_scale': None}})
layer = AddLayer.from_config(config)
self.assertEqual(layer.dtype, 'float32')
self.assertEqual(layer(x).dtype, 'float16')
self.assertEqual(layer.v.dtype, 'float32')
layer = AddLayer(dtype=policy.Policy('float64', loss_scale=2.))
config = layer.get_config()
self.assertEqual(config['dtype'],
{'class_name': 'Policy',
'config': {'name': 'float64',
'loss_scale': {
'class_name': 'FixedLossScale',
'config': {'loss_scale_value': 2.0}}}})
layer = AddLayer.from_config(config)
self.assertEqual(layer.dtype, 'float64')
self.assertEqual(layer(x).dtype, 'float64')
self.assertEqual(layer.v.dtype, 'float64')
layer = AddLayer(dtype=policy.Policy('infer'))
config = layer.get_config()
self.assertIsNone(config['dtype'])
layer = AddLayer.from_config(config)
# If a layer is serialized with the "infer" policy, when deserialized into
# TF 2 it will have the global policy instead of "infer". This is because
# "infer" is serialized into None, and passing dtype=None in TensorFlow 2
# indicates to use the global policy.
self.assertEqual(layer.dtype, 'float32')
self.assertEqual(layer(x).dtype, 'float32')
self.assertEqual(layer.v.dtype, 'float32')
layer = AddLayer(dtype=policy.Policy('infer', loss_scale=2.))
config = layer.get_config()
self.assertEqual(config['dtype'],
{'class_name': 'Policy',
'config': {'name': 'infer',
'loss_scale': {
'class_name': 'FixedLossScale',
'config': {'loss_scale_value': 2.0}}}})
layer = AddLayer.from_config(config)
self.assertEqual(layer.dtype, None)
self.assertEqual(layer(x).dtype, 'float16')
self.assertEqual(layer.v.dtype, 'float16')
@test_util.run_in_graph_and_eager_modes
def test_delete_variable(self):
layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
@ -485,6 +567,15 @@ class KerasModelTest(keras_parameterized.TestCase):
'testcase_name': 'infer',
'strategy_fn': create_mirrored_strategy,
'policy_name': 'mixed_float16'
}, {
'testcase_name': 'get_config',
'strategy_fn': create_mirrored_strategy,
'get_config': True
}, {
'testcase_name': 'get_config_with_regularizer',
'strategy_fn': create_mirrored_strategy,
'get_config': True,
'use_regularizer': True,
}, {
'testcase_name': 'norun_distributed',
'strategy_fn': create_mirrored_strategy,
@ -496,6 +587,7 @@ class KerasModelTest(keras_parameterized.TestCase):
use_operator=False,
use_regularizer=False,
policy_name='mixed_float16',
get_config=False,
experimental_run_tf_function=True):
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
return
@ -521,6 +613,12 @@ class KerasModelTest(keras_parameterized.TestCase):
layer_list += [layer, cast_f32_layer]
model = testing_utils.get_model_from_layers(
layer_list, input_shape=(1,), input_dtype=dtypes.float16)
if get_config:
config = model.get_config()
model = model.__class__.from_config(
config, custom_objects={'AddLayer': AddLayer})
(layer,) = (layer for layer in model.layers
if isinstance(layer, AddLayer))
def loss_fn(y_true, y_pred):
del y_true
@ -703,6 +801,15 @@ class KerasModelTest(keras_parameterized.TestCase):
'testcase_name': 'pass_loss_scale_to_policy',
'strategy_fn': create_mirrored_strategy,
'pass_loss_scale_to_policy': True,
}, {
'testcase_name': 'get_config',
'strategy_fn': create_mirrored_strategy,
'get_config': True,
}, {
'testcase_name': 'get_config_and_pass_loss_scale_to_policy',
'strategy_fn': create_mirrored_strategy,
'get_config': True,
'pass_loss_scale_to_policy': True,
}, {
'testcase_name': 'norun_distributed',
'strategy_fn': create_mirrored_strategy,
@ -711,6 +818,7 @@ class KerasModelTest(keras_parameterized.TestCase):
def test_dynamic_loss_scaling(self,
strategy_fn,
pass_loss_scale_to_policy=False,
get_config=False,
experimental_run_tf_function=True):
if not self._is_strategy_supported(strategy_fn):
return
@ -746,6 +854,12 @@ class KerasModelTest(keras_parameterized.TestCase):
y = core.Lambda(identity_with_grad_check_fn)(y)
y = math_ops.cast(y, dtypes.float32)
model = models.Model(inputs=x, outputs=y)
if get_config:
config = model.get_config()
model = model.__class__.from_config(
config, custom_objects={'AddLayer': AddLayer})
(layer,) = (layer for layer in model.layers
if isinstance(layer, AddLayer))
def loss_fn(y_true, y_pred):
del y_true
@ -980,6 +1094,8 @@ class KerasModelTest(keras_parameterized.TestCase):
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
if not self._is_strategy_supported(strategy_fn):
return
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
# as well.
strategy = strategy_fn()
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
not context.executing_eagerly()):

View File

@ -24,8 +24,9 @@ import six
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util.tf_export import keras_export
@ -306,12 +307,15 @@ class Policy(object):
if loss_scale == USE_DEFAULT:
loss_scale = 'dynamic' if name == 'mixed_float16' else None
self._using_default_loss_scale = True
else:
self._using_default_loss_scale = False
if loss_scale and self._compute_dtype not in (None, 'float16'):
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
'float16 policies. You passed loss_scale=%r for policy '
'%s. Consider not passing any loss_scale instead.' %
(loss_scale, name))
self._loss_scale = loss_scale_module.get(loss_scale)
self._loss_scale = keras_loss_scale_module.get(loss_scale)
def _parse_name(self, name):
"""Parses a Policy name into a compute and variable dtype.
@ -426,6 +430,25 @@ class Policy(object):
def __repr__(self):
return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale)
def get_config(self):
config = {
'name': self.name
}
if not self._using_default_loss_scale:
# We only include the loss scale if the default loss scale is not used.
# This allows us to change the loss scale config format without breaking
# users who use the default loss scale.
config['loss_scale'] = keras_loss_scale_module.serialize(self.loss_scale)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
if 'loss_scale' in config and isinstance(config['loss_scale'], dict):
config = config.copy()
config['loss_scale'] = keras_loss_scale_module.deserialize(
config['loss_scale'], custom_objects=custom_objects)
return cls(**config)
def with_input_dtype(policy, dtype):
"""Copies "infer" `policy`, adding `dtype` to it.
@ -556,3 +579,55 @@ def policy_scope(policy):
yield
finally:
set_policy(old_policy)
def _is_convertible_to_dtype(dtype):
try:
dtypes.as_dtype(dtype)
return True
except TypeError:
return False
def _policy_equivalent_to_dtype(policy):
"""Returns True if the Policy is equivalent to a single dtype.
A policy is equivalent to a single dtype if the policy's compute and variable
dtypes are the same and the policy does not cause the layer/model to have
additional behavior, such as loss scaling.
The "infer" policy is considered equivalent to a single dtype.
Args:
policy: A Policy.
Returns:
True, if the policy is equivalent to a single dtype.
"""
# We use type() instead of isinstance because a sublcass of Policy is never
# equivalent to a dtype.
return (type(policy) == Policy and # pylint: disable=unidiomatic-typecheck
list(policy.get_config().keys()) == ['name'] and
(policy.name == 'infer' or _is_convertible_to_dtype(policy.name)))
def serialize(policy):
if _policy_equivalent_to_dtype(policy):
# We return either None or the policy name for compatibility with older
# versions of Keras. If the policy name is returned, it is a dtype string
# such as 'float32'.
return None if policy.name == 'infer' else policy.name
return generic_utils.serialize_keras_object(policy)
def deserialize(config, custom_objects=None):
if isinstance(config, str) and _is_convertible_to_dtype(config):
return Policy(config)
if config is None:
return Policy('infer')
module_objects = {'Policy': Policy}
return generic_utils.deserialize_keras_object(
config,
module_objects=module_objects,
custom_objects=custom_objects,
printable_module_name='dtype policy')

View File

@ -227,6 +227,97 @@ class PolicyTest(test.TestCase):
'infer_with_float32_vars')
self.assertEqual(mp_policy.global_policy().name, default_policy)
@testing_utils.enable_v2_dtype_behavior
def test_config(self):
for policy in (
mp_policy.Policy('float16'),
mp_policy.Policy('float32'),
mp_policy.Policy('int16'),
mp_policy.Policy('mixed_float16'),
mp_policy.Policy('mixed_bfloat16'),
mp_policy.Policy('infer'),
mp_policy.Policy('infer_float32_vars'),
mp_policy.Policy('float16_with_float32_vars'),
mp_policy.Policy('float32', loss_scale=2.),
mp_policy.Policy('float32', loss_scale=None),
mp_policy.Policy('mixed_float16', loss_scale=2.),
mp_policy.Policy('mixed_float16', loss_scale=None),
mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
mp_policy.Policy('mixed_bfloat16', loss_scale=None),
):
config = policy.get_config()
new_policy = mp_policy.Policy.from_config(config)
# Comparing strings is the easiest way to ensure the policies are the
# same, as policy does not override the == operator.
self.assertEqual(str(policy), str(new_policy))
@testing_utils.enable_v2_dtype_behavior
def test_serialization(self):
# Test policies that are equivalent to a single dtype
for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
policy = mp_policy.Policy(policy_name)
config = mp_policy.serialize(policy)
self.assertEqual(config, policy_name)
new_policy = mp_policy.deserialize(config)
self.assertEqual(str(policy), str(new_policy))
# Test "infer" policy
policy = mp_policy.Policy('infer')
config = mp_policy.serialize(policy)
self.assertIsNone(config)
new_policy = mp_policy.deserialize(config)
self.assertEqual(str(policy), str(new_policy))
class MyPolicy(mp_policy.Policy):
pass
# Test policies that do not override the loss scale
for policy in (
mp_policy.Policy('mixed_float16'),
mp_policy.Policy('mixed_bfloat16'),
mp_policy.Policy('infer_with_float32_vars'),
mp_policy.Policy('float16_with_float32_vars'),
MyPolicy('float32')
):
config = mp_policy.serialize(policy)
self.assertEqual(config, {'class_name': policy.__class__.__name__,
'config': {'name': policy.name}})
new_policy = mp_policy.deserialize(config,
custom_objects={'MyPolicy': MyPolicy})
self.assertEqual(str(policy), str(new_policy))
# Test policies that override the loss scale
for policy in (
mp_policy.Policy('float32', loss_scale=2.),
mp_policy.Policy('float32', loss_scale=None),
mp_policy.Policy('mixed_float16', loss_scale=2.),
mp_policy.Policy('mixed_float16', loss_scale=None),
mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
mp_policy.Policy('mixed_bfloat16', loss_scale=None),
mp_policy.Policy('infer_with_float32_vars', loss_scale=2.),
mp_policy.Policy('infer_with_float32_vars', loss_scale=None),
mp_policy.Policy('float16_with_float32_vars', loss_scale=2.),
mp_policy.Policy('float16_with_float32_vars', loss_scale=None),
):
config = mp_policy.serialize(policy)
expected_loss_scale_config = None
if policy.loss_scale:
expected_loss_scale_config = {
'class_name': 'FixedLossScale',
'config': {'loss_scale_value': 2.}
}
self.assertEqual(
config, {
'class_name': policy.__class__.__name__,
'config': {
'name': policy.name,
'loss_scale': expected_loss_scale_config
}
})
new_policy = mp_policy.deserialize(
config, custom_objects={'MyPolicy': MyPolicy})
self.assertEqual(str(policy), str(new_policy))
@testing_utils.enable_v2_dtype_behavior
def test_error_if_graph_rewrite_enabled(self):
try:

View File

@ -43,7 +43,7 @@ def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
certain value.
"""
@custom_gradient.custom_gradient
def identity_with_grad_check(x):
def _identity_with_grad_check(x):
"""Function that asserts it's gradient has a certain value."""
x = array_ops.identity(x)
def grad(dx):
@ -57,6 +57,10 @@ def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
dx = array_ops.identity(dx)
return dx
return x, grad
# Keras sometimes has trouble serializing Lambda layers with a decorated
# function. So we define and return a non-decorated function.
def identity_with_grad_check(x):
return _identity_with_grad_check(x)
return identity_with_grad_check
@ -77,7 +81,7 @@ def create_identity_with_nan_gradients_fn(have_nan_gradients):
`have_nan_gradients` is True.
"""
@custom_gradient.custom_gradient
def identity_with_nan_gradients(x):
def _identity_with_nan_gradients(x):
"""Function whose gradient is NaN iff `have_nan_gradients` is True."""
x = array_ops.identity(x)
def grad(dx):
@ -93,4 +97,8 @@ def create_identity_with_nan_gradients_fn(have_nan_gradients):
lambda: dx
)
return x, grad
# Keras sometimes has trouble serializing Lambda layers with a decorated
# function. So we define and return a non-decorated function.
def identity_with_nan_gradients(x):
return _identity_with_nan_gradients(x)
return identity_with_nan_gradients

View File

@ -26,4 +26,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -26,4 +26,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}