Merge pull request #37837 from itamarst:pickleable-tf.function
PiperOrigin-RevId: 304625100 Change-Id: I506e9c14b3362a0a2d5eaf7a3ac79b9a1e1fc0b3
This commit is contained in:
commit
b65ce98d5d
@ -422,6 +422,19 @@ class Function(object):
|
||||
self._input_signature = input_signature
|
||||
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
|
||||
|
||||
def __getstate__(self):
|
||||
"""Custom pickling, to omit unpickleable objects."""
|
||||
result = self.__dict__.copy()
|
||||
del result["_lock"]
|
||||
del result["_descriptor_cache"]
|
||||
return result
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Restore from pickled state."""
|
||||
self.__dict__ = state
|
||||
self._lock = threading.Lock()
|
||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||
|
||||
def _defun_with_scope(self, scope):
|
||||
"""Creates a defun wrapped inside a variable creator scope."""
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import pickle
|
||||
import re
|
||||
import weakref
|
||||
|
||||
@ -68,6 +69,10 @@ class _ModelWithOptimizer(training.Model):
|
||||
return {'loss': loss}
|
||||
|
||||
|
||||
def undecorated_function(x):
|
||||
return x * 3.
|
||||
|
||||
|
||||
class _HasDecoratedMethod(object):
|
||||
|
||||
@def_function.function
|
||||
@ -747,6 +752,41 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
# If the graph is deleted, then an exception is raised on reading `captures`
|
||||
self.assertEmpty(graph.captures)
|
||||
|
||||
@parameterized.parameters(*itertools.product(
|
||||
(None, (tensor_spec.TensorSpec([]),)), # input_signature
|
||||
(True, False), # autograph
|
||||
(None, converter.Feature.ALL), # autograph_options
|
||||
(None, 'foo.bar'), # implements
|
||||
(None, True, False), # relax_shapes
|
||||
))
|
||||
def test_pickle(self, input_signature, autograph, autograph_options,
|
||||
implements, relax_shapes):
|
||||
"""@function objects can be pickled and unpickled."""
|
||||
# Can't pickle functions in __main__:
|
||||
from tensorflow.python.eager.def_function_test import undecorated_function
|
||||
original_py_function = undecorated_function
|
||||
|
||||
func = def_function.function(
|
||||
func=original_py_function,
|
||||
input_signature=input_signature,
|
||||
autograph=autograph,
|
||||
experimental_implements=implements,
|
||||
experimental_autograph_options=autograph_options,
|
||||
experimental_relax_shapes=relax_shapes,
|
||||
)
|
||||
|
||||
cloned = pickle.loads(pickle.dumps(func))
|
||||
|
||||
self.assertEqual(func._name, cloned._name)
|
||||
self.assertEqual(input_signature, cloned._input_signature)
|
||||
self.assertEqual(autograph, cloned._autograph)
|
||||
self.assertEqual(implements, cloned._implements)
|
||||
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
|
||||
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
|
||||
|
||||
x = array_ops.ones([])
|
||||
self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user