Merge pull request #41483 from tomerk/cherrypicks_DC0YA
Raise an error when some but not all values passed to the first layer…
This commit is contained in:
commit
f923fa474b
@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# >> inputs = tf.keras.Input(10)
|
||||
# >> outputs = MyLayer()(inputs) # Functional construction mode.
|
||||
# >> model = tf.keras.Model(inputs, outputs)
|
||||
if _in_functional_construction_mode(inputs, args, kwargs, input_list):
|
||||
if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
|
||||
return self._functional_construction_call(inputs, args, kwargs,
|
||||
input_list)
|
||||
|
||||
@ -3205,7 +3205,7 @@ class AddMetric(Layer):
|
||||
return config
|
||||
|
||||
|
||||
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
||||
def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
||||
"""Check the arguments to see if we are constructing a functional model."""
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
# We are constructing a functional model if any of the inputs
|
||||
@ -3215,7 +3215,20 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
|
||||
for tensor in nest.flatten([inputs, args, kwargs]))
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
all_inputs_symbolic = all(
|
||||
tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
if (base_layer_utils.is_subclassed(layer) and
|
||||
any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
|
||||
[inputs, args, kwargs])) and not all_inputs_symbolic):
|
||||
raise ValueError('It appears you are trying to construct a '
|
||||
'functional model, but not all of the inputs in '
|
||||
'the first positional argument of your layer call '
|
||||
'are symbolic tensors. '
|
||||
'(Input objects, or the output of another layer) '
|
||||
'Functional models cannot correctly track custom '
|
||||
'layers unless all values in the first call argument '
|
||||
'are symbolic.')
|
||||
return all_inputs_symbolic
|
||||
else:
|
||||
return (base_layer_utils.is_in_keras_graph() or
|
||||
all(hasattr(t, '_keras_history') for t in input_list))
|
||||
|
@ -932,6 +932,72 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(mode='eager'),
|
||||
combinations.combine(use_keras_tensors=False)))
|
||||
def test_only_some_in_first_arg_derived_from_keras_layer(self):
|
||||
class MyAddAll(layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
x = inputs[0]
|
||||
for inp in inputs[1:]:
|
||||
if inp is not None:
|
||||
x = x + inp
|
||||
return x
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
layer = MyAddAll()
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'construct a functional'):
|
||||
layer([0.0, input1, None, input2, None])
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(mode='eager'),
|
||||
combinations.combine(use_keras_tensors=True)))
|
||||
def test_only_some_in_first_arg_derived_from_keras_layer_keras_tensors(self):
|
||||
# This functionality is unsupported in v1 graphs
|
||||
|
||||
class MyAddAll(layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
x = inputs[0]
|
||||
for inp in inputs[1:]:
|
||||
if inp is not None:
|
||||
x = x + inp
|
||||
return x
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
layer = MyAddAll()
|
||||
outputs = layer([0.0, input1, None, input2, None])
|
||||
model = training_lib.Model([input1, input2], outputs)
|
||||
self.assertIn(layer, model.layers)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
# Check serialization.
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'MyAddAll': MyAddAll})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_kwarg_derived_from_keras_layer(self):
|
||||
|
||||
@ -1070,7 +1136,8 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
input2 = input_layer_lib.Input(10)
|
||||
input3 = input_layer_lib.Input(10)
|
||||
|
||||
outputs = AddAll()(
|
||||
layer = AddAll()
|
||||
outputs = layer(
|
||||
[input1, 4 * array_ops.ones((1, 10))],
|
||||
x3={
|
||||
'a': input2,
|
||||
@ -1078,6 +1145,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
'c': 5 * array_ops.ones((1, 10))
|
||||
})
|
||||
model = training_lib.Model([input1, input2, input3], outputs)
|
||||
self.assertIn(layer, model.layers)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
|
Loading…
Reference in New Issue
Block a user