Fix to allow Functional Models to be constructed inside a tf.function (if

constuction is guarded so as to only happen on first trace).

PiperOrigin-RevId: 314883284
Change-Id: Id10bdad87c4cba56ff6c6122bd2f92f96958a75a
This commit is contained in:
Thomas O'Malley 2020-06-05 01:00:30 -07:00 committed by TensorFlower Gardener
parent 3dcbf3c4fd
commit a39ac9de5f
2 changed files with 25 additions and 1 deletions

View File

@ -3097,7 +3097,8 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
if context.executing_eagerly():
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
else:
return base_layer_utils.is_in_keras_graph()
return (base_layer_utils.is_in_keras_graph() or
all(hasattr(t, '_keras_history') for t in input_list))
def _convert_numpy_or_python_types(x):

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -1168,6 +1169,28 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
self.assertIn('trackable', model._unconditional_dependency_names)
self.assertEqual(model.trackable, model._lookup_dependency('trackable'))
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_model_construction_in_tf_function(self):
d = {'model': None}
@def_function.function
def fn(x):
if d['model'] is None:
# Check that Functional can be built in a `tf.function`.
inputs = input_layer_lib.Input(10)
outputs = layers.Dense(1)(inputs)
model = functional.Functional(inputs, outputs)
d['model'] = model
else:
model = d['model']
return model(x)
x = array_ops.ones((10, 10))
y = fn(x)
self.assertEqual(y.shape.as_list(), [10, 1])
class DeferredModeTest(keras_parameterized.TestCase):