Pass experimental_relax_shapes to instance methods...
decorated with `tf.function`. Currently the `experimental_relax_shapes` is not passed, so the instance methods ignore this argument. Fixes #34905. Note that the `experimental_relax_shapes` was named differently in `def_function.Function` and `function.Function`, so the one in `def_function` was renamed to start with underscore. It is a private field, so it should be fine.
This commit is contained in:
parent
115ea3db34
commit
e08ae84598
@ -405,7 +405,7 @@ class Function(object):
|
||||
self._implements = experimental_implements
|
||||
self._autograph = autograph
|
||||
self._experimental_autograph_options = experimental_autograph_options
|
||||
self.experimental_relax_shapes = experimental_relax_shapes
|
||||
self._experimental_relax_shapes = experimental_relax_shapes
|
||||
self._experimental_compile = experimental_compile
|
||||
self._created_variables = None # GUARDED_BY(self._lock)
|
||||
self._stateful_fn = None # GUARDED_BY(self._lock)
|
||||
@ -458,7 +458,7 @@ class Function(object):
|
||||
attributes=attributes,
|
||||
autograph=self._autograph,
|
||||
experimental_autograph_options=self._experimental_autograph_options,
|
||||
experimental_relax_shapes=self.experimental_relax_shapes)
|
||||
experimental_relax_shapes=self._experimental_relax_shapes)
|
||||
|
||||
def _initialize(self, args, kwds, add_initializers_to=None):
|
||||
"""Initializes, on the first call.
|
||||
@ -514,7 +514,7 @@ class Function(object):
|
||||
autograph=self._autograph,
|
||||
experimental_implements=self._implements,
|
||||
experimental_autograph_options=self._experimental_autograph_options,
|
||||
experimental_relax_shapes=self.experimental_relax_shapes,
|
||||
experimental_relax_shapes=self._experimental_relax_shapes,
|
||||
experimental_compile=self._experimental_compile)
|
||||
|
||||
def _decorate(self, decorator):
|
||||
|
@ -681,7 +681,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(autograph, cloned._autograph)
|
||||
self.assertEqual(implements, cloned._implements)
|
||||
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
|
||||
self.assertEqual(relax_shapes, cloned.experimental_relax_shapes)
|
||||
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
|
||||
self.assertEqual(compile_, cloned._experimental_compile)
|
||||
|
||||
# This test does not run with XLA JIT support linked in so we can only check
|
||||
|
@ -3275,7 +3275,8 @@ def class_method_to_instance_method(original_function, instance):
|
||||
tf_decorator.make_decorator(bound_method, bound_method_wrapper),
|
||||
name=original_function._name,
|
||||
autograph=original_function._autograph,
|
||||
input_signature=original_function.input_signature)
|
||||
input_signature=original_function.input_signature,
|
||||
experimental_relax_shapes=original_function._experimental_relax_shapes)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# And we wrap the function with tf_decorator so inspection works correctly
|
||||
|
@ -323,6 +323,29 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertTrue(unknown_dim[0])
|
||||
self.assertLen(total_function_cache(func), 2)
|
||||
|
||||
def testInputShapeRelaxationOnInstanceMethod(self):
|
||||
# Test that experimental_relax_shapes is passed during
|
||||
# instance method bounding.
|
||||
unknown_dim = [False]
|
||||
|
||||
class Foo(object):
|
||||
|
||||
@def_function.function(experimental_relax_shapes=True)
|
||||
def func(self, a):
|
||||
if a._shape_tuple()[0] is None:
|
||||
unknown_dim[0] = True
|
||||
return a + 1
|
||||
|
||||
foo = Foo()
|
||||
foo.func(constant_op.constant([]))
|
||||
self.assertFalse(unknown_dim[0])
|
||||
|
||||
foo.func(constant_op.constant([1.0]))
|
||||
self.assertFalse(unknown_dim[0])
|
||||
|
||||
foo.func(constant_op.constant([1.0, 2.0]))
|
||||
self.assertTrue(unknown_dim[0])
|
||||
|
||||
def testCapturesVariables(self):
|
||||
a = variables.Variable(1.0, trainable=False)
|
||||
b = variables.Variable(1.0)
|
||||
|
Loading…
Reference in New Issue
Block a user