Fix to allow Functional Models to be constructed inside a tf.function (if
constuction is guarded so as to only happen on first trace). PiperOrigin-RevId: 314883284 Change-Id: Id10bdad87c4cba56ff6c6122bd2f92f96958a75a
This commit is contained in:
parent
3dcbf3c4fd
commit
a39ac9de5f
|
@ -3097,7 +3097,8 @@ def _in_functional_construction_mode(inputs, args, kwargs, input_list): # pylin
|
|||
if context.executing_eagerly():
|
||||
return all(tf_utils.is_symbolic_tensor(t) for t in input_list)
|
||||
else:
|
||||
return base_layer_utils.is_in_keras_graph()
|
||||
return (base_layer_utils.is_in_keras_graph() or
|
||||
all(hasattr(t, '_keras_history') for t in input_list))
|
||||
|
||||
|
||||
def _convert_numpy_or_python_types(x):
|
||||
|
|
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
@ -1168,6 +1169,28 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
|||
self.assertIn('trackable', model._unconditional_dependency_names)
|
||||
self.assertEqual(model.trackable, model._lookup_dependency('trackable'))
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_model_construction_in_tf_function(self):
|
||||
|
||||
d = {'model': None}
|
||||
|
||||
@def_function.function
|
||||
def fn(x):
|
||||
if d['model'] is None:
|
||||
# Check that Functional can be built in a `tf.function`.
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = layers.Dense(1)(inputs)
|
||||
model = functional.Functional(inputs, outputs)
|
||||
d['model'] = model
|
||||
else:
|
||||
model = d['model']
|
||||
|
||||
return model(x)
|
||||
|
||||
x = array_ops.ones((10, 10))
|
||||
y = fn(x)
|
||||
self.assertEqual(y.shape.as_list(), [10, 1])
|
||||
|
||||
|
||||
class DeferredModeTest(keras_parameterized.TestCase):
|
||||
|
||||
|
|
Loading…
Reference in New Issue