Raise an error when some but not all values passed to the first layer call arg are symbolic. This setting can cause functional models to be constructed incorrectly.

Support for this will be added when we enable the KerasTensors refactoring.

Addreses GitHub Issue #40638

PiperOrigin-RevId: 321639068
Change-Id: Iebf0e1198018fe44b1f60673bd991a9262ecef7d
This commit is contained in:
Tomer Kaftan 2020-07-16 13:54:22 -07:00
parent 0aa1d61fad
commit 7d310df2de
2 changed files with 82 additions and 2 deletions

View File

@ -3215,7 +3215,19 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
for tensor in nest.flatten([inputs, args, kwargs])) for tensor in nest.flatten([inputs, args, kwargs]))
else: else:
if context.executing_eagerly(): if context.executing_eagerly():
return all(tf_utils.is_symbolic_tensor(t) for t in input_list) all_inputs_symbolic = all(
tf_utils.is_symbolic_tensor(t) for t in input_list)
if (any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
[inputs, args, kwargs])) and not all_inputs_symbolic):
raise ValueError('It appears you are trying to construct a '
'functional model, but not all of the inputs in '
'the first positional argument of your layer call '
'are symbolic tensors. '
'(Input objects, or the output of another layer) '
'Functional models cannot correctly track layers '
'unless all values in the first call argument '
'are symbolic.')
return all_inputs_symbolic
else: else:
return (base_layer_utils.is_in_keras_graph() or return (base_layer_utils.is_in_keras_graph() or
all(hasattr(t, '_keras_history') for t in input_list)) all(hasattr(t, '_keras_history') for t in input_list))

View File

@ -931,6 +931,72 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
# Check that second input was correctly added to first. # Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0) self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(mode='eager'),
combinations.combine(use_keras_tensors=False)))
def test_only_some_in_first_arg_derived_from_keras_layer(self):
class MyAddAll(layers.Layer):
def call(self, inputs):
x = inputs[0]
for inp in inputs[1:]:
if inp is not None:
x = x + inp
return x
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
layer = MyAddAll()
with self.assertRaisesRegexp(ValueError, 'construct a functional'):
layer([0.0, input1, None, input2, None])
@combinations.generate(combinations.times(
combinations.keras_mode_combinations(mode='eager'),
combinations.combine(use_keras_tensors=True)))
def test_only_some_in_first_arg_derived_from_keras_layer_keras_tensors(self):
# This functionality is unsupported in v1 graphs
class MyAddAll(layers.Layer):
def call(self, inputs):
x = inputs[0]
for inp in inputs[1:]:
if inp is not None:
x = x + inp
return x
input1 = input_layer_lib.Input(10)
input2 = input_layer_lib.Input(10)
layer = MyAddAll()
outputs = layer([0.0, input1, None, input2, None])
model = training_lib.Model([input1, input2], outputs)
self.assertIn(layer, model.layers)
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
# Check serialization.
model = training_lib.Model.from_config(
model.get_config(), custom_objects={'MyAddAll': MyAddAll})
model.compile(
'sgd',
'mse',
run_eagerly=testing_utils.should_run_eagerly())
history = model.fit(
x=[3 * np.ones((10, 10)), 7 * np.ones((10, 10))],
y=10 * np.ones((10, 10)),
batch_size=2)
# Check that second input was correctly added to first.
self.assertEqual(history.history['loss'][0], 0.0)
@combinations.generate(combinations.keras_mode_combinations()) @combinations.generate(combinations.keras_mode_combinations())
def test_call_kwarg_derived_from_keras_layer(self): def test_call_kwarg_derived_from_keras_layer(self):
@ -1069,7 +1135,8 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
input2 = input_layer_lib.Input(10) input2 = input_layer_lib.Input(10)
input3 = input_layer_lib.Input(10) input3 = input_layer_lib.Input(10)
outputs = AddAll()( layer = AddAll()
outputs = layer(
[input1, 4 * array_ops.ones((1, 10))], [input1, 4 * array_ops.ones((1, 10))],
x3={ x3={
'a': input2, 'a': input2,
@ -1077,6 +1144,7 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
'c': 5 * array_ops.ones((1, 10)) 'c': 5 * array_ops.ones((1, 10))
}) })
model = training_lib.Model([input1, input2, input3], outputs) model = training_lib.Model([input1, input2, input3], outputs)
self.assertIn(layer, model.layers)
model.compile( model.compile(
'sgd', 'sgd',
'mse', 'mse',