Relax the error about functional api construction w/ a mix of symbolic and non-symbolic tensors for built-in layers (such as layers.add and layers.multiply where using constants is a common user pattern)

PiperOrigin-RevId: 321698209
Change-Id: Ief13e59aec91b787361a7760318ecd47870d938f
This commit is contained in:
Tomer Kaftan 2020-07-16 19:47:51 -07:00
parent 7d310df2de
commit 7e0abd9c89

View File

@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# >> inputs = tf.keras.Input(10)
# >> outputs = MyLayer()(inputs) # Functional construction mode.
# >> model = tf.keras.Model(inputs, outputs)
if _in_functional_construction_mode(inputs, args, kwargs, input_list):
if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
return self._functional_construction_call(inputs, args, kwargs,
input_list)
@ -3205,7 +3205,7 @@ class AddMetric(Layer):
return config
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
"""Check the arguments to see if we are constructing a functional model."""
if keras_tensor.keras_tensors_enabled():
# We are constructing a functional model if any of the inputs
@ -3217,15 +3217,16 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
if context.executing_eagerly():
all_inputs_symbolic = all(
tf_utils.is_symbolic_tensor(t) for t in input_list)
if (any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
[inputs, args, kwargs])) and not all_inputs_symbolic):
if (base_layer_utils.is_subclassed(layer) and
any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
[inputs, args, kwargs])) and not all_inputs_symbolic):
raise ValueError('It appears you are trying to construct a '
'functional model, but not all of the inputs in '
'the first positional argument of your layer call '
'are symbolic tensors. '
'(Input objects, or the output of another layer) '
'Functional models cannot correctly track layers '
'unless all values in the first call argument '
'Functional models cannot correctly track custom '
'layers unless all values in the first call argument '
'are symbolic.')
return all_inputs_symbolic
else: