From 082ca0493ecde58a799b14f5b68847a86f6b4f2f Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 10 Sep 2020 12:05:03 -0700 Subject: [PATCH] [C++ gradients] Move thread_local_stack to its own lib so that it can re-used for tape stack. PiperOrigin-RevId: 330987171 Change-Id: Ibbd068168b460f977882c65b45f7ef2e7e76f2fe --- .../python/framework/experimental/BUILD | 6 +++ .../framework/experimental/context_stack.py | 22 +---------- .../experimental/thread_local_stack.py | 39 +++++++++++++++++++ 3 files changed, 47 insertions(+), 20 deletions(-) create mode 100644 tensorflow/python/framework/experimental/thread_local_stack.py diff --git a/tensorflow/python/framework/experimental/BUILD b/tensorflow/python/framework/experimental/BUILD index 5340d0549bb..4b7f4fd2b04 100644 --- a/tensorflow/python/framework/experimental/BUILD +++ b/tensorflow/python/framework/experimental/BUILD @@ -55,9 +55,15 @@ py_library( srcs = ["def_function.py"], ) +py_library( + name = "thread_local_stack", + srcs = ["thread_local_stack.py"], +) + py_library( name = "context_stack", srcs = ["context_stack.py"], + deps = [":thread_local_stack"], ) cuda_py_test( diff --git a/tensorflow/python/framework/experimental/context_stack.py b/tensorflow/python/framework/experimental/context_stack.py index 44968d631c9..7e29c1fb36e 100644 --- a/tensorflow/python/framework/experimental/context_stack.py +++ b/tensorflow/python/framework/experimental/context_stack.py @@ -19,28 +19,10 @@ from __future__ import division from __future__ import print_function import contextlib -import threading +from tensorflow.python.framework.experimental import thread_local_stack -# TODO(srbs): Move this to C++. -class _ThreadLocalStack(threading.local): - """A thread-local stack of objects for providing implicit defaults.""" - - def __init__(self): - super(_ThreadLocalStack, self).__init__() - self._stack = [] - - def peek(self): - return self._stack[-1] if self._stack else None - - def push(self, ctx): - return self._stack.append(ctx) - - def pop(self): - self._stack.pop() - - -_default_ctx_stack = _ThreadLocalStack() +_default_ctx_stack = thread_local_stack.ThreadLocalStack() def get_default(): diff --git a/tensorflow/python/framework/experimental/thread_local_stack.py b/tensorflow/python/framework/experimental/thread_local_stack.py new file mode 100644 index 00000000000..7042f32902c --- /dev/null +++ b/tensorflow/python/framework/experimental/thread_local_stack.py @@ -0,0 +1,39 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Thread-local stack.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + + +# TODO(srbs): Move this to C++. +class ThreadLocalStack(threading.local): + """A thread-local stack of objects for providing implicit defaults.""" + + def __init__(self): + super(ThreadLocalStack, self).__init__() + self._stack = [] + + def peek(self): + return self._stack[-1] if self._stack else None + + def push(self, ctx): + return self._stack.append(ctx) + + def pop(self): + self._stack.pop()