diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 6e3a6814608..2d5aab83665 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -921,7 +921,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # >> inputs = tf.keras.Input(10) # >> outputs = MyLayer()(inputs) # Functional construction mode. # >> model = tf.keras.Model(inputs, outputs) - if _in_functional_construction_mode(inputs, args, kwargs, input_list): + if _in_functional_construction_mode(self, inputs, args, kwargs, input_list): return self._functional_construction_call(inputs, args, kwargs, input_list) @@ -3205,7 +3205,7 @@ class AddMetric(Layer): return config -def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylint: disable=unused-argument +def _in_functional_construction_mode(layer, inputs, args, kwargs, input_list): # pylint: disable=unused-argument """Check the arguments to see if we are constructing a functional model.""" if keras_tensor.keras_tensors_enabled(): # We are constructing a functional model if any of the inputs @@ -3217,15 +3217,16 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin if context.executing_eagerly(): all_inputs_symbolic = all( tf_utils.is_symbolic_tensor(t) for t in input_list) - if (any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten( - [inputs, args, kwargs])) and not all_inputs_symbolic): + if (base_layer_utils.is_subclassed(layer) and + any(tf_utils.is_symbolic_tensor(t) for t in nest.flatten( + [inputs, args, kwargs])) and not all_inputs_symbolic): raise ValueError('It appears you are trying to construct a ' 'functional model, but not all of the inputs in ' 'the first positional argument of your layer call ' 'are symbolic tensors. ' '(Input objects, or the output of another layer) ' - 'Functional models cannot correctly track layers ' - 'unless all values in the first call argument ' + 'Functional models cannot correctly track custom ' + 'layers unless all values in the first call argument ' 'are symbolic.') return all_inputs_symbolic else: