Pass experimental_relax_shapes to instance methods...

decorated with `tf.function`.

Currently the `experimental_relax_shapes` is not passed,
so the instance methods ignore this argument.

Fixes #34905.

Note that the `experimental_relax_shapes` was named differently in
`def_function.Function` and `function.Function`, so the one
in `def_function` was renamed to start with underscore. It is a private
field, so it should be fine.
This commit is contained in:
Milan Straka 2019-12-11 09:06:59 +01:00
parent 115ea3db34
commit e08ae84598
4 changed files with 29 additions and 5 deletions

View File

@ -405,7 +405,7 @@ class Function(object):
self._implements = experimental_implements
self._autograph = autograph
self._experimental_autograph_options = experimental_autograph_options
self.experimental_relax_shapes = experimental_relax_shapes
self._experimental_relax_shapes = experimental_relax_shapes
self._experimental_compile = experimental_compile
self._created_variables = None # GUARDED_BY(self._lock)
self._stateful_fn = None # GUARDED_BY(self._lock)
@ -458,7 +458,7 @@ class Function(object):
attributes=attributes,
autograph=self._autograph,
experimental_autograph_options=self._experimental_autograph_options,
experimental_relax_shapes=self.experimental_relax_shapes)
experimental_relax_shapes=self._experimental_relax_shapes)
def _initialize(self, args, kwds, add_initializers_to=None):
"""Initializes, on the first call.
@ -514,7 +514,7 @@ class Function(object):
autograph=self._autograph,
experimental_implements=self._implements,
experimental_autograph_options=self._experimental_autograph_options,
experimental_relax_shapes=self.experimental_relax_shapes,
experimental_relax_shapes=self._experimental_relax_shapes,
experimental_compile=self._experimental_compile)
def _decorate(self, decorator):

View File

@ -681,7 +681,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(autograph, cloned._autograph)
self.assertEqual(implements, cloned._implements)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned.experimental_relax_shapes)
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
self.assertEqual(compile_, cloned._experimental_compile)
# This test does not run with XLA JIT support linked in so we can only check

View File

@ -3275,7 +3275,8 @@ def class_method_to_instance_method(original_function, instance):
tf_decorator.make_decorator(bound_method, bound_method_wrapper),
name=original_function._name,
autograph=original_function._autograph,
input_signature=original_function.input_signature)
input_signature=original_function.input_signature,
experimental_relax_shapes=original_function._experimental_relax_shapes)
# pylint: enable=protected-access
# And we wrap the function with tf_decorator so inspection works correctly

View File

@ -323,6 +323,29 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
self.assertTrue(unknown_dim[0])
self.assertLen(total_function_cache(func), 2)
def testInputShapeRelaxationOnInstanceMethod(self):
# Test that experimental_relax_shapes is passed during
# instance method bounding.
unknown_dim = [False]
class Foo(object):
@def_function.function(experimental_relax_shapes=True)
def func(self, a):
if a._shape_tuple()[0] is None:
unknown_dim[0] = True
return a + 1
foo = Foo()
foo.func(constant_op.constant([]))
self.assertFalse(unknown_dim[0])
foo.func(constant_op.constant([1.0]))
self.assertFalse(unknown_dim[0])
foo.func(constant_op.constant([1.0, 2.0]))
self.assertTrue(unknown_dim[0])
def testCapturesVariables(self):
a = variables.Variable(1.0, trainable=False)
b = variables.Variable(1.0)