From 93ce587122252ca062547090f71cac4657ddf538 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 23 Mar 2020 13:56:00 -0400 Subject: [PATCH 1/2] Make @tf.function-wrapped functions pickleable. --- tensorflow/python/eager/def_function.py | 13 +++++++ tensorflow/python/eager/def_function_test.py | 38 ++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 6f50325015f..1ed603cab6a 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -422,6 +422,19 @@ class Function(object): self._input_signature = input_signature self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY) + def __getstate__(self): + """Custom pickling, to omit unpickleable objects.""" + result = self.__dict__.copy() + del result["_lock"] + del result["_descriptor_cache"] + return result + + def __setstate__(self, state): + """Restore from pickled state.""" + self.__dict__ = state + self._lock = threading.Lock() + self._descriptor_cache = weakref.WeakKeyDictionary() + def _defun_with_scope(self, scope): """Creates a defun wrapped inside a variable creator scope.""" diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 998350f695d..6a8582183bc 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import functools import itertools +import pickle import re import weakref @@ -67,6 +68,8 @@ class _ModelWithOptimizer(training.Model): self.optimizer.apply_gradients(zip(gradients, trainable_variables)) return {'loss': loss} +def undecorated_function(x): + return x * 3. class _HasDecoratedMethod(object): @@ -747,6 +750,41 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): # If the graph is deleted, then an exception is raised on reading `captures` self.assertEmpty(graph.captures) + @parameterized.parameters(*itertools.product( + (None, (tensor_spec.TensorSpec([]),)), # input_signature + (True, False), # autograph + (None, converter.Feature.ALL), # autograph_options + (None, 'foo.bar'), # implements + (None, True, False), # relax_shapes + )) + def test_pickle(self, input_signature, autograph, autograph_options, implements, + relax_shapes): + """@function objects can be pickled and unpickled.""" + # Can't pickle functions in __main__: + from tensorflow.python.eager.def_function_test import undecorated_function + original_py_function = undecorated_function + + func = def_function.function( + func=original_py_function, + input_signature=input_signature, + autograph=autograph, + experimental_implements=implements, + experimental_autograph_options=autograph_options, + experimental_relax_shapes=relax_shapes, + ) + + cloned = pickle.loads(pickle.dumps(func)) + + self.assertEqual(func._name, cloned._name) + self.assertEqual(input_signature, cloned._input_signature) + self.assertEqual(autograph, cloned._autograph) + self.assertEqual(implements, cloned._implements) + self.assertEqual(autograph_options, cloned._experimental_autograph_options) + self.assertEqual(relax_shapes, cloned._experimental_relax_shapes) + + x = array_ops.ones([]) + self.assertEqual(self.evaluate(cloned(x)), + self.evaluate(func(x))) if __name__ == '__main__': ops.enable_eager_execution() From 9274fe4678f120de6081d5f9f4395c2181ca33dc Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 25 Mar 2020 13:07:45 -0400 Subject: [PATCH 2/2] Try to pacify pylint. --- tensorflow/python/eager/def_function_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 6a8582183bc..7c06d7b8ae0 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -757,8 +757,8 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): (None, 'foo.bar'), # implements (None, True, False), # relax_shapes )) - def test_pickle(self, input_signature, autograph, autograph_options, implements, - relax_shapes): + def test_pickle(self, input_signature, autograph, autograph_options, + implements, relax_shapes): """@function objects can be pickled and unpickled.""" # Can't pickle functions in __main__: from tensorflow.python.eager.def_function_test import undecorated_function