Allow access SaveOptions through SaveContext

Distributed variables needs to behave differently when tracing functions with
different SaveOptions.

We need to access SaveContext in tf.distribute code instead of the another way
around because we may have a handle to the strategy in saving code.

PiperOrigin-RevId: 321815477
Change-Id: Ib69f6d42c60e198c0e8e174f76bc9424e21df5b5
This commit is contained in:
Ran Chen 2020-07-17 11:11:52 -07:00 committed by TensorFlower Gardener
parent 92b5bde9aa
commit 32bb13dace
6 changed files with 123 additions and 6 deletions

View File

@ -1145,6 +1145,8 @@ distribute_py_test(
deps = [
":combinations",
":distribute_lib",
":distribute_utils",
":packed_distributed_variable",
":strategy_combinations",
":test_util",
":tpu_strategy",
@ -1174,7 +1176,7 @@ distribute_py_test(
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/saved_model:save_context",
"//tensorflow/python/saved_model/model_utils:mode_keys",
"//tensorflow/python/saved_model:save_options",
"//tensorflow/python/tpu:tpu_lib",
"//tensorflow/python/types",
"@absl_py//absl/testing:parameterized",

View File

@ -56,6 +56,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
from tensorflow.python.tpu import tpu_strategy_util
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import util as trackable_utils
@ -597,7 +598,8 @@ class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(
v._packed_variable, packed.PackedDistributedVariable)
with save_context.save_context():
options = save_options.SaveOptions()
with save_context.save_context(options):
self.assertIsNone(v._packed_variable)

View File

@ -302,6 +302,18 @@ py_strict_library(
deps = [],
)
tf_py_test(
name = "save_context_test",
srcs = ["save_context_test.py"],
srcs_version = "PY2AND3",
deps = [
":save_context",
":save_options",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python/eager:test",
],
)
py_strict_library(
name = "save",
srcs = [

View File

@ -1143,6 +1143,6 @@ def _build_meta_graph(obj,
options,
meta_graph_def=None):
"""Creates a MetaGraph under a SaveContext."""
with save_context.save_context():
with save_context.save_context(options):
return _build_meta_graph_impl(obj, export_dir, signatures, options,
meta_graph_def)

View File

@ -28,12 +28,20 @@ class SaveContext(threading.local):
def __init__(self):
super(SaveContext, self).__init__()
self._in_save_context = False
self._options = None
def enter_save_context(self):
def options(self):
if not self.in_save_context():
raise ValueError("not in a SaveContext")
return self._options
def enter_save_context(self, options):
self._in_save_context = True
self._options = options
def exit_save_context(self):
self._in_save_context = False
self._options = None
def in_save_context(self):
return self._in_save_context
@ -42,8 +50,10 @@ _save_context = SaveContext()
@contextlib.contextmanager
def save_context():
_save_context.enter_save_context()
def save_context(options):
if in_save_context():
raise ValueError("already in a SaveContext")
_save_context.enter_save_context(options)
try:
yield
finally:
@ -54,3 +64,7 @@ def in_save_context():
"""Returns whether under a save context."""
return _save_context.in_save_context()
def get_save_options():
"""Returns the save options if under a save context."""
return _save_context.options()

View File

@ -0,0 +1,87 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for SaveContext."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.python.eager import test
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
class SaveContextTest(test.TestCase):
def test_multi_thread(self):
self.assertFalse(save_context.in_save_context())
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
save_context.get_save_options()
options = save_options.SaveOptions(save_debug_info=True)
with save_context.save_context(options):
self.assertTrue(save_context.in_save_context())
self.assertTrue(save_context.get_save_options().save_debug_info)
entered_context_in_thread = threading.Event()
continue_thread = threading.Event()
def thread_fn():
self.assertFalse(save_context.in_save_context())
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
save_context.get_save_options()
options = save_options.SaveOptions(save_debug_info=False)
with save_context.save_context(options):
self.assertTrue(save_context.in_save_context())
# save_debug_info has a different value in this thread.
self.assertFalse(save_context.get_save_options().save_debug_info)
entered_context_in_thread.set()
continue_thread.wait()
self.assertFalse(save_context.in_save_context())
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
save_context.get_save_options()
t = threading.Thread(target=thread_fn)
t.start()
entered_context_in_thread.wait()
# Another thread shouldn't affect this thread.
self.assertTrue(save_context.in_save_context())
self.assertTrue(save_context.get_save_options().save_debug_info)
continue_thread.set()
t.join()
# Another thread exiting SaveContext shouldn't affect this thread.
self.assertTrue(save_context.in_save_context())
self.assertTrue(save_context.get_save_options().save_debug_info)
self.assertFalse(save_context.in_save_context())
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
save_context.get_save_options()
def test_enter_multiple(self):
options = save_options.SaveOptions()
with self.assertRaisesRegex(ValueError, 'already in a SaveContext'):
with save_context.save_context(options):
with save_context.save_context(options):
pass
if __name__ == '__main__':
test.main()