When a non-KerasTensor Tensor is passed to the tensor argument of keras.Input or InputLayer, make a KerasTensor directly from that tensor rather than erroring out.

PiperOrigin-RevId: 324688220
Change-Id: I2b06682f8ea706be4e36e0b8807c0f07bec55a4e
This commit is contained in:
Tomer Kaftan 2020-08-03 14:49:34 -07:00 committed by TensorFlower Gardener
parent d84acd6e45
commit 64c753b9ee
4 changed files with 179 additions and 15 deletions

View File

@ -545,6 +545,23 @@ tf_py_test(
],
)
tf_py_test(
name = "input_layer_test",
size = "medium",
srcs = ["input_layer_test.py"],
python_version = "PY3",
shard_count = 3,
tags = [
"nomac", # TODO(mihaimaruseac): b/127695564
],
deps = [
":base_layer",
":engine",
"//tensorflow/python/keras:testing_utils",
"//tensorflow/python/keras/utils:layer_utils",
],
)
tf_py_test(
name = "functional_test",
size = "medium",

View File

@ -135,8 +135,9 @@ class Functional(training_lib.Model):
(isinstance(self._nested_inputs, (list, tuple, dict)) and
not any(nest.is_nested(t) for t in self._nested_inputs)))
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
base_layer_utils.create_keras_history(self._nested_outputs)
if not keras_tensor.keras_tensors_enabled():
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
base_layer_utils.create_keras_history(self._nested_outputs)
self._validate_graph_inputs_and_outputs()

View File

@ -76,8 +76,9 @@ class InputLayer(base_layer.Layer):
batch_size: Optional input batch size (integer or None).
dtype: Optional datatype of the input. When not provided, the Keras
default float type will be used.
input_tensor: Optional tensor to use as layer input
instead of creating a placeholder.
input_tensor: Optional tensor to use as layer input. If set, the layer
will use the `tf.TypeSpec` of this tensor rather
than creating a new placeholder tensor.
sparse: Boolean, whether the placeholder created is meant to be sparse.
Default to False.
ragged: Boolean, whether the placeholder created is meant to be ragged.
@ -162,19 +163,15 @@ class InputLayer(base_layer.Layer):
self.is_placeholder = True
self._batch_input_shape = batch_input_shape
else:
raise_eager_tensor_error = False
if keras_tensor.keras_tensors_enabled():
if (not isinstance(input_tensor, keras_tensor.KerasTensor) and
not tf_utils.is_symbolic_tensor(input_tensor)):
raise_eager_tensor_error = True
if not isinstance(input_tensor, keras_tensor.KerasTensor):
input_tensor = keras_tensor.keras_tensor_from_tensor(input_tensor)
else:
if not tf_utils.is_symbolic_tensor(input_tensor):
raise_eager_tensor_error = True
if raise_eager_tensor_error:
raise ValueError('You should not pass an EagerTensor to `Input`. '
'For example, instead of creating an '
'InputLayer, you should instantiate your model and '
'directly call it on your input.')
raise ValueError('You should not pass an EagerTensor to `Input`. '
'For example, instead of creating an '
'InputLayer, you should instantiate your model and '
'directly call it on your input.')
self.is_placeholder = False
try:
self._batch_input_shape = tuple(input_tensor.shape.as_list())
@ -245,7 +242,8 @@ def Input( # pylint: disable=invalid-name
if `sparse` is False, sparse tensors can still be passed into the
input - they will be densified with a default value of 0.
tensor: Optional existing tensor to wrap into the `Input` layer.
If set, the layer will not create a placeholder tensor.
If set, the layer will use the `tf.TypeSpec` of this tensor rather
than creating a new placeholder tensor.
ragged: A boolean specifying whether the placeholder to be created is
ragged. Only one of 'ragged' and 'sparse' can be True. In this case,
values of 'None' in the 'shape' argument represent ragged dimensions.

View File

@ -0,0 +1,148 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#,============================================================================
"""Tests for InputLayer construction."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import def_function
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import functional
from tensorflow.python.keras.engine import input_layer as input_layer_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
class InputLayerTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testBasicOutputShapeNoBatchSize(self):
# Create a Keras Input
x = input_layer_lib.Input(shape=(32,), name='input_a')
self.assertAllEqual(x.shape.as_list(), [None, 32])
# Verify you can construct and use a model w/ this input
model = functional.Functional(x, x * 2.0)
self.assertAllEqual(model(array_ops.ones((3, 32))),
array_ops.ones((3, 32)) * 2.0)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testBasicOutputShapeWithBatchSize(self):
# Create a Keras Input
x = input_layer_lib.Input(batch_size=6, shape=(32,), name='input_b')
self.assertAllEqual(x.shape.as_list(), [6, 32])
# Verify you can construct and use a model w/ this input
model = functional.Functional(x, x * 2.0)
self.assertAllEqual(model(array_ops.ones(x.shape)),
array_ops.ones(x.shape) * 2.0)
@combinations.generate(combinations.combine(mode=['eager']))
def testBasicOutputShapeNoBatchSizeInTFFunction(self):
model = None
@def_function.function
def run_model(inp):
nonlocal model
if not model:
# Create a Keras Input
x = input_layer_lib.Input(shape=(8,), name='input_a')
self.assertAllEqual(x.shape.as_list(), [None, 8])
# Verify you can construct and use a model w/ this input
model = functional.Functional(x, x * 2.0)
return model(inp)
self.assertAllEqual(run_model(array_ops.ones((10, 8))),
array_ops.ones((10, 8)) * 2.0)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def testInputTensorArg(self):
with testing_utils.use_keras_tensors_scope(True):
# Create a Keras Input
x = input_layer_lib.Input(tensor=array_ops.zeros((7, 32)))
self.assertAllEqual(x.shape.as_list(), [7, 32])
# Verify you can construct and use a model w/ this input
model = functional.Functional(x, x * 2.0)
self.assertAllEqual(model(array_ops.ones(x.shape)),
array_ops.ones(x.shape) * 2.0)
@combinations.generate(combinations.combine(mode=['eager']))
def testInputTensorArgInTFFunction(self):
with testing_utils.use_keras_tensors_scope(True):
# We use a mutable model container instead of a model python variable,
# because python 2.7 does not have `nonlocal`
model_container = {}
@def_function.function
def run_model(inp):
if not model_container:
# Create a Keras Input
x = input_layer_lib.Input(tensor=array_ops.zeros((10, 16)))
self.assertAllEqual(x.shape.as_list(), [10, 16])
# Verify you can construct and use a model w/ this input
model_container['model'] = functional.Functional(x, x * 3.0)
return model_container['model'](inp)
self.assertAllEqual(run_model(array_ops.ones((10, 16))),
array_ops.ones((10, 16)) * 3.0)
@combinations.generate(combinations.combine(mode=['eager']))
def testCompositeInputTensorArg(self):
with testing_utils.use_keras_tensors_scope(True):
# Create a Keras Input
rt = ragged_tensor.RaggedTensor.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
x = input_layer_lib.Input(tensor=rt)
# Verify you can construct and use a model w/ this input
model = functional.Functional(x, x * 2)
# And that the model works
rt = ragged_tensor.RaggedTensor.from_row_splits(
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
self.assertAllEqual(model(rt), rt * 2)
@combinations.generate(combinations.combine(mode=['eager']))
def testCompositeInputTensorArgInTFFunction(self):
with testing_utils.use_keras_tensors_scope(True):
# We use a mutable model container instead of a model python variable,
# because python 2.7 does not have `nonlocal`
model_container = {}
@def_function.function
def run_model(inp):
if not model_container:
# Create a Keras Input
rt = ragged_tensor.RaggedTensor.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
x = input_layer_lib.Input(tensor=rt)
# Verify you can construct and use a model w/ this input
model_container['model'] = functional.Functional(x, x * 3)
return model_container['model'](inp)
# And verify the model works
rt = ragged_tensor.RaggedTensor.from_row_splits(
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
self.assertAllEqual(run_model(rt), rt * 3)
if __name__ == '__main__':
test.main()