When a non-KerasTensor Tensor is passed to the tensor
argument of keras.Input or InputLayer, make a KerasTensor directly from that tensor rather than erroring out.
PiperOrigin-RevId: 324688220 Change-Id: I2b06682f8ea706be4e36e0b8807c0f07bec55a4e
This commit is contained in:
parent
d84acd6e45
commit
64c753b9ee
tensorflow/python/keras/engine
@ -545,6 +545,23 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "input_layer_test",
|
||||
size = "medium",
|
||||
srcs = ["input_layer_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 3,
|
||||
tags = [
|
||||
"nomac", # TODO(mihaimaruseac): b/127695564
|
||||
],
|
||||
deps = [
|
||||
":base_layer",
|
||||
":engine",
|
||||
"//tensorflow/python/keras:testing_utils",
|
||||
"//tensorflow/python/keras/utils:layer_utils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "functional_test",
|
||||
size = "medium",
|
||||
|
@ -135,8 +135,9 @@ class Functional(training_lib.Model):
|
||||
(isinstance(self._nested_inputs, (list, tuple, dict)) and
|
||||
not any(nest.is_nested(t) for t in self._nested_inputs)))
|
||||
|
||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||
base_layer_utils.create_keras_history(self._nested_outputs)
|
||||
if not keras_tensor.keras_tensors_enabled():
|
||||
if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
|
||||
base_layer_utils.create_keras_history(self._nested_outputs)
|
||||
|
||||
self._validate_graph_inputs_and_outputs()
|
||||
|
||||
|
@ -76,8 +76,9 @@ class InputLayer(base_layer.Layer):
|
||||
batch_size: Optional input batch size (integer or None).
|
||||
dtype: Optional datatype of the input. When not provided, the Keras
|
||||
default float type will be used.
|
||||
input_tensor: Optional tensor to use as layer input
|
||||
instead of creating a placeholder.
|
||||
input_tensor: Optional tensor to use as layer input. If set, the layer
|
||||
will use the `tf.TypeSpec` of this tensor rather
|
||||
than creating a new placeholder tensor.
|
||||
sparse: Boolean, whether the placeholder created is meant to be sparse.
|
||||
Default to False.
|
||||
ragged: Boolean, whether the placeholder created is meant to be ragged.
|
||||
@ -162,19 +163,15 @@ class InputLayer(base_layer.Layer):
|
||||
self.is_placeholder = True
|
||||
self._batch_input_shape = batch_input_shape
|
||||
else:
|
||||
raise_eager_tensor_error = False
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
if (not isinstance(input_tensor, keras_tensor.KerasTensor) and
|
||||
not tf_utils.is_symbolic_tensor(input_tensor)):
|
||||
raise_eager_tensor_error = True
|
||||
if not isinstance(input_tensor, keras_tensor.KerasTensor):
|
||||
input_tensor = keras_tensor.keras_tensor_from_tensor(input_tensor)
|
||||
else:
|
||||
if not tf_utils.is_symbolic_tensor(input_tensor):
|
||||
raise_eager_tensor_error = True
|
||||
if raise_eager_tensor_error:
|
||||
raise ValueError('You should not pass an EagerTensor to `Input`. '
|
||||
'For example, instead of creating an '
|
||||
'InputLayer, you should instantiate your model and '
|
||||
'directly call it on your input.')
|
||||
raise ValueError('You should not pass an EagerTensor to `Input`. '
|
||||
'For example, instead of creating an '
|
||||
'InputLayer, you should instantiate your model and '
|
||||
'directly call it on your input.')
|
||||
self.is_placeholder = False
|
||||
try:
|
||||
self._batch_input_shape = tuple(input_tensor.shape.as_list())
|
||||
@ -245,7 +242,8 @@ def Input( # pylint: disable=invalid-name
|
||||
if `sparse` is False, sparse tensors can still be passed into the
|
||||
input - they will be densified with a default value of 0.
|
||||
tensor: Optional existing tensor to wrap into the `Input` layer.
|
||||
If set, the layer will not create a placeholder tensor.
|
||||
If set, the layer will use the `tf.TypeSpec` of this tensor rather
|
||||
than creating a new placeholder tensor.
|
||||
ragged: A boolean specifying whether the placeholder to be created is
|
||||
ragged. Only one of 'ragged' and 'sparse' can be True. In this case,
|
||||
values of 'None' in the 'shape' argument represent ragged dimensions.
|
||||
|
148
tensorflow/python/keras/engine/input_layer_test.py
Normal file
148
tensorflow/python/keras/engine/input_layer_test.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#,============================================================================
|
||||
"""Tests for InputLayer construction."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import functional
|
||||
from tensorflow.python.keras.engine import input_layer as input_layer_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InputLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testBasicOutputShapeNoBatchSize(self):
|
||||
# Create a Keras Input
|
||||
x = input_layer_lib.Input(shape=(32,), name='input_a')
|
||||
self.assertAllEqual(x.shape.as_list(), [None, 32])
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, x * 2.0)
|
||||
self.assertAllEqual(model(array_ops.ones((3, 32))),
|
||||
array_ops.ones((3, 32)) * 2.0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testBasicOutputShapeWithBatchSize(self):
|
||||
# Create a Keras Input
|
||||
x = input_layer_lib.Input(batch_size=6, shape=(32,), name='input_b')
|
||||
self.assertAllEqual(x.shape.as_list(), [6, 32])
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, x * 2.0)
|
||||
self.assertAllEqual(model(array_ops.ones(x.shape)),
|
||||
array_ops.ones(x.shape) * 2.0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testBasicOutputShapeNoBatchSizeInTFFunction(self):
|
||||
model = None
|
||||
@def_function.function
|
||||
def run_model(inp):
|
||||
nonlocal model
|
||||
if not model:
|
||||
# Create a Keras Input
|
||||
x = input_layer_lib.Input(shape=(8,), name='input_a')
|
||||
self.assertAllEqual(x.shape.as_list(), [None, 8])
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, x * 2.0)
|
||||
return model(inp)
|
||||
|
||||
self.assertAllEqual(run_model(array_ops.ones((10, 8))),
|
||||
array_ops.ones((10, 8)) * 2.0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def testInputTensorArg(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# Create a Keras Input
|
||||
x = input_layer_lib.Input(tensor=array_ops.zeros((7, 32)))
|
||||
self.assertAllEqual(x.shape.as_list(), [7, 32])
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, x * 2.0)
|
||||
self.assertAllEqual(model(array_ops.ones(x.shape)),
|
||||
array_ops.ones(x.shape) * 2.0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testInputTensorArgInTFFunction(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# We use a mutable model container instead of a model python variable,
|
||||
# because python 2.7 does not have `nonlocal`
|
||||
model_container = {}
|
||||
|
||||
@def_function.function
|
||||
def run_model(inp):
|
||||
if not model_container:
|
||||
# Create a Keras Input
|
||||
x = input_layer_lib.Input(tensor=array_ops.zeros((10, 16)))
|
||||
self.assertAllEqual(x.shape.as_list(), [10, 16])
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model_container['model'] = functional.Functional(x, x * 3.0)
|
||||
return model_container['model'](inp)
|
||||
|
||||
self.assertAllEqual(run_model(array_ops.ones((10, 16))),
|
||||
array_ops.ones((10, 16)) * 3.0)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testCompositeInputTensorArg(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# Create a Keras Input
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
x = input_layer_lib.Input(tensor=rt)
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model = functional.Functional(x, x * 2)
|
||||
|
||||
# And that the model works
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertAllEqual(model(rt), rt * 2)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def testCompositeInputTensorArgInTFFunction(self):
|
||||
with testing_utils.use_keras_tensors_scope(True):
|
||||
# We use a mutable model container instead of a model python variable,
|
||||
# because python 2.7 does not have `nonlocal`
|
||||
model_container = {}
|
||||
|
||||
@def_function.function
|
||||
def run_model(inp):
|
||||
if not model_container:
|
||||
# Create a Keras Input
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
x = input_layer_lib.Input(tensor=rt)
|
||||
|
||||
# Verify you can construct and use a model w/ this input
|
||||
model_container['model'] = functional.Functional(x, x * 3)
|
||||
return model_container['model'](inp)
|
||||
|
||||
# And verify the model works
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertAllEqual(run_model(rt), rt * 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user