Relax the error about functional api construction w/ a mix of symbolic and non-symbolic tensors for built-in layers (such as layers.add and layers.multiply where using constants is a common user pattern)
PiperOrigin-RevId: 321698209 Change-Id: Ief13e59aec91b787361a7760318ecd47870d938f
This commit is contained in:
parent
7d310df2de
commit
7e0abd9c89
@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||||||
# >> inputs = tf.keras.Input(10)
|
# >> inputs = tf.keras.Input(10)
|
||||||
# >> outputs = MyLayer()(inputs) # Functional construction mode.
|
# >> outputs = MyLayer()(inputs) # Functional construction mode.
|
||||||
# >> model = tf.keras.Model(inputs, outputs)
|
# >> model = tf.keras.Model(inputs, outputs)
|
||||||
if _in_functional_construction_mode(inputs, args, kwargs, input_list):
|
if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
|
||||||
return self._functional_construction_call(inputs, args, kwargs,
|
return self._functional_construction_call(inputs, args, kwargs,
|
||||||
input_list)
|
input_list)
|
||||||
|
|
||||||
@ -3205,7 +3205,7 @@ class AddMetric(Layer):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument
|
||||||
"""Check the arguments to see if we are constructing a functional model."""
|
"""Check the arguments to see if we are constructing a functional model."""
|
||||||
if keras_tensor.keras_tensors_enabled():
|
if keras_tensor.keras_tensors_enabled():
|
||||||
# We are constructing a functional model if any of the inputs
|
# We are constructing a functional model if any of the inputs
|
||||||
@ -3217,15 +3217,16 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
all_inputs_symbolic = all(
|
all_inputs_symbolic = all(
|
||||||
tf_utils.is_symbolic_tensor(t) for t in input_list)
|
tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||||
if (any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
|
if (base_layer_utils.is_subclassed(layer) and
|
||||||
[inputs, args, kwargs])) and not all_inputs_symbolic):
|
any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten(
|
||||||
|
[inputs, args, kwargs])) and not all_inputs_symbolic):
|
||||||
raise ValueError('It appears you are trying to construct a '
|
raise ValueError('It appears you are trying to construct a '
|
||||||
'functional model, but not all of the inputs in '
|
'functional model, but not all of the inputs in '
|
||||||
'the first positional argument of your layer call '
|
'the first positional argument of your layer call '
|
||||||
'are symbolic tensors. '
|
'are symbolic tensors. '
|
||||||
'(Input objects, or the output of another layer) '
|
'(Input objects, or the output of another layer) '
|
||||||
'Functional models cannot correctly track layers '
|
'Functional models cannot correctly track custom '
|
||||||
'unless all values in the first call argument '
|
'layers unless all values in the first call argument '
|
||||||
'are symbolic.')
|
'are symbolic.')
|
||||||
return all_inputs_symbolic
|
return all_inputs_symbolic
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user