Merge pull request #37837 from itamarst:pickleable-tf.function

PiperOrigin-RevId: 304625100
Change-Id: I506e9c14b3362a0a2d5eaf7a3ac79b9a1e1fc0b3
This commit is contained in:
TensorFlower Gardener 2020-04-03 08:58:42 -07:00
commit b65ce98d5d
2 changed files with 53 additions and 0 deletions

View File

@ -422,6 +422,19 @@ class Function(object):
self._input_signature = input_signature
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
def __getstate__(self):
"""Custom pickling, to omit unpickleable objects."""
result = self.__dict__.copy()
del result["_lock"]
del result["_descriptor_cache"]
return result
def __setstate__(self, state):
"""Restore from pickled state."""
self.__dict__ = state
self._lock = threading.Lock()
self._descriptor_cache = weakref.WeakKeyDictionary()
def _defun_with_scope(self, scope):
"""Creates a defun wrapped inside a variable creator scope."""

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import functools
import itertools
import pickle
import re
import weakref
@ -68,6 +69,10 @@ class _ModelWithOptimizer(training.Model):
return {'loss': loss}
def undecorated_function(x):
return x * 3.
class _HasDecoratedMethod(object):
@def_function.function
@ -747,6 +752,41 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
# If the graph is deleted, then an exception is raised on reading `captures`
self.assertEmpty(graph.captures)
@parameterized.parameters(*itertools.product(
(None, (tensor_spec.TensorSpec([]),)), # input_signature
(True, False), # autograph
(None, converter.Feature.ALL), # autograph_options
(None, 'foo.bar'), # implements
(None, True, False), # relax_shapes
))
def test_pickle(self, input_signature, autograph, autograph_options,
implements, relax_shapes):
"""@function objects can be pickled and unpickled."""
# Can't pickle functions in __main__:
from tensorflow.python.eager.def_function_test import undecorated_function
original_py_function = undecorated_function
func = def_function.function(
func=original_py_function,
input_signature=input_signature,
autograph=autograph,
experimental_implements=implements,
experimental_autograph_options=autograph_options,
experimental_relax_shapes=relax_shapes,
)
cloned = pickle.loads(pickle.dumps(func))
self.assertEqual(func._name, cloned._name)
self.assertEqual(input_signature, cloned._input_signature)
self.assertEqual(autograph, cloned._autograph)
self.assertEqual(implements, cloned._implements)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
x = array_ops.ones([])
self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x)))
if __name__ == '__main__':
ops.enable_eager_execution()