diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index a0db48b8282..9c809ed7ea1 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -3097,7 +3097,8 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin if context.executing_eagerly(): return all(tf_utils.is_symbolic_tensor(t) for t in input_list) else: - return base_layer_utils.is_in_keras_graph() + return (base_layer_utils.is_in_keras_graph() or + all(hasattr(t, '_keras_history') for t in input_list)) def _convert_numpy_or_python_types(x): diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 90fc9f2697f..ff6e46e6750 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -1168,6 +1169,28 @@ class NetworkConstructionTest(keras_parameterized.TestCase): self.assertIn('trackable', model._unconditional_dependency_names) self.assertEqual(model.trackable, model._lookup_dependency('trackable')) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_model_construction_in_tf_function(self): + + d = {'model': None} + + @def_function.function + def fn(x): + if d['model'] is None: + # Check that Functional can be built in a `tf.function`. + inputs = input_layer_lib.Input(10) + outputs = layers.Dense(1)(inputs) + model = functional.Functional(inputs, outputs) + d['model'] = model + else: + model = d['model'] + + return model(x) + + x = array_ops.ones((10, 10)) + y = fn(x) + self.assertEqual(y.shape.as_list(), [10, 1]) + class DeferredModeTest(keras_parameterized.TestCase):