When a non-KerasTensor Tensor is passed to the tensor
argument of keras.Input or InputLayer, make a KerasTensor directly from that tensor rather than erroring out.
PiperOrigin-RevId: 324688220 Change-Id: I2b06682f8ea706be4e36e0b8807c0f07bec55a4e
This commit is contained in:
parent
d84acd6e45
commit
64c753b9ee
@ -545,6 +545,23 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "input_layer_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["input_layer_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
shard_count = 3,
|
||||||
|
tags = [
|
||||||
|
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":base_layer",
|
||||||
|
":engine",
|
||||||
|
"//tensorflow/python/keras:testing_utils",
|
||||||
|
"//tensorflow/python/keras/utils:layer_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "functional_test",
|
name = "functional_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -135,6 +135,7 @@ class Functional(training_lib.Model):
|
|||||||
(isinstance(self._nested_inputs, (list, tuple, dict)) and
|
(isinstance(self._nested_inputs, (list, tuple, dict)) and
|
||||||
not any(nest.is_nested(t) for t in self._nested_inputs)))
|
not any(nest.is_nested(t) for t in self._nested_inputs)))
|
||||||
|
|
||||||
|
if not keras_tensor.keras_tensors_enabled():
|
||||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||||
base_layer_utils.create_keras_history(self._nested_outputs)
|
base_layer_utils.create_keras_history(self._nested_outputs)
|
||||||
|
|
||||||
|
@ -76,8 +76,9 @@ class InputLayer(base_layer.Layer):
|
|||||||
batch_size: Optional input batch size (integer or None).
|
batch_size: Optional input batch size (integer or None).
|
||||||
dtype: Optional datatype of the input. When not provided, the Keras
|
dtype: Optional datatype of the input. When not provided, the Keras
|
||||||
default float type will be used.
|
default float type will be used.
|
||||||
input_tensor: Optional tensor to use as layer input
|
input_tensor: Optional tensor to use as layer input. If set, the layer
|
||||||
instead of creating a placeholder.
|
will use the `tf.TypeSpec` of this tensor rather
|
||||||
|
than creating a new placeholder tensor.
|
||||||
sparse: Boolean, whether the placeholder created is meant to be sparse.
|
sparse: Boolean, whether the placeholder created is meant to be sparse.
|
||||||
Default to False.
|
Default to False.
|
||||||
ragged: Boolean, whether the placeholder created is meant to be ragged.
|
ragged: Boolean, whether the placeholder created is meant to be ragged.
|
||||||
@ -162,15 +163,11 @@ class InputLayer(base_layer.Layer):
|
|||||||
self.is_placeholder = True
|
self.is_placeholder = True
|
||||||
self._batch_input_shape = batch_input_shape
|
self._batch_input_shape = batch_input_shape
|
||||||
else:
|
else:
|
||||||
raise_eager_tensor_error = False
|
|
||||||
if keras_tensor.keras_tensors_enabled():
|
if keras_tensor.keras_tensors_enabled():
|
||||||
if (not isinstance(input_tensor, keras_tensor.KerasTensor) and
|
if not isinstance(input_tensor, keras_tensor.KerasTensor):
|
||||||
not tf_utils.is_symbolic_tensor(input_tensor)):
|
input_tensor = keras_tensor.keras_tensor_from_tensor(input_tensor)
|
||||||
raise_eager_tensor_error = True
|
|
||||||
else:
|
else:
|
||||||
if not tf_utils.is_symbolic_tensor(input_tensor):
|
if not tf_utils.is_symbolic_tensor(input_tensor):
|
||||||
raise_eager_tensor_error = True
|
|
||||||
if raise_eager_tensor_error:
|
|
||||||
raise ValueError('You should not pass an EagerTensor to `Input`. '
|
raise ValueError('You should not pass an EagerTensor to `Input`. '
|
||||||
'For example, instead of creating an '
|
'For example, instead of creating an '
|
||||||
'InputLayer, you should instantiate your model and '
|
'InputLayer, you should instantiate your model and '
|
||||||
@ -245,7 +242,8 @@ def Input( # pylint: disable=invalid-name
|
|||||||
if `sparse` is False, sparse tensors can still be passed into the
|
if `sparse` is False, sparse tensors can still be passed into the
|
||||||
input - they will be densified with a default value of 0.
|
input - they will be densified with a default value of 0.
|
||||||
tensor: Optional existing tensor to wrap into the `Input` layer.
|
tensor: Optional existing tensor to wrap into the `Input` layer.
|
||||||
If set, the layer will not create a placeholder tensor.
|
If set, the layer will use the `tf.TypeSpec` of this tensor rather
|
||||||
|
than creating a new placeholder tensor.
|
||||||
ragged: A boolean specifying whether the placeholder to be created is
|
ragged: A boolean specifying whether the placeholder to be created is
|
||||||
ragged. Only one of 'ragged' and 'sparse' can be True. In this case,
|
ragged. Only one of 'ragged' and 'sparse' can be True. In this case,
|
||||||
values of 'None' in the 'shape' argument represent ragged dimensions.
|
values of 'None' in the 'shape' argument represent ragged dimensions.
|
||||||
|
148
tensorflow/python/keras/engine/input_layer_test.py
Normal file
148
tensorflow/python/keras/engine/input_layer_test.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#,============================================================================
|
||||||
|
"""Tests for InputLayer construction."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.keras import combinations
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.engine import functional
|
||||||
|
from tensorflow.python.keras.engine import input_layer as input_layer_lib
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class InputLayerTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
|
def testBasicOutputShapeNoBatchSize(self):
|
||||||
|
# Create a Keras Input
|
||||||
|
x = input_layer_lib.Input(shape=(32,), name='input_a')
|
||||||
|
self.assertAllEqual(x.shape.as_list(), [None, 32])
|
||||||
|
|
||||||
|
# Verify you can construct and use a model w/ this input
|
||||||
|
model = functional.Functional(x, x * 2.0)
|
||||||
|
self.assertAllEqual(model(array_ops.ones((3, 32))),
|
||||||
|
array_ops.ones((3, 32)) * 2.0)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
|
def testBasicOutputShapeWithBatchSize(self):
|
||||||
|
# Create a Keras Input
|
||||||
|
x = input_layer_lib.Input(batch_size=6, shape=(32,), name='input_b')
|
||||||
|
self.assertAllEqual(x.shape.as_list(), [6, 32])
|
||||||
|
|
||||||
|
# Verify you can construct and use a model w/ this input
|
||||||
|
model = functional.Functional(x, x * 2.0)
|
||||||
|
self.assertAllEqual(model(array_ops.ones(x.shape)),
|
||||||
|
array_ops.ones(x.shape) * 2.0)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def testBasicOutputShapeNoBatchSizeInTFFunction(self):
|
||||||
|
model = None
|
||||||
|
@def_function.function
|
||||||
|
def run_model(inp):
|
||||||
|
nonlocal model
|
||||||
|
if not model:
|
||||||
|
# Create a Keras Input
|
||||||
|
x = input_layer_lib.Input(shape=(8,), name='input_a')
|
||||||
|
self.assertAllEqual(x.shape.as_list(), [None, 8])
|
||||||
|
|
||||||
|
# Verify you can construct and use a model w/ this input
|
||||||
|
model = functional.Functional(x, x * 2.0)
|
||||||
|
return model(inp)
|
||||||
|
|
||||||
|
self.assertAllEqual(run_model(array_ops.ones((10, 8))),
|
||||||
|
array_ops.ones((10, 8)) * 2.0)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||||
|
def testInputTensorArg(self):
|
||||||
|
with testing_utils.use_keras_tensors_scope(True):
|
||||||
|
# Create a Keras Input
|
||||||
|
x = input_layer_lib.Input(tensor=array_ops.zeros((7, 32)))
|
||||||
|
self.assertAllEqual(x.shape.as_list(), [7, 32])
|
||||||
|
|
||||||
|
# Verify you can construct and use a model w/ this input
|
||||||
|
model = functional.Functional(x, x * 2.0)
|
||||||
|
self.assertAllEqual(model(array_ops.ones(x.shape)),
|
||||||
|
array_ops.ones(x.shape) * 2.0)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def testInputTensorArgInTFFunction(self):
|
||||||
|
with testing_utils.use_keras_tensors_scope(True):
|
||||||
|
# We use a mutable model container instead of a model python variable,
|
||||||
|
# because python 2.7 does not have `nonlocal`
|
||||||
|
model_container = {}
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def run_model(inp):
|
||||||
|
if not model_container:
|
||||||
|
# Create a Keras Input
|
||||||
|
x = input_layer_lib.Input(tensor=array_ops.zeros((10, 16)))
|
||||||
|
self.assertAllEqual(x.shape.as_list(), [10, 16])
|
||||||
|
|
||||||
|
# Verify you can construct and use a model w/ this input
|
||||||
|
model_container['model'] = functional.Functional(x, x * 3.0)
|
||||||
|
return model_container['model'](inp)
|
||||||
|
|
||||||
|
self.assertAllEqual(run_model(array_ops.ones((10, 16))),
|
||||||
|
array_ops.ones((10, 16)) * 3.0)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def testCompositeInputTensorArg(self):
|
||||||
|
with testing_utils.use_keras_tensors_scope(True):
|
||||||
|
# Create a Keras Input
|
||||||
|
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||||
|
x = input_layer_lib.Input(tensor=rt)
|
||||||
|
|
||||||
|
# Verify you can construct and use a model w/ this input
|
||||||
|
model = functional.Functional(x, x * 2)
|
||||||
|
|
||||||
|
# And that the model works
|
||||||
|
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||||
|
self.assertAllEqual(model(rt), rt * 2)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def testCompositeInputTensorArgInTFFunction(self):
|
||||||
|
with testing_utils.use_keras_tensors_scope(True):
|
||||||
|
# We use a mutable model container instead of a model python variable,
|
||||||
|
# because python 2.7 does not have `nonlocal`
|
||||||
|
model_container = {}
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def run_model(inp):
|
||||||
|
if not model_container:
|
||||||
|
# Create a Keras Input
|
||||||
|
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||||
|
x = input_layer_lib.Input(tensor=rt)
|
||||||
|
|
||||||
|
# Verify you can construct and use a model w/ this input
|
||||||
|
model_container['model'] = functional.Functional(x, x * 3)
|
||||||
|
return model_container['model'](inp)
|
||||||
|
|
||||||
|
# And verify the model works
|
||||||
|
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||||
|
self.assertAllEqual(run_model(rt), rt * 3)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
x
Reference in New Issue
Block a user