Relax the error about functional api construction w/ a mix of symbolic and non-symbolic tensors for built-in layers (such as layers.add and layers.multiply where using constants is a common user pattern)
PiperOrigin-RevId: 321698209 Change-Id: Ief13e59aec91b787361a7760318ecd47870d938f
This commit is contained in:
parent
7d310df2de
commit
7e0abd9c89
@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# >> inputs = tf.keras.Input(10)
|
||||
# >> outputs = MyLayer()(inputs) # Functional construction mode.
|
||||
# >> model = tf.keras.Model(inputs, outputs)
|
||||
if _in_functional_construction_mode(inputs, args, kwargs, input_list):
|
||||
if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
|
||||
return self._functional_construction_call(inputs, args, kwargs,
|
||||
input_list)
|
||||
|
||||
@ -3205,7 +3205,7 @@ class AddMetric(Layer):
|
||||
return config
|
||||
|
||||
|
||||
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
||||
def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
||||
"""Check the arguments to see if we are constructing a functional model."""
|
||||
if keras_tensor.keras_tensors_enabled():
|
||||
# We are constructing a functional model if any of the inputs
|
||||
@ -3217,15 +3217,16 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
|
||||
if context.executing_eagerly():
|
||||
all_inputs_symbolic = all(
|
||||
tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
if (any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
|
||||
[inputs, args, kwargs])) and not all_inputs_symbolic):
|
||||
if (base_layer_utils.is_subclassed(layer) and
|
||||
any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
|
||||
[inputs, args, kwargs])) and not all_inputs_symbolic):
|
||||
raise ValueError('It appears you are trying to construct a '
|
||||
'functional model, but not all of the inputs in '
|
||||
'the first positional argument of your layer call '
|
||||
'are symbolic tensors. '
|
||||
'(Input objects, or the output of another layer) '
|
||||
'Functional models cannot correctly track layers '
|
||||
'unless all values in the first call argument '
|
||||
'Functional models cannot correctly track custom '
|
||||
'layers unless all values in the first call argument '
|
||||
'are symbolic.')
|
||||
return all_inputs_symbolic
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user