Make @tf.function-wrapped functions pickleable.

This commit is contained in:
Itamar Turner-Trauring 2020-03-23 13:56:00 -04:00
parent c58518c893
commit 93ce587122
2 changed files with 51 additions and 0 deletions

View File

@ -422,6 +422,19 @@ class Function(object):
self._input_signature = input_signature
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
def __getstate__(self):
"""Custom pickling, to omit unpickleable objects."""
result = self.__dict__.copy()
del result["_lock"]
del result["_descriptor_cache"]
return result
def __setstate__(self, state):
"""Restore from pickled state."""
self.__dict__ = state
self._lock = threading.Lock()
self._descriptor_cache = weakref.WeakKeyDictionary()
def _defun_with_scope(self, scope):
"""Creates a defun wrapped inside a variable creator scope."""

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import functools
import itertools
import pickle
import re
import weakref
@ -67,6 +68,8 @@ class _ModelWithOptimizer(training.Model):
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
return {'loss': loss}
def undecorated_function(x):
return x * 3.
class _HasDecoratedMethod(object):
@ -747,6 +750,41 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
# If the graph is deleted, then an exception is raised on reading `captures`
self.assertEmpty(graph.captures)
@parameterized.parameters(*itertools.product(
(None, (tensor_spec.TensorSpec([]),)), # input_signature
(True, False), # autograph
(None, converter.Feature.ALL), # autograph_options
(None, 'foo.bar'), # implements
(None, True, False), # relax_shapes
))
def test_pickle(self, input_signature, autograph, autograph_options, implements,
relax_shapes):
"""@function objects can be pickled and unpickled."""
# Can't pickle functions in __main__:
from tensorflow.python.eager.def_function_test import undecorated_function
original_py_function = undecorated_function
func = def_function.function(
func=original_py_function,
input_signature=input_signature,
autograph=autograph,
experimental_implements=implements,
experimental_autograph_options=autograph_options,
experimental_relax_shapes=relax_shapes,
)
cloned = pickle.loads(pickle.dumps(func))
self.assertEqual(func._name, cloned._name)
self.assertEqual(input_signature, cloned._input_signature)
self.assertEqual(autograph, cloned._autograph)
self.assertEqual(implements, cloned._implements)
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
x = array_ops.ones([])
self.assertEqual(self.evaluate(cloned(x)),
self.evaluate(func(x)))
if __name__ == '__main__':
ops.enable_eager_execution()