Merge pull request #37837 from itamarst:pickleable-tf.function
PiperOrigin-RevId: 304625100 Change-Id: I506e9c14b3362a0a2d5eaf7a3ac79b9a1e1fc0b3
This commit is contained in:
commit
b65ce98d5d
@ -422,6 +422,19 @@ class Function(object):
|
|||||||
self._input_signature = input_signature
|
self._input_signature = input_signature
|
||||||
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
|
self._call_counter = _CallCounter(FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
"""Custom pickling, to omit unpickleable objects."""
|
||||||
|
result = self.__dict__.copy()
|
||||||
|
del result["_lock"]
|
||||||
|
del result["_descriptor_cache"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
"""Restore from pickled state."""
|
||||||
|
self.__dict__ = state
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
def _defun_with_scope(self, scope):
|
def _defun_with_scope(self, scope):
|
||||||
"""Creates a defun wrapped inside a variable creator scope."""
|
"""Creates a defun wrapped inside a variable creator scope."""
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
|
import pickle
|
||||||
import re
|
import re
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
@ -68,6 +69,10 @@ class _ModelWithOptimizer(training.Model):
|
|||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
||||||
|
|
||||||
|
def undecorated_function(x):
|
||||||
|
return x * 3.
|
||||||
|
|
||||||
|
|
||||||
class _HasDecoratedMethod(object):
|
class _HasDecoratedMethod(object):
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -747,6 +752,41 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
# If the graph is deleted, then an exception is raised on reading `captures`
|
# If the graph is deleted, then an exception is raised on reading `captures`
|
||||||
self.assertEmpty(graph.captures)
|
self.assertEmpty(graph.captures)
|
||||||
|
|
||||||
|
@parameterized.parameters(*itertools.product(
|
||||||
|
(None, (tensor_spec.TensorSpec([]),)), # input_signature
|
||||||
|
(True, False), # autograph
|
||||||
|
(None, converter.Feature.ALL), # autograph_options
|
||||||
|
(None, 'foo.bar'), # implements
|
||||||
|
(None, True, False), # relax_shapes
|
||||||
|
))
|
||||||
|
def test_pickle(self, input_signature, autograph, autograph_options,
|
||||||
|
implements, relax_shapes):
|
||||||
|
"""@function objects can be pickled and unpickled."""
|
||||||
|
# Can't pickle functions in __main__:
|
||||||
|
from tensorflow.python.eager.def_function_test import undecorated_function
|
||||||
|
original_py_function = undecorated_function
|
||||||
|
|
||||||
|
func = def_function.function(
|
||||||
|
func=original_py_function,
|
||||||
|
input_signature=input_signature,
|
||||||
|
autograph=autograph,
|
||||||
|
experimental_implements=implements,
|
||||||
|
experimental_autograph_options=autograph_options,
|
||||||
|
experimental_relax_shapes=relax_shapes,
|
||||||
|
)
|
||||||
|
|
||||||
|
cloned = pickle.loads(pickle.dumps(func))
|
||||||
|
|
||||||
|
self.assertEqual(func._name, cloned._name)
|
||||||
|
self.assertEqual(input_signature, cloned._input_signature)
|
||||||
|
self.assertEqual(autograph, cloned._autograph)
|
||||||
|
self.assertEqual(implements, cloned._implements)
|
||||||
|
self.assertEqual(autograph_options, cloned._experimental_autograph_options)
|
||||||
|
self.assertEqual(relax_shapes, cloned._experimental_relax_shapes)
|
||||||
|
|
||||||
|
x = array_ops.ones([])
|
||||||
|
self.assertEqual(self.evaluate(cloned(x)), self.evaluate(func(x)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
Reference in New Issue
Block a user