diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 6f50325015f..1ed603cab6a 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -422,6 +422,19 @@ class Function(object): self._input_signature = input_signature self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY) + def __getstate__(self): + """Custom pickling, to omit unpickleable objects.""" + result = self.__dict__.copy() + del result["_lock"] + del result["_descriptor_cache"] + return result + + def __setstate__(self, state): + """Restore from pickled state.""" + self.__dict__ = state + self._lock = threading.Lock() + self._descriptor_cache = weakref.WeakKeyDictionary() + def _defun_with_scope(self, scope): """Creates a defun wrapped inside a variable creator scope.""" diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 998350f695d..6e23f25ec09 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import functools import itertools +import pickle import re import weakref @@ -68,6 +69,10 @@ class _ModelWithOptimizer(training.Model): return {'loss': loss} +def undecorated_function(x): + return x * 3. + + class _HasDecoratedMethod(object): @def_function.function @@ -747,6 +752,41 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): # If the graph is deleted, then an exception is raised on reading `captures` self.assertEmpty(graph.captures) + @parameterized.parameters(*itertools.product( + (None, (tensor_spec.TensorSpec([]),)), # input_signature + (True, False), # autograph + (None, converter.Feature.ALL), # autograph_options + (None, 'foo.bar'), # implements + (None, True, False), # relax_shapes + )) + def test_pickle(self, input_signature, autograph, autograph_options, + implements, relax_shapes): + """@function objects can be pickled and unpickled.""" + # Can't pickle functions in __main__: + from tensorflow.python.eager.def_function_test import undecorated_function + original_py_function = undecorated_function + + func = def_function.function( + func=original_py_function, + input_signature=input_signature, + autograph=autograph, + experimental_implements=implements, + experimental_autograph_options=autograph_options, + experimental_relax_shapes=relax_shapes, + ) + + cloned = pickle.loads(pickle.dumps(func)) + + self.assertEqual(func._name, cloned._name) + self.assertEqual(input_signature, cloned._input_signature) + self.assertEqual(autograph, cloned._autograph) + self.assertEqual(implements, cloned._implements) + self.assertEqual(autograph_options, cloned._experimental_autograph_options) + self.assertEqual(relax_shapes, cloned._experimental_relax_shapes) + + x = array_ops.ones([]) + self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x))) + if __name__ == '__main__': ops.enable_eager_execution()