Merge pull request #41483 from tomerk/cherrypicks_DC0YA

Raise an error when some but not all values passed to the first layer…
This commit is contained in:
Goldie Gadde 2020-07-17 12:07:45 -07:00 committed by GitHub
commit f923fa474b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 4 deletions

View File

@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# >> inputs = tf.keras.Input(10)
# >> outputs = MyLayer()(inputs) # Functional construction mode.
# >> model = tf.keras.Model(inputs, outputs)
if _in_functional_construction_mode(inputs, args, kwargs, input_list):
if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
return self._functional_construction_call(inputs, args, kwargs,
input_list)
@ -3205,7 +3205,7 @@ class AddMetric(Layer):
return config
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
"""Check the arguments to see if we are constructing a functional model."""
if keras_tensor.keras_tensors_enabled():
# We are constructing a functional model if any of the inputs
@ -3215,7 +3215,20 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
for tensor in nest.flatten([inputs, args, kwargs]))
else:
if context.executing_eagerly():
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
all_inputs_symbolic = all(
tf_utils.is_symbolic_tensor(t) for t in input_list)
if (base_layer_utils.is_subclassed(layer) and
any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
[inputs, args, kwargs])) and not all_inputs_symbolic):
raise ValueError('It appears you are trying to construct a '
'functional model, but not all of the inputs in '
'the first positional argument of your layer call '
'are symbolic tensors. '
'(Input objects, or the output of another layer) '
'Functional models cannot correctly track custom '
'layers unless all values in the first call argument '
'are symbolic.')
return all_inputs_symbolic
else:
return (base_layer_utils.is_in_keras_graph() or
all(hasattr(t, '_keras_history') for t in input_list))

View File

@ -932,6 +932,72 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(mode='eager'),
combinations.combine(use_keras_tensors=False)))
def test_only_some_in_first_arg_derived_from_keras_layer(self):
class MyAddAll(layers.Layer):
def call(self, inputs):
x = inputs[0]
for inp in inputs[1:]:
if inp is not None:
x = x + inp
return x
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
layer = MyAddAll()
with self.assertRaisesRegexp(ValueError, 'construct a functional'):
layer([0.0, input1, None, input2, None])
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(mode='eager'),
combinations.combine(use_keras_tensors=True)))
def test_only_some_in_first_arg_derived_from_keras_layer_keras_tensors(self):
# This functionality is unsupported in v1 graphs
class MyAddAll(layers.Layer):
def call(self, inputs):
x = inputs[0]
for inp in inputs[1:]:
if inp is not None:
x = x + inp
return x
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
layer = MyAddAll()
outputs = layer([0.0, input1, None, input2, None])
model = training_lib.Model([input1, input2], outputs)
self.assertIn(layer, model.layers)
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
# Check serialization.
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'MyAddAll': MyAddAll})
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations())
def test_call_kwarg_derived_from_keras_layer(self):
@ -1070,7 +1136,8 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
input2 = input_layer_lib.Input(10)
input3 = input_layer_lib.Input(10)
outputs = AddAll()(
layer = AddAll()
outputs = layer(
[input1, 4 * array_ops.ones((1, 10))],
x3={
'a': input2,
@ -1078,6 +1145,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
'c': 5 * array_ops.ones((1, 10))
})
model = training_lib.Model([input1, input2, input3], outputs)
self.assertIn(layer, model.layers)
model.compile(
'sgd',
'mse',