Raise an error when some but not all values passed to the first layer call arg are symbolic. This setting can cause functional models to be constructed incorrectly.
Support for this will be added when we enable the KerasTensors refactoring. Addreses GitHub Issue #40638 PiperOrigin-RevId: 321639068 Change-Id: Iebf0e1198018fe44b1f60673bd991a9262ecef7d
This commit is contained in:
parent
0aa1d61fad
commit
7d310df2de
tensorflow/python/keras/engine
@ -3215,7 +3215,19 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
|
||||
for tensor in nest.flatten([inputs, args, kwargs]))
|
||||
else:
|
||||
if context.executing_eagerly():
|
||||
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
all_inputs_symbolic = all(
|
||||
tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
if (any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
|
||||
[inputs, args, kwargs])) and not all_inputs_symbolic):
|
||||
raise ValueError('It appears you are trying to construct a '
|
||||
'functional model, but not all of the inputs in '
|
||||
'the first positional argument of your layer call '
|
||||
'are symbolic tensors. '
|
||||
'(Input objects, or the output of another layer) '
|
||||
'Functional models cannot correctly track layers '
|
||||
'unless all values in the first call argument '
|
||||
'are symbolic.')
|
||||
return all_inputs_symbolic
|
||||
else:
|
||||
return (base_layer_utils.is_in_keras_graph() or
|
||||
all(hasattr(t, '_keras_history') for t in input_list))
|
||||
|
@ -931,6 +931,72 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(mode='eager'),
|
||||
combinations.combine(use_keras_tensors=False)))
|
||||
def test_only_some_in_first_arg_derived_from_keras_layer(self):
|
||||
class MyAddAll(layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
x = inputs[0]
|
||||
for inp in inputs[1:]:
|
||||
if inp is not None:
|
||||
x = x + inp
|
||||
return x
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
layer = MyAddAll()
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'construct a functional'):
|
||||
layer([0.0, input1, None, input2, None])
|
||||
|
||||
@combinations.generate(combinations.times(
|
||||
combinations.keras_mode_combinations(mode='eager'),
|
||||
combinations.combine(use_keras_tensors=True)))
|
||||
def test_only_some_in_first_arg_derived_from_keras_layer_keras_tensors(self):
|
||||
# This functionality is unsupported in v1 graphs
|
||||
|
||||
class MyAddAll(layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
x = inputs[0]
|
||||
for inp in inputs[1:]:
|
||||
if inp is not None:
|
||||
x = x + inp
|
||||
return x
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
layer = MyAddAll()
|
||||
outputs = layer([0.0, input1, None, input2, None])
|
||||
model = training_lib.Model([input1, input2], outputs)
|
||||
self.assertIn(layer, model.layers)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
# Check serialization.
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'MyAddAll': MyAddAll})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_kwarg_derived_from_keras_layer(self):
|
||||
|
||||
@ -1069,7 +1135,8 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
input2 = input_layer_lib.Input(10)
|
||||
input3 = input_layer_lib.Input(10)
|
||||
|
||||
outputs = AddAll()(
|
||||
layer = AddAll()
|
||||
outputs = layer(
|
||||
[input1, 4 * array_ops.ones((1, 10))],
|
||||
x3={
|
||||
'a': input2,
|
||||
@ -1077,6 +1144,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
'c': 5 * array_ops.ones((1, 10))
|
||||
})
|
||||
model = training_lib.Model([input1, input2, input3], outputs)
|
||||
self.assertIn(layer, model.layers)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
|
Loading…
Reference in New Issue
Block a user