Add support for creating Keras.inputs from arbitrary TypeSpecs. (Including TypeSpecs that don't have a dtype)
PiperOrigin-RevId: 350200835 Change-Id: I06772c1d6ece689f17a72d787dcd12c6f611e7e3
This commit is contained in:
parent
6f1dee7d6c
commit
1b94fd1b8a
@ -35,6 +35,7 @@
|
||||
* Discretization combiner implemented, with additional arg `epsilon`.
|
||||
* Improvements to model saving/loading:
|
||||
* `model.load_weights` now accepts paths to saved models.
|
||||
* Keras inputs can now be created directly from arbitrary `tf.TypeSpecs`.
|
||||
|
||||
* `tf.data`:
|
||||
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
|
||||
|
@ -1337,7 +1337,6 @@ def placeholder(shape=None,
|
||||
if sparse:
|
||||
spec = sparse_tensor.SparseTensorSpec(
|
||||
shape=shape, dtype=dtype)
|
||||
x = keras_tensor.SparseKerasTensor(spec, name=name)
|
||||
elif ragged:
|
||||
ragged_rank = 0
|
||||
for i in range(1, len(shape)):
|
||||
@ -1349,12 +1348,10 @@ def placeholder(shape=None,
|
||||
ragged_rank = i
|
||||
spec = ragged_tensor.RaggedTensorSpec(
|
||||
shape=shape, dtype=dtype, ragged_rank=ragged_rank)
|
||||
|
||||
x = keras_tensor.RaggedKerasTensor(spec, name=name)
|
||||
else:
|
||||
spec = tensor_spec.TensorSpec(
|
||||
shape=shape, dtype=dtype, name=name)
|
||||
x = keras_tensor.KerasTensor(spec, name=name)
|
||||
x = keras_tensor.keras_tensor_from_type_spec(spec, name=name)
|
||||
else:
|
||||
with get_graph().as_default():
|
||||
if sparse:
|
||||
|
@ -27,6 +27,7 @@ import warnings
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
@ -643,8 +644,11 @@ class Functional(training_lib.Model):
|
||||
# Dtype casting.
|
||||
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||
elif tf_utils.is_extension_type(tensor):
|
||||
# Dtype casting.
|
||||
tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
|
||||
# Dtype casting (If the extension type has a non-variant dtype and
|
||||
# supports being cast)
|
||||
ref_input_dtype = getattr(ref_input, 'dtype', None)
|
||||
if ref_input_dtype is not None and ref_input_dtype != dtypes.variant:
|
||||
tensor = math_ops.cast(tensor, dtype=ref_input_dtype)
|
||||
|
||||
return tensor
|
||||
|
||||
|
@ -32,6 +32,13 @@ from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
def _assert_other_arg_none(arg_name, arg):
|
||||
if arg is not None:
|
||||
raise ValueError('When `type_spec` is not None, all other args '
|
||||
'except `name` must be None, '
|
||||
'but %s is not None.' % arg_name)
|
||||
|
||||
|
||||
@keras_export('keras.layers.InputLayer')
|
||||
class InputLayer(base_layer.Layer):
|
||||
"""Layer to be used as an entry point into a Network (a graph of layers).
|
||||
@ -85,6 +92,9 @@ class InputLayer(base_layer.Layer):
|
||||
ragged dimensions. For more information about RaggedTensors, see
|
||||
[this guide](https://www.tensorflow.org/guide/ragged_tensors).
|
||||
Default to False.
|
||||
type_spec: A `tf.TypeSpec` object to create Input from. This `tf.TypeSpec`
|
||||
represents the entire batch. When provided, all other args except
|
||||
name must be None.
|
||||
name: Optional name of the layer (string).
|
||||
"""
|
||||
|
||||
@ -93,10 +103,18 @@ class InputLayer(base_layer.Layer):
|
||||
batch_size=None,
|
||||
dtype=None,
|
||||
input_tensor=None,
|
||||
sparse=False,
|
||||
sparse=None,
|
||||
name=None,
|
||||
ragged=False,
|
||||
ragged=None,
|
||||
type_spec=None,
|
||||
**kwargs):
|
||||
self._init_input_shape = input_shape
|
||||
self._init_batch_size = batch_size
|
||||
self._init_dtype = dtype
|
||||
self._init_sparse = sparse
|
||||
self._init_ragged = ragged
|
||||
self._init_type_spec = type_spec
|
||||
|
||||
strategy = distribution_strategy_context.get_strategy()
|
||||
if strategy and batch_size is not None and \
|
||||
distributed_training_utils.global_batch_size_supported(strategy):
|
||||
@ -135,8 +153,8 @@ class InputLayer(base_layer.Layer):
|
||||
(input_tensor.dtype, dtype))
|
||||
super(InputLayer, self).__init__(dtype=dtype, name=name)
|
||||
self.built = True
|
||||
self.sparse = sparse
|
||||
self.ragged = ragged
|
||||
self.sparse = True if sparse else False
|
||||
self.ragged = True if ragged else False
|
||||
self.batch_size = batch_size
|
||||
self.supports_masking = True
|
||||
|
||||
@ -145,7 +163,32 @@ class InputLayer(base_layer.Layer):
|
||||
elif isinstance(input_shape, int):
|
||||
input_shape = (input_shape,)
|
||||
|
||||
if input_tensor is None:
|
||||
if type_spec is not None:
|
||||
args_that_must_be_none = [
|
||||
('(input_)shape', self._init_input_shape),
|
||||
('batch_size', self._init_batch_size),
|
||||
('dtype', self._init_dtype),
|
||||
('input_tensor', input_tensor),
|
||||
('sparse', self._init_sparse),
|
||||
('ragged', self._init_ragged),
|
||||
]
|
||||
for arg_name, arg in args_that_must_be_none:
|
||||
_assert_other_arg_none(arg_name, arg)
|
||||
if not keras_tensor.keras_tensors_enabled():
|
||||
raise ValueError('Creating Keras inputs from a type_spec is only '
|
||||
'supported when eager execution is enabled.')
|
||||
input_tensor = keras_tensor.keras_tensor_from_type_spec(type_spec)
|
||||
if isinstance(input_tensor, keras_tensor.SparseKerasTensor):
|
||||
self.sparse = True
|
||||
if isinstance(input_tensor, keras_tensor.RaggedKerasTensor):
|
||||
self.ragged = True
|
||||
self.is_placeholder = True
|
||||
try:
|
||||
self._batch_input_shape = tuple(input_tensor.shape.as_list())
|
||||
except ValueError:
|
||||
# If the shape cannot be represented as a tuple (e.g. unknown rank)
|
||||
self._batch_input_shape = None
|
||||
elif input_tensor is None:
|
||||
if input_shape is not None:
|
||||
batch_input_shape = (batch_size,) + tuple(input_shape)
|
||||
else:
|
||||
@ -190,13 +233,19 @@ class InputLayer(base_layer.Layer):
|
||||
shape=input_tensor.shape, dtype=input_tensor.dtype, name=self.name)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'batch_input_shape': self._batch_input_shape,
|
||||
'dtype': self.dtype,
|
||||
'sparse': self.sparse,
|
||||
'ragged': self.ragged,
|
||||
'name': self.name
|
||||
}
|
||||
if self._init_type_spec is not None:
|
||||
config = {
|
||||
'name': self.name,
|
||||
'type_spec': self._init_type_spec
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
'batch_input_shape': self._batch_input_shape,
|
||||
'dtype': self.dtype,
|
||||
'sparse': self.sparse,
|
||||
'ragged': self.ragged,
|
||||
'name': self.name,
|
||||
}
|
||||
return config
|
||||
|
||||
@property
|
||||
@ -210,13 +259,14 @@ def Input( # pylint: disable=invalid-name
|
||||
batch_size=None,
|
||||
name=None,
|
||||
dtype=None,
|
||||
sparse=False,
|
||||
sparse=None,
|
||||
tensor=None,
|
||||
ragged=False,
|
||||
ragged=None,
|
||||
type_spec=None,
|
||||
**kwargs):
|
||||
"""`Input()` is used to instantiate a Keras tensor.
|
||||
|
||||
A Keras tensor is a TensorFlow symbolic tensor object,
|
||||
A Keras tensor is a symbolic tensor-like object,
|
||||
which we augment with certain attributes that allow us to build a Keras model
|
||||
just by knowing the inputs and outputs of the model.
|
||||
|
||||
@ -248,6 +298,8 @@ def Input( # pylint: disable=invalid-name
|
||||
values of 'None' in the 'shape' argument represent ragged dimensions.
|
||||
For more information about RaggedTensors, see
|
||||
[this guide](https://www.tensorflow.org/guide/ragged_tensors).
|
||||
type_spec: A `tf.TypeSpec` object to create the input placeholder from.
|
||||
When provided, all other args except name must be None.
|
||||
**kwargs: deprecated arguments support. Supports `batch_shape` and
|
||||
`batch_input_shape`.
|
||||
|
||||
@ -264,8 +316,8 @@ def Input( # pylint: disable=invalid-name
|
||||
```
|
||||
|
||||
Note that even if eager execution is enabled,
|
||||
`Input` produces a symbolic tensor (i.e. a placeholder).
|
||||
This symbolic tensor can be used with other
|
||||
`Input` produces a symbolic tensor-like object (i.e. a placeholder).
|
||||
This symbolic tensor-like object can be used with other
|
||||
TensorFlow ops, as such:
|
||||
|
||||
```python
|
||||
@ -273,11 +325,29 @@ def Input( # pylint: disable=invalid-name
|
||||
y = tf.square(x)
|
||||
```
|
||||
|
||||
However, the resulting model will not track any variables that were
|
||||
used as inputs to TensorFlow ops. All variable usages must happen within
|
||||
Keras layers to make sure they will be tracked by the model's weights.
|
||||
|
||||
The Keras Input can also create a placeholder from an arbitrary `tf.TypeSpec`,
|
||||
e.g:
|
||||
|
||||
```python
|
||||
x = Input(type_spec=tf.RaggedTensorSpec(shape=[None, None],
|
||||
dtype=tf.float32, ragged_rank=1))
|
||||
y = x.values
|
||||
model = Model(x, y)
|
||||
```
|
||||
When passing an arbitrary `tf.TypeSpec`, it must represent the signature of an
|
||||
entire batch instead of just one example.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `sparse` and `ragged` are provided.
|
||||
ValueError: If both `shape` and (`batch_input_shape` or `batch_shape`) are
|
||||
provided.
|
||||
ValueError: If both `shape` and `tensor` are None.
|
||||
ValueError: If `shape`, `tensor` and `type_spec` are None.
|
||||
ValueError: If arguments besides `type_spec` are non-None while `type_spec`
|
||||
is passed.
|
||||
ValueError: if any unrecognized parameters are provided.
|
||||
"""
|
||||
if sparse and ragged:
|
||||
@ -285,16 +355,18 @@ def Input( # pylint: disable=invalid-name
|
||||
'Cannot set both sparse and ragged to True in a Keras input.')
|
||||
|
||||
input_layer_config = {'name': name, 'dtype': dtype, 'sparse': sparse,
|
||||
'ragged': ragged, 'input_tensor': tensor}
|
||||
'ragged': ragged, 'input_tensor': tensor,
|
||||
'type_spec': type_spec}
|
||||
|
||||
batch_input_shape = kwargs.pop('batch_input_shape',
|
||||
kwargs.pop('batch_shape', None))
|
||||
if shape is not None and batch_input_shape is not None:
|
||||
raise ValueError('Only provide the `shape` OR `batch_input_shape` argument '
|
||||
'to Input, not both at the same time.')
|
||||
if batch_input_shape is None and shape is None and tensor is None:
|
||||
raise ValueError('Please provide to Input either a `shape`'
|
||||
' or a `tensor` argument. Note that '
|
||||
if (batch_input_shape is None and shape is None and tensor is None
|
||||
and type_spec is None):
|
||||
raise ValueError('Please provide to Input a `shape`'
|
||||
' or a `tensor` or a `type_spec` argument. Note that '
|
||||
'`shape` does not include the batch '
|
||||
'dimension.')
|
||||
if kwargs:
|
||||
|
@ -19,16 +19,105 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import functional
|
||||
from tensorflow.python.keras.engine import input_layer as input_layer_lib
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.saving import model_config
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class TwoTensors(composite_tensor.CompositeTensor):
|
||||
"""A simple value type to test TypeSpec.
|
||||
|
||||
Contains two tensors (x, y) and a string (color). The color value is a
|
||||
stand-in for any extra type metadata we might need to store.
|
||||
|
||||
This value type contains no single dtype.
|
||||
"""
|
||||
|
||||
def __init__(self, x, y, color='red', assign_variant_dtype=False):
|
||||
assert isinstance(color, str)
|
||||
self.x = ops.convert_to_tensor_v2_with_dispatch(x)
|
||||
self.y = ops.convert_to_tensor_v2_with_dispatch(y)
|
||||
self.color = color
|
||||
self.shape = tensor_shape.TensorShape(None)
|
||||
self._shape = tensor_shape.TensorShape(None)
|
||||
if assign_variant_dtype:
|
||||
self.dtype = dtypes.variant
|
||||
self._assign_variant_dtype = assign_variant_dtype
|
||||
|
||||
def _type_spec(self):
|
||||
return TwoTensorsSpecNoOneDtype(
|
||||
self.x.shape, self.x.dtype, self.y.shape,
|
||||
self.y.dtype, color=self.color,
|
||||
assign_variant_dtype=self._assign_variant_dtype)
|
||||
|
||||
|
||||
def as_shape(shape):
|
||||
"""Converts the given object to a TensorShape."""
|
||||
if isinstance(shape, tensor_shape.TensorShape):
|
||||
return shape
|
||||
else:
|
||||
return tensor_shape.TensorShape(shape)
|
||||
|
||||
|
||||
@type_spec.register('tf.TwoTensorsSpec')
|
||||
class TwoTensorsSpecNoOneDtype(type_spec.TypeSpec):
|
||||
"""A TypeSpec for the TwoTensors value type."""
|
||||
|
||||
def __init__(
|
||||
self, x_shape, x_dtype, y_shape, y_dtype, color='red',
|
||||
assign_variant_dtype=False):
|
||||
self.x_shape = as_shape(x_shape)
|
||||
self.x_dtype = dtypes.as_dtype(x_dtype)
|
||||
self.y_shape = as_shape(y_shape)
|
||||
self.y_dtype = dtypes.as_dtype(y_dtype)
|
||||
self.color = color
|
||||
self.shape = tensor_shape.TensorShape(None)
|
||||
self._shape = tensor_shape.TensorShape(None)
|
||||
if assign_variant_dtype:
|
||||
self.dtype = dtypes.variant
|
||||
self._assign_variant_dtype = assign_variant_dtype
|
||||
|
||||
value_type = property(lambda self: TwoTensors)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
return (tensor_spec.TensorSpec(self.x_shape, self.x_dtype),
|
||||
tensor_spec.TensorSpec(self.y_shape, self.y_dtype))
|
||||
|
||||
def _to_components(self, value):
|
||||
return (value.x, value.y)
|
||||
|
||||
def _from_components(self, components):
|
||||
x, y = components
|
||||
return TwoTensors(x, y, self.color)
|
||||
|
||||
def _serialize(self):
|
||||
return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color)
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value):
|
||||
return cls(value.x.shape, value.x.dtype, value.y.shape, value.y.dtype,
|
||||
value.color)
|
||||
|
||||
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
TwoTensors, TwoTensorsSpecNoOneDtype.from_value)
|
||||
|
||||
|
||||
class InputLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@ -144,5 +233,152 @@ class InputLayerTest(keras_parameterized.TestCase):
|
||||
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertAllEqual(run_model(rt), rt * 3)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testNoMixingArgsWithTypeSpecArg(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'all other args except `name` must be None'):
|
||||
input_layer_lib.Input(
|
||||
shape=(4, 7),
|
||||
type_spec=tensor_spec.TensorSpec((2, 7, 32), dtypes.float32))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'all other args except `name` must be None'):
|
||||
input_layer_lib.Input(
|
||||
batch_size=4,
|
||||
type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'all other args except `name` must be None'):
|
||||
input_layer_lib.Input(
|
||||
dtype=dtypes.int64,
|
||||
type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'all other args except `name` must be None'):
|
||||
input_layer_lib.Input(
|
||||
sparse=True,
|
||||
type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'all other args except `name` must be None'):
|
||||
input_layer_lib.Input(
|
||||
ragged=True,
|
||||
type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testTypeSpecArg(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# Create a Keras Input
|
||||
x = input_layer_lib.Input(
|
||||
type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
|
||||
self.assertAllEqual(x.shape.as_list(), [7, 32])
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, x * 2.0)
|
||||
self.assertAllEqual(model(array_ops.ones(x.shape)),
|
||||
array_ops.ones(x.shape) * 2.0)
|
||||
|
||||
# Test serialization / deserialization
|
||||
model = functional.Functional.from_config(model.get_config())
|
||||
self.assertAllEqual(model(array_ops.ones(x.shape)),
|
||||
array_ops.ones(x.shape) * 2.0)
|
||||
|
||||
model = model_config.model_from_json(model.to_json())
|
||||
self.assertAllEqual(model(array_ops.ones(x.shape)),
|
||||
array_ops.ones(x.shape) * 2.0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testTypeSpecArgInTFFunction(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# We use a mutable model container instead of a model python variable,
|
||||
# because python 2.7 does not have `nonlocal`
|
||||
model_container = {}
|
||||
|
||||
@def_function.function
|
||||
def run_model(inp):
|
||||
if not model_container:
|
||||
# Create a Keras Input
|
||||
x = input_layer_lib.Input(
|
||||
type_spec=tensor_spec.TensorSpec((10, 16), dtypes.float32))
|
||||
self.assertAllEqual(x.shape.as_list(), [10, 16])
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model_container['model'] = functional.Functional(x, x * 3.0)
|
||||
return model_container['model'](inp)
|
||||
|
||||
self.assertAllEqual(run_model(array_ops.ones((10, 16))),
|
||||
array_ops.ones((10, 16)) * 3.0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testCompositeTypeSpecArg(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# Create a Keras Input
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
x = input_layer_lib.Input(type_spec=rt._type_spec)
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, x * 2)
|
||||
|
||||
# And that the model works
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertAllEqual(model(rt), rt * 2)
|
||||
|
||||
# Test serialization / deserialization
|
||||
model = functional.Functional.from_config(model.get_config())
|
||||
self.assertAllEqual(model(rt), rt * 2)
|
||||
model = model_config.model_from_json(model.to_json())
|
||||
self.assertAllEqual(model(rt), rt * 2)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testCompositeTypeSpecArgInTFFunction(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# We use a mutable model container instead of a model pysthon variable,
|
||||
# because python 2.7 does not have `nonlocal`
|
||||
model_container = {}
|
||||
|
||||
@def_function.function
|
||||
def run_model(inp):
|
||||
if not model_container:
|
||||
# Create a Keras Input
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
x = input_layer_lib.Input(type_spec=rt._type_spec)
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model_container['model'] = functional.Functional(x, x * 3)
|
||||
return model_container['model'](inp)
|
||||
|
||||
# And verify the model works
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertAllEqual(run_model(rt), rt * 3)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testCompositeTypeSpecArgWithoutDtype(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
for assign_variant_dtype in [False, True]:
|
||||
# Create a Keras Input
|
||||
spec = TwoTensorsSpecNoOneDtype(
|
||||
(1, 2, 3), dtypes.float32, (1, 2, 3), dtypes.int64,
|
||||
assign_variant_dtype=assign_variant_dtype)
|
||||
x = input_layer_lib.Input(type_spec=spec)
|
||||
|
||||
def lambda_fn(tensors):
|
||||
return (math_ops.cast(tensors.x, dtypes.float64)
|
||||
+ math_ops.cast(tensors.y, dtypes.float64))
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, core.Lambda(lambda_fn)(x))
|
||||
|
||||
# And that the model works
|
||||
two_tensors = TwoTensors(array_ops.ones((1, 2, 3)) * 2.0,
|
||||
array_ops.ones(1, 2, 3))
|
||||
self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
|
||||
|
||||
# Test serialization / deserialization
|
||||
model = functional.Functional.from_config(model.get_config())
|
||||
self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
|
||||
model = model_config.model_from_json(model.to_json())
|
||||
self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -205,6 +205,10 @@ class KerasTensor(object):
|
||||
type_spec = type_spec_module.type_spec_from_value(tensor)
|
||||
return cls(type_spec, name=name)
|
||||
|
||||
@classmethod
|
||||
def from_type_spec(cls, type_spec, name=None):
|
||||
return cls(type_spec=type_spec, name=name)
|
||||
|
||||
def _to_placeholder(self):
|
||||
"""Convert this KerasTensor to a placeholder in a graph."""
|
||||
# If there is an inferred value for this tensor, inject the inferred value
|
||||
@ -538,6 +542,11 @@ class UserRegisteredTypeKerasTensor(KerasTensor):
|
||||
def from_tensor(cls, tensor):
|
||||
return cls(tensor)
|
||||
|
||||
@classmethod
|
||||
def from_type_spec(cls, type_spec, name=None):
|
||||
raise NotImplementedError('You cannot instantiate a KerasTensor '
|
||||
'directly from TypeSpec: %s' % type_spec)
|
||||
|
||||
def _to_placeholder(self):
|
||||
return self._user_registered_symbolic_object
|
||||
|
||||
@ -608,3 +617,17 @@ def keras_tensor_from_tensor(tensor):
|
||||
if hasattr(tensor, '_keras_mask'):
|
||||
out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access
|
||||
return out
|
||||
|
||||
|
||||
def keras_tensor_from_type_spec(type_spec, name=None):
|
||||
"""Convert a TypeSpec to a representative KerasTensor."""
|
||||
# Create a specialized KerasTensor that supports instance methods,
|
||||
# operators, and additional value inference if possible
|
||||
keras_tensor_cls = None
|
||||
value_type = type_spec.value_type
|
||||
for tensor_type, cls in keras_tensor_classes:
|
||||
if issubclass(value_type, tensor_type):
|
||||
keras_tensor_cls = cls
|
||||
break
|
||||
|
||||
return keras_tensor_cls.from_type_spec(type_spec, name=name)
|
||||
|
@ -19,8 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
@ -126,6 +125,6 @@ def model_from_json(json_string, custom_objects=None):
|
||||
Returns:
|
||||
A Keras model instance (uncompiled).
|
||||
"""
|
||||
config = json.loads(json_string)
|
||||
config = json_utils.decode(json_string)
|
||||
from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
|
||||
return deserialize(config, custom_objects=custom_objects)
|
||||
|
@ -129,7 +129,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -434,7 +434,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Input"
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add"
|
||||
|
@ -86,6 +86,6 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Input"
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -129,7 +129,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
@ -426,7 +426,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Input"
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add"
|
||||
|
@ -86,6 +86,6 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "Input"
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user