Support Layer.get_config with a mixed_[b]float16 dtype policy.
Note saving a model with Model.save() is still unsupported. PiperOrigin-RevId: 272759486
This commit is contained in:
parent
0c25833dfe
commit
3c30766e1b
@ -468,8 +468,10 @@ class Layer(module.Module):
|
||||
config = {'name': self.name, 'trainable': self.trainable}
|
||||
if hasattr(self, '_batch_input_shape'):
|
||||
config['batch_input_shape'] = self._batch_input_shape
|
||||
# TODO(reedwm): Remove the hasattr(self, 'dtype') check. All layers have a
|
||||
# dtype.
|
||||
if hasattr(self, 'dtype'):
|
||||
config['dtype'] = self.dtype
|
||||
config['dtype'] = policy.serialize(self._dtype_policy)
|
||||
if hasattr(self, 'dynamic'):
|
||||
# Only include `dynamic` in the `config` if it is `True`
|
||||
if self.dynamic:
|
||||
@ -484,7 +486,6 @@ class Layer(module.Module):
|
||||
if len(extra_args) > 1 and hasattr(self.get_config, '_is_default'):
|
||||
raise NotImplementedError('Layers with arguments in `__init__` must '
|
||||
'override `get_config`.')
|
||||
# TODO(reedwm): Handle serializing self._dtype_policy.
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
@ -1677,6 +1678,8 @@ class Layer(module.Module):
|
||||
"""Sets self._dtype_policy."""
|
||||
if isinstance(dtype, policy.Policy):
|
||||
self._dtype_policy = dtype
|
||||
elif isinstance(dtype, dict):
|
||||
self._dtype_policy = policy.deserialize(dtype)
|
||||
elif dtype:
|
||||
self._dtype_policy = policy.Policy(dtypes.as_dtype(dtype).name)
|
||||
else:
|
||||
|
@ -94,6 +94,9 @@ class AddLayer(AssertTypeLayer):
|
||||
**kwargs: Passed to AssertTypeLayer constructor.
|
||||
"""
|
||||
self._regularizer = regularizer
|
||||
if isinstance(regularizer, dict):
|
||||
self._regularizer = regularizers.deserialize(regularizer,
|
||||
custom_objects=globals())
|
||||
self._use_operator = use_operator
|
||||
self._var_name = var_name
|
||||
super(AddLayer, self).__init__(**kwargs)
|
||||
@ -116,8 +119,7 @@ class AddLayer(AssertTypeLayer):
|
||||
|
||||
def get_config(self):
|
||||
config = super(AddLayer, self).get_config()
|
||||
assert self._regularizer is None, (
|
||||
'regularizer must be None to get config for AddLayer')
|
||||
config['regularizer'] = regularizers.serialize(self._regularizer)
|
||||
config['use_operator'] = self._use_operator
|
||||
config['var_name'] = self._var_name
|
||||
config['assert_type'] = self._assert_type
|
||||
@ -159,6 +161,9 @@ class IdentityRegularizer(regularizers.Regularizer):
|
||||
assert x.dtype == dtypes.float32
|
||||
return array_ops.identity(x)
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
|
||||
# If called outside any strategy.scope() calls, this will return the default
|
||||
# strategy.
|
||||
@ -429,6 +434,83 @@ class KerasLayerTest(keras_parameterized.TestCase):
|
||||
self._test_checkpointing_layer_weights(
|
||||
strategy_fn, mixed_prec_when_saving=False, mixed_prec_when_loading=True)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_config(self, strategy_fn):
|
||||
x = constant_op.constant([1.], dtype=dtypes.float16)
|
||||
with strategy_fn().scope():
|
||||
for layer, dtype in (
|
||||
(AddLayer(), 'float32'),
|
||||
(AddLayer(dtype='float64'), 'float64'),
|
||||
(AddLayer(dtype=policy.Policy('float64')), 'float64')):
|
||||
config = layer.get_config()
|
||||
self.assertEqual(config['dtype'], dtype)
|
||||
self.assertIsInstance(config['dtype'], str)
|
||||
layer = AddLayer.from_config(config)
|
||||
self.assertEqual(layer.dtype, dtype)
|
||||
self.assertEqual(layer(x).dtype, dtype)
|
||||
self.assertEqual(layer.v.dtype, dtype)
|
||||
|
||||
layer = AddLayer(dtype=policy.Policy('mixed_float16'))
|
||||
config = layer.get_config()
|
||||
self.assertEqual(config['dtype'],
|
||||
{'class_name': 'Policy',
|
||||
'config': {'name': 'mixed_float16'}})
|
||||
layer = AddLayer.from_config(config)
|
||||
self.assertEqual(layer.dtype, 'float32')
|
||||
self.assertEqual(layer(x).dtype, 'float16')
|
||||
self.assertEqual(layer.v.dtype, 'float32')
|
||||
|
||||
layer = AddLayer(dtype=policy.Policy('mixed_float16', loss_scale=None))
|
||||
config = layer.get_config()
|
||||
self.assertEqual(config['dtype'],
|
||||
{'class_name': 'Policy',
|
||||
'config': {'name': 'mixed_float16',
|
||||
'loss_scale': None}})
|
||||
layer = AddLayer.from_config(config)
|
||||
self.assertEqual(layer.dtype, 'float32')
|
||||
self.assertEqual(layer(x).dtype, 'float16')
|
||||
self.assertEqual(layer.v.dtype, 'float32')
|
||||
|
||||
layer = AddLayer(dtype=policy.Policy('float64', loss_scale=2.))
|
||||
config = layer.get_config()
|
||||
self.assertEqual(config['dtype'],
|
||||
{'class_name': 'Policy',
|
||||
'config': {'name': 'float64',
|
||||
'loss_scale': {
|
||||
'class_name': 'FixedLossScale',
|
||||
'config': {'loss_scale_value': 2.0}}}})
|
||||
layer = AddLayer.from_config(config)
|
||||
self.assertEqual(layer.dtype, 'float64')
|
||||
self.assertEqual(layer(x).dtype, 'float64')
|
||||
self.assertEqual(layer.v.dtype, 'float64')
|
||||
|
||||
layer = AddLayer(dtype=policy.Policy('infer'))
|
||||
config = layer.get_config()
|
||||
self.assertIsNone(config['dtype'])
|
||||
layer = AddLayer.from_config(config)
|
||||
# If a layer is serialized with the "infer" policy, when deserialized into
|
||||
# TF 2 it will have the global policy instead of "infer". This is because
|
||||
# "infer" is serialized into None, and passing dtype=None in TensorFlow 2
|
||||
# indicates to use the global policy.
|
||||
self.assertEqual(layer.dtype, 'float32')
|
||||
self.assertEqual(layer(x).dtype, 'float32')
|
||||
self.assertEqual(layer.v.dtype, 'float32')
|
||||
|
||||
layer = AddLayer(dtype=policy.Policy('infer', loss_scale=2.))
|
||||
config = layer.get_config()
|
||||
self.assertEqual(config['dtype'],
|
||||
{'class_name': 'Policy',
|
||||
'config': {'name': 'infer',
|
||||
'loss_scale': {
|
||||
'class_name': 'FixedLossScale',
|
||||
'config': {'loss_scale_value': 2.0}}}})
|
||||
layer = AddLayer.from_config(config)
|
||||
self.assertEqual(layer.dtype, None)
|
||||
self.assertEqual(layer(x).dtype, 'float16')
|
||||
self.assertEqual(layer.v.dtype, 'float16')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_delete_variable(self):
|
||||
layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
|
||||
@ -485,6 +567,15 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
'testcase_name': 'infer',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'policy_name': 'mixed_float16'
|
||||
}, {
|
||||
'testcase_name': 'get_config',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'get_config': True
|
||||
}, {
|
||||
'testcase_name': 'get_config_with_regularizer',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'get_config': True,
|
||||
'use_regularizer': True,
|
||||
}, {
|
||||
'testcase_name': 'norun_distributed',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
@ -496,6 +587,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
use_operator=False,
|
||||
use_regularizer=False,
|
||||
policy_name='mixed_float16',
|
||||
get_config=False,
|
||||
experimental_run_tf_function=True):
|
||||
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
|
||||
return
|
||||
@ -521,6 +613,12 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
layer_list += [layer, cast_f32_layer]
|
||||
model = testing_utils.get_model_from_layers(
|
||||
layer_list, input_shape=(1,), input_dtype=dtypes.float16)
|
||||
if get_config:
|
||||
config = model.get_config()
|
||||
model = model.__class__.from_config(
|
||||
config, custom_objects={'AddLayer': AddLayer})
|
||||
(layer,) = (layer for layer in model.layers
|
||||
if isinstance(layer, AddLayer))
|
||||
|
||||
def loss_fn(y_true, y_pred):
|
||||
del y_true
|
||||
@ -703,6 +801,15 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
'testcase_name': 'pass_loss_scale_to_policy',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'pass_loss_scale_to_policy': True,
|
||||
}, {
|
||||
'testcase_name': 'get_config',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'get_config': True,
|
||||
}, {
|
||||
'testcase_name': 'get_config_and_pass_loss_scale_to_policy',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'get_config': True,
|
||||
'pass_loss_scale_to_policy': True,
|
||||
}, {
|
||||
'testcase_name': 'norun_distributed',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
@ -711,6 +818,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
def test_dynamic_loss_scaling(self,
|
||||
strategy_fn,
|
||||
pass_loss_scale_to_policy=False,
|
||||
get_config=False,
|
||||
experimental_run_tf_function=True):
|
||||
if not self._is_strategy_supported(strategy_fn):
|
||||
return
|
||||
@ -746,6 +854,12 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
y = core.Lambda(identity_with_grad_check_fn)(y)
|
||||
y = math_ops.cast(y, dtypes.float32)
|
||||
model = models.Model(inputs=x, outputs=y)
|
||||
if get_config:
|
||||
config = model.get_config()
|
||||
model = model.__class__.from_config(
|
||||
config, custom_objects={'AddLayer': AddLayer})
|
||||
(layer,) = (layer for layer in model.layers
|
||||
if isinstance(layer, AddLayer))
|
||||
|
||||
def loss_fn(y_true, y_pred):
|
||||
del y_true
|
||||
@ -980,6 +1094,8 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
|
||||
if not self._is_strategy_supported(strategy_fn):
|
||||
return
|
||||
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
|
||||
# as well.
|
||||
strategy = strategy_fn()
|
||||
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
||||
not context.executing_eagerly()):
|
||||
|
@ -24,8 +24,9 @@ import six
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -306,12 +307,15 @@ class Policy(object):
|
||||
|
||||
if loss_scale == USE_DEFAULT:
|
||||
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
||||
self._using_default_loss_scale = True
|
||||
else:
|
||||
self._using_default_loss_scale = False
|
||||
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
||||
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
|
||||
'float16 policies. You passed loss_scale=%r for policy '
|
||||
'%s. Consider not passing any loss_scale instead.' %
|
||||
(loss_scale, name))
|
||||
self._loss_scale = loss_scale_module.get(loss_scale)
|
||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||
|
||||
def _parse_name(self, name):
|
||||
"""Parses a Policy name into a compute and variable dtype.
|
||||
@ -426,6 +430,25 @@ class Policy(object):
|
||||
def __repr__(self):
|
||||
return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'name': self.name
|
||||
}
|
||||
if not self._using_default_loss_scale:
|
||||
# We only include the loss scale if the default loss scale is not used.
|
||||
# This allows us to change the loss scale config format without breaking
|
||||
# users who use the default loss scale.
|
||||
config['loss_scale'] = keras_loss_scale_module.serialize(self.loss_scale)
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
if 'loss_scale' in config and isinstance(config['loss_scale'], dict):
|
||||
config = config.copy()
|
||||
config['loss_scale'] = keras_loss_scale_module.deserialize(
|
||||
config['loss_scale'], custom_objects=custom_objects)
|
||||
return cls(**config)
|
||||
|
||||
|
||||
def with_input_dtype(policy, dtype):
|
||||
"""Copies "infer" `policy`, adding `dtype` to it.
|
||||
@ -556,3 +579,55 @@ def policy_scope(policy):
|
||||
yield
|
||||
finally:
|
||||
set_policy(old_policy)
|
||||
|
||||
|
||||
def _is_convertible_to_dtype(dtype):
|
||||
try:
|
||||
dtypes.as_dtype(dtype)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _policy_equivalent_to_dtype(policy):
|
||||
"""Returns True if the Policy is equivalent to a single dtype.
|
||||
|
||||
A policy is equivalent to a single dtype if the policy's compute and variable
|
||||
dtypes are the same and the policy does not cause the layer/model to have
|
||||
additional behavior, such as loss scaling.
|
||||
|
||||
The "infer" policy is considered equivalent to a single dtype.
|
||||
|
||||
Args:
|
||||
policy: A Policy.
|
||||
|
||||
Returns:
|
||||
True, if the policy is equivalent to a single dtype.
|
||||
"""
|
||||
# We use type() instead of isinstance because a sublcass of Policy is never
|
||||
# equivalent to a dtype.
|
||||
return (type(policy) == Policy and # pylint: disable=unidiomatic-typecheck
|
||||
list(policy.get_config().keys()) == ['name'] and
|
||||
(policy.name == 'infer' or _is_convertible_to_dtype(policy.name)))
|
||||
|
||||
|
||||
def serialize(policy):
|
||||
if _policy_equivalent_to_dtype(policy):
|
||||
# We return either None or the policy name for compatibility with older
|
||||
# versions of Keras. If the policy name is returned, it is a dtype string
|
||||
# such as 'float32'.
|
||||
return None if policy.name == 'infer' else policy.name
|
||||
return generic_utils.serialize_keras_object(policy)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
if isinstance(config, str) and _is_convertible_to_dtype(config):
|
||||
return Policy(config)
|
||||
if config is None:
|
||||
return Policy('infer')
|
||||
module_objects = {'Policy': Policy}
|
||||
return generic_utils.deserialize_keras_object(
|
||||
config,
|
||||
module_objects=module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='dtype policy')
|
||||
|
@ -227,6 +227,97 @@ class PolicyTest(test.TestCase):
|
||||
'infer_with_float32_vars')
|
||||
self.assertEqual(mp_policy.global_policy().name, default_policy)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_config(self):
|
||||
for policy in (
|
||||
mp_policy.Policy('float16'),
|
||||
mp_policy.Policy('float32'),
|
||||
mp_policy.Policy('int16'),
|
||||
mp_policy.Policy('mixed_float16'),
|
||||
mp_policy.Policy('mixed_bfloat16'),
|
||||
mp_policy.Policy('infer'),
|
||||
mp_policy.Policy('infer_float32_vars'),
|
||||
mp_policy.Policy('float16_with_float32_vars'),
|
||||
mp_policy.Policy('float32', loss_scale=2.),
|
||||
mp_policy.Policy('float32', loss_scale=None),
|
||||
mp_policy.Policy('mixed_float16', loss_scale=2.),
|
||||
mp_policy.Policy('mixed_float16', loss_scale=None),
|
||||
mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
|
||||
mp_policy.Policy('mixed_bfloat16', loss_scale=None),
|
||||
):
|
||||
config = policy.get_config()
|
||||
new_policy = mp_policy.Policy.from_config(config)
|
||||
# Comparing strings is the easiest way to ensure the policies are the
|
||||
# same, as policy does not override the == operator.
|
||||
self.assertEqual(str(policy), str(new_policy))
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_serialization(self):
|
||||
# Test policies that are equivalent to a single dtype
|
||||
for policy_name in 'float16', 'float32', 'int8', 'string', 'bool':
|
||||
policy = mp_policy.Policy(policy_name)
|
||||
config = mp_policy.serialize(policy)
|
||||
self.assertEqual(config, policy_name)
|
||||
new_policy = mp_policy.deserialize(config)
|
||||
self.assertEqual(str(policy), str(new_policy))
|
||||
|
||||
# Test "infer" policy
|
||||
policy = mp_policy.Policy('infer')
|
||||
config = mp_policy.serialize(policy)
|
||||
self.assertIsNone(config)
|
||||
new_policy = mp_policy.deserialize(config)
|
||||
self.assertEqual(str(policy), str(new_policy))
|
||||
|
||||
class MyPolicy(mp_policy.Policy):
|
||||
pass
|
||||
|
||||
# Test policies that do not override the loss scale
|
||||
for policy in (
|
||||
mp_policy.Policy('mixed_float16'),
|
||||
mp_policy.Policy('mixed_bfloat16'),
|
||||
mp_policy.Policy('infer_with_float32_vars'),
|
||||
mp_policy.Policy('float16_with_float32_vars'),
|
||||
MyPolicy('float32')
|
||||
):
|
||||
config = mp_policy.serialize(policy)
|
||||
self.assertEqual(config, {'class_name': policy.__class__.__name__,
|
||||
'config': {'name': policy.name}})
|
||||
new_policy = mp_policy.deserialize(config,
|
||||
custom_objects={'MyPolicy': MyPolicy})
|
||||
self.assertEqual(str(policy), str(new_policy))
|
||||
|
||||
# Test policies that override the loss scale
|
||||
for policy in (
|
||||
mp_policy.Policy('float32', loss_scale=2.),
|
||||
mp_policy.Policy('float32', loss_scale=None),
|
||||
mp_policy.Policy('mixed_float16', loss_scale=2.),
|
||||
mp_policy.Policy('mixed_float16', loss_scale=None),
|
||||
mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
|
||||
mp_policy.Policy('mixed_bfloat16', loss_scale=None),
|
||||
mp_policy.Policy('infer_with_float32_vars', loss_scale=2.),
|
||||
mp_policy.Policy('infer_with_float32_vars', loss_scale=None),
|
||||
mp_policy.Policy('float16_with_float32_vars', loss_scale=2.),
|
||||
mp_policy.Policy('float16_with_float32_vars', loss_scale=None),
|
||||
):
|
||||
config = mp_policy.serialize(policy)
|
||||
expected_loss_scale_config = None
|
||||
if policy.loss_scale:
|
||||
expected_loss_scale_config = {
|
||||
'class_name': 'FixedLossScale',
|
||||
'config': {'loss_scale_value': 2.}
|
||||
}
|
||||
self.assertEqual(
|
||||
config, {
|
||||
'class_name': policy.__class__.__name__,
|
||||
'config': {
|
||||
'name': policy.name,
|
||||
'loss_scale': expected_loss_scale_config
|
||||
}
|
||||
})
|
||||
new_policy = mp_policy.deserialize(
|
||||
config, custom_objects={'MyPolicy': MyPolicy})
|
||||
self.assertEqual(str(policy), str(new_policy))
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_error_if_graph_rewrite_enabled(self):
|
||||
try:
|
||||
|
@ -43,7 +43,7 @@ def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
|
||||
certain value.
|
||||
"""
|
||||
@custom_gradient.custom_gradient
|
||||
def identity_with_grad_check(x):
|
||||
def _identity_with_grad_check(x):
|
||||
"""Function that asserts it's gradient has a certain value."""
|
||||
x = array_ops.identity(x)
|
||||
def grad(dx):
|
||||
@ -57,6 +57,10 @@ def create_identity_with_grad_check_fn(expected_gradient, expected_dtype=None):
|
||||
dx = array_ops.identity(dx)
|
||||
return dx
|
||||
return x, grad
|
||||
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
||||
# function. So we define and return a non-decorated function.
|
||||
def identity_with_grad_check(x):
|
||||
return _identity_with_grad_check(x)
|
||||
return identity_with_grad_check
|
||||
|
||||
|
||||
@ -77,7 +81,7 @@ def create_identity_with_nan_gradients_fn(have_nan_gradients):
|
||||
`have_nan_gradients` is True.
|
||||
"""
|
||||
@custom_gradient.custom_gradient
|
||||
def identity_with_nan_gradients(x):
|
||||
def _identity_with_nan_gradients(x):
|
||||
"""Function whose gradient is NaN iff `have_nan_gradients` is True."""
|
||||
x = array_ops.identity(x)
|
||||
def grad(dx):
|
||||
@ -93,4 +97,8 @@ def create_identity_with_nan_gradients_fn(have_nan_gradients):
|
||||
lambda: dx
|
||||
)
|
||||
return x, grad
|
||||
# Keras sometimes has trouble serializing Lambda layers with a decorated
|
||||
# function. So we define and return a non-decorated function.
|
||||
def identity_with_nan_gradients(x):
|
||||
return _identity_with_nan_gradients(x)
|
||||
return identity_with_nan_gradients
|
||||
|
@ -26,4 +26,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -26,4 +26,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user