From 1b94fd1b8aa4d15b0bc0d2204b8ba57e5f7d61d2 Mon Sep 17 00:00:00 2001 From: Tomer Kaftan Date: Tue, 5 Jan 2021 12:54:23 -0800 Subject: [PATCH] Add support for creating Keras.inputs from arbitrary TypeSpecs. (Including TypeSpecs that don't have a dtype) PiperOrigin-RevId: 350200835 Change-Id: I06772c1d6ece689f17a72d787dcd12c6f611e7e3 --- RELEASE.md | 1 + tensorflow/python/keras/backend.py | 5 +- tensorflow/python/keras/engine/functional.py | 8 +- tensorflow/python/keras/engine/input_layer.py | 116 +++++++-- .../python/keras/engine/input_layer_test.py | 236 ++++++++++++++++++ .../python/keras/engine/keras_tensor.py | 23 ++ .../python/keras/saving/model_config.py | 5 +- ...tensorflow.keras.layers.-input-layer.pbtxt | 2 +- .../golden/v1/tensorflow.keras.layers.pbtxt | 2 +- .../api/golden/v1/tensorflow.keras.pbtxt | 2 +- ...tensorflow.keras.layers.-input-layer.pbtxt | 2 +- .../golden/v2/tensorflow.keras.layers.pbtxt | 2 +- .../api/golden/v2/tensorflow.keras.pbtxt | 2 +- 13 files changed, 369 insertions(+), 37 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 8f484eb5c06..c141ed035a0 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -35,6 +35,7 @@ * Discretization combiner implemented, with additional arg `epsilon`. * Improvements to model saving/loading: * `model.load_weights` now accepts paths to saved models. + * Keras inputs can now be created directly from arbitrary `tf.TypeSpecs`. * `tf.data`: * Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index a5b077be984..9925ad240b9 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -1337,7 +1337,6 @@ def placeholder(shape=None, if sparse: spec = sparse_tensor.SparseTensorSpec( shape=shape, dtype=dtype) - x = keras_tensor.SparseKerasTensor(spec, name=name) elif ragged: ragged_rank = 0 for i in range(1, len(shape)): @@ -1349,12 +1348,10 @@ def placeholder(shape=None, ragged_rank = i spec = ragged_tensor.RaggedTensorSpec( shape=shape, dtype=dtype, ragged_rank=ragged_rank) - - x = keras_tensor.RaggedKerasTensor(spec, name=name) else: spec = tensor_spec.TensorSpec( shape=shape, dtype=dtype, name=name) - x = keras_tensor.KerasTensor(spec, name=name) + x = keras_tensor.keras_tensor_from_type_spec(spec, name=name) else: with get_graph().as_default(): if sparse: diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index cbb82d220a0..e3efd50fb3c 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -27,6 +27,7 @@ import warnings from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer @@ -643,8 +644,11 @@ class Functional(training_lib.Model): # Dtype casting. tensor = math_ops.cast(tensor, dtype=ref_input.dtype) elif tf_utils.is_extension_type(tensor): - # Dtype casting. - tensor = math_ops.cast(tensor, dtype=ref_input.dtype) + # Dtype casting (If the extension type has a non-variant dtype and + # supports being cast) + ref_input_dtype = getattr(ref_input, 'dtype', None) + if ref_input_dtype is not None and ref_input_dtype != dtypes.variant: + tensor = math_ops.cast(tensor, dtype=ref_input_dtype) return tensor diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 75e0cc879f3..d9c155a1d84 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -32,6 +32,13 @@ from tensorflow.python.keras.utils import tf_utils from tensorflow.python.util.tf_export import keras_export +def _assert_other_arg_none(arg_name, arg): + if arg is not None: + raise ValueError('When `type_spec` is not None, all other args ' + 'except `name` must be None, ' + 'but %s is not None.' % arg_name) + + @keras_export('keras.layers.InputLayer') class InputLayer(base_layer.Layer): """Layer to be used as an entry point into a Network (a graph of layers). @@ -85,6 +92,9 @@ class InputLayer(base_layer.Layer): ragged dimensions. For more information about RaggedTensors, see [this guide](https://www.tensorflow.org/guide/ragged_tensors). Default to False. + type_spec: A `tf.TypeSpec` object to create Input from. This `tf.TypeSpec` + represents the entire batch. When provided, all other args except + name must be None. name: Optional name of the layer (string). """ @@ -93,10 +103,18 @@ class InputLayer(base_layer.Layer): batch_size=None, dtype=None, input_tensor=None, - sparse=False, + sparse=None, name=None, - ragged=False, + ragged=None, + type_spec=None, **kwargs): + self._init_input_shape = input_shape + self._init_batch_size = batch_size + self._init_dtype = dtype + self._init_sparse = sparse + self._init_ragged = ragged + self._init_type_spec = type_spec + strategy = distribution_strategy_context.get_strategy() if strategy and batch_size is not None and \ distributed_training_utils.global_batch_size_supported(strategy): @@ -135,8 +153,8 @@ class InputLayer(base_layer.Layer): (input_tensor.dtype, dtype)) super(InputLayer, self).__init__(dtype=dtype, name=name) self.built = True - self.sparse = sparse - self.ragged = ragged + self.sparse = True if sparse else False + self.ragged = True if ragged else False self.batch_size = batch_size self.supports_masking = True @@ -145,7 +163,32 @@ class InputLayer(base_layer.Layer): elif isinstance(input_shape, int): input_shape = (input_shape,) - if input_tensor is None: + if type_spec is not None: + args_that_must_be_none = [ + ('(input_)shape', self._init_input_shape), + ('batch_size', self._init_batch_size), + ('dtype', self._init_dtype), + ('input_tensor', input_tensor), + ('sparse', self._init_sparse), + ('ragged', self._init_ragged), + ] + for arg_name, arg in args_that_must_be_none: + _assert_other_arg_none(arg_name, arg) + if not keras_tensor.keras_tensors_enabled(): + raise ValueError('Creating Keras inputs from a type_spec is only ' + 'supported when eager execution is enabled.') + input_tensor = keras_tensor.keras_tensor_from_type_spec(type_spec) + if isinstance(input_tensor, keras_tensor.SparseKerasTensor): + self.sparse = True + if isinstance(input_tensor, keras_tensor.RaggedKerasTensor): + self.ragged = True + self.is_placeholder = True + try: + self._batch_input_shape = tuple(input_tensor.shape.as_list()) + except ValueError: + # If the shape cannot be represented as a tuple (e.g. unknown rank) + self._batch_input_shape = None + elif input_tensor is None: if input_shape is not None: batch_input_shape = (batch_size,) + tuple(input_shape) else: @@ -190,13 +233,19 @@ class InputLayer(base_layer.Layer): shape=input_tensor.shape, dtype=input_tensor.dtype, name=self.name) def get_config(self): - config = { - 'batch_input_shape': self._batch_input_shape, - 'dtype': self.dtype, - 'sparse': self.sparse, - 'ragged': self.ragged, - 'name': self.name - } + if self._init_type_spec is not None: + config = { + 'name': self.name, + 'type_spec': self._init_type_spec + } + else: + config = { + 'batch_input_shape': self._batch_input_shape, + 'dtype': self.dtype, + 'sparse': self.sparse, + 'ragged': self.ragged, + 'name': self.name, + } return config @property @@ -210,13 +259,14 @@ def Input( # pylint: disable=invalid-name batch_size=None, name=None, dtype=None, - sparse=False, + sparse=None, tensor=None, - ragged=False, + ragged=None, + type_spec=None, **kwargs): """`Input()` is used to instantiate a Keras tensor. - A Keras tensor is a TensorFlow symbolic tensor object, + A Keras tensor is a symbolic tensor-like object, which we augment with certain attributes that allow us to build a Keras model just by knowing the inputs and outputs of the model. @@ -248,6 +298,8 @@ def Input( # pylint: disable=invalid-name values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see [this guide](https://www.tensorflow.org/guide/ragged_tensors). + type_spec: A `tf.TypeSpec` object to create the input placeholder from. + When provided, all other args except name must be None. **kwargs: deprecated arguments support. Supports `batch_shape` and `batch_input_shape`. @@ -264,8 +316,8 @@ def Input( # pylint: disable=invalid-name ``` Note that even if eager execution is enabled, - `Input` produces a symbolic tensor (i.e. a placeholder). - This symbolic tensor can be used with other + `Input` produces a symbolic tensor-like object (i.e. a placeholder). + This symbolic tensor-like object can be used with other TensorFlow ops, as such: ```python @@ -273,11 +325,29 @@ def Input( # pylint: disable=invalid-name y = tf.square(x) ``` + However, the resulting model will not track any variables that were + used as inputs to TensorFlow ops. All variable usages must happen within + Keras layers to make sure they will be tracked by the model's weights. + + The Keras Input can also create a placeholder from an arbitrary `tf.TypeSpec`, + e.g: + + ```python + x = Input(type_spec=tf.RaggedTensorSpec(shape=[None, None], + dtype=tf.float32, ragged_rank=1)) + y = x.values + model = Model(x, y) + ``` + When passing an arbitrary `tf.TypeSpec`, it must represent the signature of an + entire batch instead of just one example. + Raises: ValueError: If both `sparse` and `ragged` are provided. ValueError: If both `shape` and (`batch_input_shape` or `batch_shape`) are provided. - ValueError: If both `shape` and `tensor` are None. + ValueError: If `shape`, `tensor` and `type_spec` are None. + ValueError: If arguments besides `type_spec` are non-None while `type_spec` + is passed. ValueError: if any unrecognized parameters are provided. """ if sparse and ragged: @@ -285,16 +355,18 @@ def Input( # pylint: disable=invalid-name 'Cannot set both sparse and ragged to True in a Keras input.') input_layer_config = {'name': name, 'dtype': dtype, 'sparse': sparse, - 'ragged': ragged, 'input_tensor': tensor} + 'ragged': ragged, 'input_tensor': tensor, + 'type_spec': type_spec} batch_input_shape = kwargs.pop('batch_input_shape', kwargs.pop('batch_shape', None)) if shape is not None and batch_input_shape is not None: raise ValueError('Only provide the `shape` OR `batch_input_shape` argument ' 'to Input, not both at the same time.') - if batch_input_shape is None and shape is None and tensor is None: - raise ValueError('Please provide to Input either a `shape`' - ' or a `tensor` argument. Note that ' + if (batch_input_shape is None and shape is None and tensor is None + and type_spec is None): + raise ValueError('Please provide to Input a `shape`' + ' or a `tensor` or a `type_spec` argument. Note that ' '`shape` does not include the batch ' 'dimension.') if kwargs: diff --git a/tensorflow/python/keras/engine/input_layer_test.py b/tensorflow/python/keras/engine/input_layer_test.py index 1b15f34458c..52140e0dec5 100644 --- a/tensorflow/python/keras/engine/input_layer_test.py +++ b/tensorflow/python/keras/engine/input_layer_test.py @@ -19,16 +19,105 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import def_function +from tensorflow.python.framework import composite_tensor +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import type_spec from tensorflow.python.keras import combinations from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import functional from tensorflow.python.keras.engine import input_layer as input_layer_lib +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.saving import model_config from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test +class TwoTensors(composite_tensor.CompositeTensor): + """A simple value type to test TypeSpec. + + Contains two tensors (x, y) and a string (color). The color value is a + stand-in for any extra type metadata we might need to store. + + This value type contains no single dtype. + """ + + def __init__(self, x, y, color='red', assign_variant_dtype=False): + assert isinstance(color, str) + self.x = ops.convert_to_tensor_v2_with_dispatch(x) + self.y = ops.convert_to_tensor_v2_with_dispatch(y) + self.color = color + self.shape = tensor_shape.TensorShape(None) + self._shape = tensor_shape.TensorShape(None) + if assign_variant_dtype: + self.dtype = dtypes.variant + self._assign_variant_dtype = assign_variant_dtype + + def _type_spec(self): + return TwoTensorsSpecNoOneDtype( + self.x.shape, self.x.dtype, self.y.shape, + self.y.dtype, color=self.color, + assign_variant_dtype=self._assign_variant_dtype) + + +def as_shape(shape): + """Converts the given object to a TensorShape.""" + if isinstance(shape, tensor_shape.TensorShape): + return shape + else: + return tensor_shape.TensorShape(shape) + + +@type_spec.register('tf.TwoTensorsSpec') +class TwoTensorsSpecNoOneDtype(type_spec.TypeSpec): + """A TypeSpec for the TwoTensors value type.""" + + def __init__( + self, x_shape, x_dtype, y_shape, y_dtype, color='red', + assign_variant_dtype=False): + self.x_shape = as_shape(x_shape) + self.x_dtype = dtypes.as_dtype(x_dtype) + self.y_shape = as_shape(y_shape) + self.y_dtype = dtypes.as_dtype(y_dtype) + self.color = color + self.shape = tensor_shape.TensorShape(None) + self._shape = tensor_shape.TensorShape(None) + if assign_variant_dtype: + self.dtype = dtypes.variant + self._assign_variant_dtype = assign_variant_dtype + + value_type = property(lambda self: TwoTensors) + + @property + def _component_specs(self): + return (tensor_spec.TensorSpec(self.x_shape, self.x_dtype), + tensor_spec.TensorSpec(self.y_shape, self.y_dtype)) + + def _to_components(self, value): + return (value.x, value.y) + + def _from_components(self, components): + x, y = components + return TwoTensors(x, y, self.color) + + def _serialize(self): + return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color) + + @classmethod + def from_value(cls, value): + return cls(value.x.shape, value.x.dtype, value.y.shape, value.y.dtype, + value.color) + + +type_spec.register_type_spec_from_value_converter( + TwoTensors, TwoTensorsSpecNoOneDtype.from_value) + + class InputLayerTest(keras_parameterized.TestCase): @combinations.generate(combinations.combine(mode=['graph', 'eager'])) @@ -144,5 +233,152 @@ class InputLayerTest(keras_parameterized.TestCase): values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) self.assertAllEqual(run_model(rt), rt * 3) + @combinations.generate(combinations.combine(mode=['eager'])) + def testNoMixingArgsWithTypeSpecArg(self): + with testing_utils.use_keras_tensors_scope(True): + with self.assertRaisesRegexp( + ValueError, 'all other args except `name` must be None'): + input_layer_lib.Input( + shape=(4, 7), + type_spec=tensor_spec.TensorSpec((2, 7, 32), dtypes.float32)) + with self.assertRaisesRegexp( + ValueError, 'all other args except `name` must be None'): + input_layer_lib.Input( + batch_size=4, + type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) + with self.assertRaisesRegexp( + ValueError, 'all other args except `name` must be None'): + input_layer_lib.Input( + dtype=dtypes.int64, + type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) + with self.assertRaisesRegexp( + ValueError, 'all other args except `name` must be None'): + input_layer_lib.Input( + sparse=True, + type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) + with self.assertRaisesRegexp( + ValueError, 'all other args except `name` must be None'): + input_layer_lib.Input( + ragged=True, + type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testTypeSpecArg(self): + with testing_utils.use_keras_tensors_scope(True): + # Create a Keras Input + x = input_layer_lib.Input( + type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32)) + self.assertAllEqual(x.shape.as_list(), [7, 32]) + + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, x * 2.0) + self.assertAllEqual(model(array_ops.ones(x.shape)), + array_ops.ones(x.shape) * 2.0) + + # Test serialization / deserialization + model = functional.Functional.from_config(model.get_config()) + self.assertAllEqual(model(array_ops.ones(x.shape)), + array_ops.ones(x.shape) * 2.0) + + model = model_config.model_from_json(model.to_json()) + self.assertAllEqual(model(array_ops.ones(x.shape)), + array_ops.ones(x.shape) * 2.0) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testTypeSpecArgInTFFunction(self): + with testing_utils.use_keras_tensors_scope(True): + # We use a mutable model container instead of a model python variable, + # because python 2.7 does not have `nonlocal` + model_container = {} + + @def_function.function + def run_model(inp): + if not model_container: + # Create a Keras Input + x = input_layer_lib.Input( + type_spec=tensor_spec.TensorSpec((10, 16), dtypes.float32)) + self.assertAllEqual(x.shape.as_list(), [10, 16]) + + # Verify you can construct and use a model w/ this input + model_container['model'] = functional.Functional(x, x * 3.0) + return model_container['model'](inp) + + self.assertAllEqual(run_model(array_ops.ones((10, 16))), + array_ops.ones((10, 16)) * 3.0) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testCompositeTypeSpecArg(self): + with testing_utils.use_keras_tensors_scope(True): + # Create a Keras Input + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + x = input_layer_lib.Input(type_spec=rt._type_spec) + + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, x * 2) + + # And that the model works + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + self.assertAllEqual(model(rt), rt * 2) + + # Test serialization / deserialization + model = functional.Functional.from_config(model.get_config()) + self.assertAllEqual(model(rt), rt * 2) + model = model_config.model_from_json(model.to_json()) + self.assertAllEqual(model(rt), rt * 2) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testCompositeTypeSpecArgInTFFunction(self): + with testing_utils.use_keras_tensors_scope(True): + # We use a mutable model container instead of a model pysthon variable, + # because python 2.7 does not have `nonlocal` + model_container = {} + + @def_function.function + def run_model(inp): + if not model_container: + # Create a Keras Input + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + x = input_layer_lib.Input(type_spec=rt._type_spec) + + # Verify you can construct and use a model w/ this input + model_container['model'] = functional.Functional(x, x * 3) + return model_container['model'](inp) + + # And verify the model works + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + self.assertAllEqual(run_model(rt), rt * 3) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testCompositeTypeSpecArgWithoutDtype(self): + with testing_utils.use_keras_tensors_scope(True): + for assign_variant_dtype in [False, True]: + # Create a Keras Input + spec = TwoTensorsSpecNoOneDtype( + (1, 2, 3), dtypes.float32, (1, 2, 3), dtypes.int64, + assign_variant_dtype=assign_variant_dtype) + x = input_layer_lib.Input(type_spec=spec) + + def lambda_fn(tensors): + return (math_ops.cast(tensors.x, dtypes.float64) + + math_ops.cast(tensors.y, dtypes.float64)) + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, core.Lambda(lambda_fn)(x)) + + # And that the model works + two_tensors = TwoTensors(array_ops.ones((1, 2, 3)) * 2.0, + array_ops.ones(1, 2, 3)) + self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors)) + + # Test serialization / deserialization + model = functional.Functional.from_config(model.get_config()) + self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors)) + model = model_config.model_from_json(model.to_json()) + self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/keras_tensor.py b/tensorflow/python/keras/engine/keras_tensor.py index fbefbe73fb1..23beb78f826 100644 --- a/tensorflow/python/keras/engine/keras_tensor.py +++ b/tensorflow/python/keras/engine/keras_tensor.py @@ -205,6 +205,10 @@ class KerasTensor(object): type_spec = type_spec_module.type_spec_from_value(tensor) return cls(type_spec, name=name) + @classmethod + def from_type_spec(cls, type_spec, name=None): + return cls(type_spec=type_spec, name=name) + def _to_placeholder(self): """Convert this KerasTensor to a placeholder in a graph.""" # If there is an inferred value for this tensor, inject the inferred value @@ -538,6 +542,11 @@ class UserRegisteredTypeKerasTensor(KerasTensor): def from_tensor(cls, tensor): return cls(tensor) + @classmethod + def from_type_spec(cls, type_spec, name=None): + raise NotImplementedError('You cannot instantiate a KerasTensor ' + 'directly from TypeSpec: %s' % type_spec) + def _to_placeholder(self): return self._user_registered_symbolic_object @@ -608,3 +617,17 @@ def keras_tensor_from_tensor(tensor): if hasattr(tensor, '_keras_mask'): out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access return out + + +def keras_tensor_from_type_spec(type_spec, name=None): + """Convert a TypeSpec to a representative KerasTensor.""" + # Create a specialized KerasTensor that supports instance methods, + # operators, and additional value inference if possible + keras_tensor_cls = None + value_type = type_spec.value_type + for tensor_type, cls in keras_tensor_classes: + if issubclass(value_type, tensor_type): + keras_tensor_cls = cls + break + + return keras_tensor_cls.from_type_spec(type_spec, name=name) diff --git a/tensorflow/python/keras/saving/model_config.py b/tensorflow/python/keras/saving/model_config.py index ab9eb2816d4..e203f45e093 100644 --- a/tensorflow/python/keras/saving/model_config.py +++ b/tensorflow/python/keras/saving/model_config.py @@ -19,8 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import json - +from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.util.tf_export import keras_export # pylint: disable=g-import-not-at-top @@ -126,6 +125,6 @@ def model_from_json(json_string, custom_objects=None): Returns: A Keras model instance (uncompiled). """ - config = json.loads(json_string) + config = json_utils.decode(json_string) from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top return deserialize(config, custom_objects=custom_objects) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt index 774a1c23255..aaf5dc16b45 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt index 35714912b04..5172fdea53b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt @@ -434,7 +434,7 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt index 9cb5ef1dcb1..c83d9ad5752 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.pbtxt @@ -86,6 +86,6 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt index 774a1c23255..aaf5dc16b45 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt @@ -129,7 +129,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add_loss" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt index 078c7ec8a67..869b5bdec3c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt @@ -426,7 +426,7 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "add" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt index 9cb5ef1dcb1..c83d9ad5752 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.pbtxt @@ -86,6 +86,6 @@ tf_module { } member_method { name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], " + argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } }