diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index c71069b3657..0d2ddb46049 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -545,6 +545,23 @@ tf_py_test( ], ) +tf_py_test( + name = "input_layer_test", + size = "medium", + srcs = ["input_layer_test.py"], + python_version = "PY3", + shard_count = 3, + tags = [ + "nomac", # TODO(mihaimaruseac): b/127695564 + ], + deps = [ + ":base_layer", + ":engine", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/utils:layer_utils", + ], +) + tf_py_test( name = "functional_test", size = "medium", diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 71d6faa71b6..8422bf923d8 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -135,8 +135,9 @@ class Functional(training_lib.Model): (isinstance(self._nested_inputs, (list, tuple, dict)) and not any(nest.is_nested(t) for t in self._nested_inputs))) - if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): - base_layer_utils.create_keras_history(self._nested_outputs) + if not keras_tensor.keras_tensors_enabled(): + if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): + base_layer_utils.create_keras_history(self._nested_outputs) self._validate_graph_inputs_and_outputs() diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 4818c5c59a7..33f9320e516 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -76,8 +76,9 @@ class InputLayer(base_layer.Layer): batch_size: Optional input batch size (integer or None). dtype: Optional datatype of the input. When not provided, the Keras default float type will be used. - input_tensor: Optional tensor to use as layer input - instead of creating a placeholder. + input_tensor: Optional tensor to use as layer input. If set, the layer + will use the `tf.TypeSpec` of this tensor rather + than creating a new placeholder tensor. sparse: Boolean, whether the placeholder created is meant to be sparse. Default to False. ragged: Boolean, whether the placeholder created is meant to be ragged. @@ -162,19 +163,15 @@ class InputLayer(base_layer.Layer): self.is_placeholder = True self._batch_input_shape = batch_input_shape else: - raise_eager_tensor_error = False if keras_tensor.keras_tensors_enabled(): - if (not isinstance(input_tensor, keras_tensor.KerasTensor) and - not tf_utils.is_symbolic_tensor(input_tensor)): - raise_eager_tensor_error = True + if not isinstance(input_tensor, keras_tensor.KerasTensor): + input_tensor = keras_tensor.keras_tensor_from_tensor(input_tensor) else: if not tf_utils.is_symbolic_tensor(input_tensor): - raise_eager_tensor_error = True - if raise_eager_tensor_error: - raise ValueError('You should not pass an EagerTensor to `Input`. ' - 'For example, instead of creating an ' - 'InputLayer, you should instantiate your model and ' - 'directly call it on your input.') + raise ValueError('You should not pass an EagerTensor to `Input`. ' + 'For example, instead of creating an ' + 'InputLayer, you should instantiate your model and ' + 'directly call it on your input.') self.is_placeholder = False try: self._batch_input_shape = tuple(input_tensor.shape.as_list()) @@ -245,7 +242,8 @@ def Input( # pylint: disable=invalid-name if `sparse` is False, sparse tensors can still be passed into the input - they will be densified with a default value of 0. tensor: Optional existing tensor to wrap into the `Input` layer. - If set, the layer will not create a placeholder tensor. + If set, the layer will use the `tf.TypeSpec` of this tensor rather + than creating a new placeholder tensor. ragged: A boolean specifying whether the placeholder to be created is ragged. Only one of 'ragged' and 'sparse' can be True. In this case, values of 'None' in the 'shape' argument represent ragged dimensions. diff --git a/tensorflow/python/keras/engine/input_layer_test.py b/tensorflow/python/keras/engine/input_layer_test.py new file mode 100644 index 00000000000..1b15f34458c --- /dev/null +++ b/tensorflow/python/keras/engine/input_layer_test.py @@ -0,0 +1,148 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#,============================================================================ +"""Tests for InputLayer construction.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import def_function +from tensorflow.python.keras import combinations +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import functional +from tensorflow.python.keras.engine import input_layer as input_layer_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.platform import test + + +class InputLayerTest(keras_parameterized.TestCase): + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def testBasicOutputShapeNoBatchSize(self): + # Create a Keras Input + x = input_layer_lib.Input(shape=(32,), name='input_a') + self.assertAllEqual(x.shape.as_list(), [None, 32]) + + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, x * 2.0) + self.assertAllEqual(model(array_ops.ones((3, 32))), + array_ops.ones((3, 32)) * 2.0) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def testBasicOutputShapeWithBatchSize(self): + # Create a Keras Input + x = input_layer_lib.Input(batch_size=6, shape=(32,), name='input_b') + self.assertAllEqual(x.shape.as_list(), [6, 32]) + + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, x * 2.0) + self.assertAllEqual(model(array_ops.ones(x.shape)), + array_ops.ones(x.shape) * 2.0) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testBasicOutputShapeNoBatchSizeInTFFunction(self): + model = None + @def_function.function + def run_model(inp): + nonlocal model + if not model: + # Create a Keras Input + x = input_layer_lib.Input(shape=(8,), name='input_a') + self.assertAllEqual(x.shape.as_list(), [None, 8]) + + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, x * 2.0) + return model(inp) + + self.assertAllEqual(run_model(array_ops.ones((10, 8))), + array_ops.ones((10, 8)) * 2.0) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def testInputTensorArg(self): + with testing_utils.use_keras_tensors_scope(True): + # Create a Keras Input + x = input_layer_lib.Input(tensor=array_ops.zeros((7, 32))) + self.assertAllEqual(x.shape.as_list(), [7, 32]) + + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, x * 2.0) + self.assertAllEqual(model(array_ops.ones(x.shape)), + array_ops.ones(x.shape) * 2.0) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testInputTensorArgInTFFunction(self): + with testing_utils.use_keras_tensors_scope(True): + # We use a mutable model container instead of a model python variable, + # because python 2.7 does not have `nonlocal` + model_container = {} + + @def_function.function + def run_model(inp): + if not model_container: + # Create a Keras Input + x = input_layer_lib.Input(tensor=array_ops.zeros((10, 16))) + self.assertAllEqual(x.shape.as_list(), [10, 16]) + + # Verify you can construct and use a model w/ this input + model_container['model'] = functional.Functional(x, x * 3.0) + return model_container['model'](inp) + + self.assertAllEqual(run_model(array_ops.ones((10, 16))), + array_ops.ones((10, 16)) * 3.0) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testCompositeInputTensorArg(self): + with testing_utils.use_keras_tensors_scope(True): + # Create a Keras Input + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + x = input_layer_lib.Input(tensor=rt) + + # Verify you can construct and use a model w/ this input + model = functional.Functional(x, x * 2) + + # And that the model works + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + self.assertAllEqual(model(rt), rt * 2) + + @combinations.generate(combinations.combine(mode=['eager'])) + def testCompositeInputTensorArgInTFFunction(self): + with testing_utils.use_keras_tensors_scope(True): + # We use a mutable model container instead of a model python variable, + # because python 2.7 does not have `nonlocal` + model_container = {} + + @def_function.function + def run_model(inp): + if not model_container: + # Create a Keras Input + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + x = input_layer_lib.Input(tensor=rt) + + # Verify you can construct and use a model w/ this input + model_container['model'] = functional.Functional(x, x * 3) + return model_container['model'](inp) + + # And verify the model works + rt = ragged_tensor.RaggedTensor.from_row_splits( + values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + self.assertAllEqual(run_model(rt), rt * 3) + +if __name__ == '__main__': + test.main()