Allow access SaveOptions through SaveContext
Distributed variables needs to behave differently when tracing functions with different SaveOptions. We need to access SaveContext in tf.distribute code instead of the another way around because we may have a handle to the strategy in saving code. PiperOrigin-RevId: 321815477 Change-Id: Ib69f6d42c60e198c0e8e174f76bc9424e21df5b5
This commit is contained in:
parent
92b5bde9aa
commit
32bb13dace
@ -1145,6 +1145,8 @@ distribute_py_test(
|
||||
deps = [
|
||||
":combinations",
|
||||
":distribute_lib",
|
||||
":distribute_utils",
|
||||
":packed_distributed_variable",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
":tpu_strategy",
|
||||
@ -1174,7 +1176,7 @@ distribute_py_test(
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/saved_model:save_context",
|
||||
"//tensorflow/python/saved_model/model_utils:mode_keys",
|
||||
"//tensorflow/python/saved_model:save_options",
|
||||
"//tensorflow/python/tpu:tpu_lib",
|
||||
"//tensorflow/python/types",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -56,6 +56,7 @@ from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.saved_model import save_context
|
||||
from tensorflow.python.saved_model import save_options
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
@ -597,7 +598,8 @@ class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsInstance(
|
||||
v._packed_variable, packed.PackedDistributedVariable)
|
||||
|
||||
with save_context.save_context():
|
||||
options = save_options.SaveOptions()
|
||||
with save_context.save_context(options):
|
||||
self.assertIsNone(v._packed_variable)
|
||||
|
||||
|
||||
|
@ -302,6 +302,18 @@ py_strict_library(
|
||||
deps = [],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "save_context_test",
|
||||
srcs = ["save_context_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":save_context",
|
||||
":save_options",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
||||
py_strict_library(
|
||||
name = "save",
|
||||
srcs = [
|
||||
|
@ -1143,6 +1143,6 @@ def _build_meta_graph(obj,
|
||||
options,
|
||||
meta_graph_def=None):
|
||||
"""Creates a MetaGraph under a SaveContext."""
|
||||
with save_context.save_context():
|
||||
with save_context.save_context(options):
|
||||
return _build_meta_graph_impl(obj, export_dir, signatures, options,
|
||||
meta_graph_def)
|
||||
|
@ -28,12 +28,20 @@ class SaveContext(threading.local):
|
||||
def __init__(self):
|
||||
super(SaveContext, self).__init__()
|
||||
self._in_save_context = False
|
||||
self._options = None
|
||||
|
||||
def enter_save_context(self):
|
||||
def options(self):
|
||||
if not self.in_save_context():
|
||||
raise ValueError("not in a SaveContext")
|
||||
return self._options
|
||||
|
||||
def enter_save_context(self, options):
|
||||
self._in_save_context = True
|
||||
self._options = options
|
||||
|
||||
def exit_save_context(self):
|
||||
self._in_save_context = False
|
||||
self._options = None
|
||||
|
||||
def in_save_context(self):
|
||||
return self._in_save_context
|
||||
@ -42,8 +50,10 @@ _save_context = SaveContext()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def save_context():
|
||||
_save_context.enter_save_context()
|
||||
def save_context(options):
|
||||
if in_save_context():
|
||||
raise ValueError("already in a SaveContext")
|
||||
_save_context.enter_save_context(options)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -54,3 +64,7 @@ def in_save_context():
|
||||
"""Returns whether under a save context."""
|
||||
return _save_context.in_save_context()
|
||||
|
||||
|
||||
def get_save_options():
|
||||
"""Returns the save options if under a save context."""
|
||||
return _save_context.options()
|
||||
|
87
tensorflow/python/saved_model/save_context_test.py
Normal file
87
tensorflow/python/saved_model/save_context_test.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test for SaveContext."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.saved_model import save_context
|
||||
from tensorflow.python.saved_model import save_options
|
||||
|
||||
|
||||
class SaveContextTest(test.TestCase):
|
||||
|
||||
def test_multi_thread(self):
|
||||
self.assertFalse(save_context.in_save_context())
|
||||
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
|
||||
save_context.get_save_options()
|
||||
|
||||
options = save_options.SaveOptions(save_debug_info=True)
|
||||
with save_context.save_context(options):
|
||||
self.assertTrue(save_context.in_save_context())
|
||||
self.assertTrue(save_context.get_save_options().save_debug_info)
|
||||
|
||||
entered_context_in_thread = threading.Event()
|
||||
continue_thread = threading.Event()
|
||||
|
||||
def thread_fn():
|
||||
self.assertFalse(save_context.in_save_context())
|
||||
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
|
||||
save_context.get_save_options()
|
||||
|
||||
options = save_options.SaveOptions(save_debug_info=False)
|
||||
with save_context.save_context(options):
|
||||
self.assertTrue(save_context.in_save_context())
|
||||
# save_debug_info has a different value in this thread.
|
||||
self.assertFalse(save_context.get_save_options().save_debug_info)
|
||||
entered_context_in_thread.set()
|
||||
continue_thread.wait()
|
||||
|
||||
self.assertFalse(save_context.in_save_context())
|
||||
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
|
||||
save_context.get_save_options()
|
||||
|
||||
t = threading.Thread(target=thread_fn)
|
||||
t.start()
|
||||
|
||||
entered_context_in_thread.wait()
|
||||
# Another thread shouldn't affect this thread.
|
||||
self.assertTrue(save_context.in_save_context())
|
||||
self.assertTrue(save_context.get_save_options().save_debug_info)
|
||||
|
||||
continue_thread.set()
|
||||
t.join()
|
||||
# Another thread exiting SaveContext shouldn't affect this thread.
|
||||
self.assertTrue(save_context.in_save_context())
|
||||
self.assertTrue(save_context.get_save_options().save_debug_info)
|
||||
|
||||
self.assertFalse(save_context.in_save_context())
|
||||
with self.assertRaisesRegex(ValueError, 'not in a SaveContext'):
|
||||
save_context.get_save_options()
|
||||
|
||||
def test_enter_multiple(self):
|
||||
options = save_options.SaveOptions()
|
||||
with self.assertRaisesRegex(ValueError, 'already in a SaveContext'):
|
||||
with save_context.save_context(options):
|
||||
with save_context.save_context(options):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user