[C++ gradients] Add experimental python bindings for unified API's Tape.
See unified_api_test for usage. PiperOrigin-RevId: 330992900 Change-Id: I1e7b2fc3ce6ab7d53c8f05a9807e7a03f6e59715
This commit is contained in:
parent
8e789c3872
commit
896635c9cb
@ -103,9 +103,13 @@ filegroup(
|
||||
"c_api_unified_experimental.h",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
"dlpack.h",
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
"immediate_execution_context.h",
|
||||
"immediate_execution_operation.h",
|
||||
"immediate_execution_tensor_handle.h",
|
||||
"mnist_gradients_testutil.h",
|
||||
"tape.h",
|
||||
"tfe_cancellation_manager_internal.h",
|
||||
"tfe_executor_internal.h",
|
||||
"tfe_monitoring_internal.h",
|
||||
@ -164,31 +168,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
"gradients.cc",
|
||||
"gradients_internal.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_internal",
|
||||
srcs = [
|
||||
@ -266,9 +245,7 @@ cc_library(
|
||||
":gradients_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/c/experimental/ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
|
@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
@ -148,3 +150,5 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
|
@ -16,7 +16,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
@ -34,7 +34,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
@ -55,10 +55,40 @@ cc_library(
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
hdrs = [
|
||||
"array_grad.h",
|
||||
"math_grad.h",
|
||||
"nn_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":array_grad",
|
||||
":math_grad",
|
||||
":nn_grad",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"array_grad.h",
|
||||
"math_grad.h",
|
||||
"nn_grad.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -68,6 +68,11 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
hdrs = [
|
||||
"array_ops.h",
|
||||
"math_ops.h",
|
||||
"nn_ops.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
|
@ -682,6 +682,7 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/eager:headers",
|
||||
"//tensorflow/c/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
|
||||
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
|
||||
@ -702,6 +703,7 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
|
||||
"//tensorflow/core:lib_headers_for_pybind",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/core/protobuf:eager_service_proto_cc",
|
||||
@ -5939,6 +5941,8 @@ pywrap_tensorflow_macro(
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/experimental/ops",
|
||||
"//tensorflow/c/experimental/gradients",
|
||||
"//tensorflow/c/eager:mnist_gradients_testutil",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core/data/service:server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
|
||||
@ -7448,6 +7452,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:headers",
|
||||
"//tensorflow/c/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
|
||||
"//tensorflow/c/experimental/gradients:pywrap_required_hdrs",
|
||||
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/python/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib_headers",
|
||||
@ -7484,6 +7489,7 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/c:headers",
|
||||
"//tensorflow/c/eager:headers",
|
||||
"//tensorflow/c/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
|
||||
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
|
||||
@ -7511,6 +7517,7 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/core:lib_headers_for_pybind",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/core/protobuf:eager_service_proto_cc",
|
||||
|
@ -27,6 +27,22 @@ tf_python_pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_tape",
|
||||
srcs = ["tape.cc"],
|
||||
features = ["-layering_check"],
|
||||
module_name = "_tape",
|
||||
deps = [
|
||||
"//tensorflow/c/eager:tfe_tensorhandle_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:unified_api_pywrap_required_headers",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_math_ops",
|
||||
srcs = ["math_ops.cc"],
|
||||
@ -44,10 +60,31 @@ tf_python_pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gradient_registry",
|
||||
srcs = ["gradient_registry.py"],
|
||||
deps = [":_tape"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "math_ops",
|
||||
srcs = ["math_ops.py"],
|
||||
deps = ["_math_ops"],
|
||||
deps = [
|
||||
":_math_ops",
|
||||
":gradient_registry",
|
||||
":tape_stack",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tape",
|
||||
srcs = ["tape.py"],
|
||||
deps = [
|
||||
":_tape",
|
||||
":context_stack",
|
||||
":tape_stack",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -66,6 +103,12 @@ py_library(
|
||||
deps = [":thread_local_stack"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "tape_stack",
|
||||
srcs = ["tape_stack.py"],
|
||||
deps = [":thread_local_stack"],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "unified_api_test",
|
||||
size = "small",
|
||||
@ -82,6 +125,7 @@ cuda_py_test(
|
||||
":context_stack",
|
||||
":def_function",
|
||||
":math_ops",
|
||||
":tape",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
|
@ -0,0 +1,27 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Global GradientRegistry."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework.experimental import _tape
|
||||
|
||||
_GRADIENT_REGISTRY_GLOBAL = _tape.GradientRegistry()
|
||||
|
||||
|
||||
def get_global_registry():
|
||||
return _GRADIENT_REGISTRY_GLOBAL
|
@ -23,22 +23,39 @@ limitations under the License.
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
// This library provides helpers for running ops and recording them on a
|
||||
// GradientTape. This is currently needed because the tape does not provide
|
||||
// an implementation of the abstract execution APIs but that will change.
|
||||
// TODO(b/168209775): Remove this and its imported symbols once the tape
|
||||
// execution context is ready.
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
|
||||
using tensorflow::AbstractContext;
|
||||
using tensorflow::AbstractTensorHandle;
|
||||
using tensorflow::ops::Add;
|
||||
using tensorflow::gradients::GradientRegistry;
|
||||
using tensorflow::gradients::Tape;
|
||||
|
||||
namespace tensorflow {
|
||||
PYBIND11_MODULE(_math_ops, m) {
|
||||
m.def("add", [](AbstractContext* ctx, AbstractTensorHandle* a,
|
||||
AbstractTensorHandle* b, const char* name) {
|
||||
AbstractTensorHandle* b, const char* name, Tape* tape,
|
||||
GradientRegistry* registry) {
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
if (!name) {
|
||||
name = "Add";
|
||||
}
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
Add(ctx, {a, b}, absl::MakeSpan(outputs), name));
|
||||
if (!tape) {
|
||||
MaybeRaiseRegisteredFromStatus(
|
||||
ops::Add(ctx, {a, b}, absl::MakeSpan(outputs), name));
|
||||
} else {
|
||||
MaybeRaiseRegisteredFromStatus(gradients::internal::Add(
|
||||
ctx, tape, {a, b}, absl::MakeSpan(outputs), *registry));
|
||||
}
|
||||
return outputs[0];
|
||||
});
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -20,8 +20,12 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework.experimental import _math_ops
|
||||
from tensorflow.python.framework.experimental import context_stack as context
|
||||
from tensorflow.python.framework.experimental import gradient_registry
|
||||
from tensorflow.python.framework.experimental import tape_stack
|
||||
|
||||
|
||||
def add(a, b, name=None):
|
||||
ctx = context.get_default()
|
||||
return _math_ops.add(ctx, a, b, name)
|
||||
tape = tape_stack.get_default()
|
||||
grad_registry = gradient_registry.get_global_registry()
|
||||
return _math_ops.add(ctx, a, b, name, tape, grad_registry)
|
||||
|
79
tensorflow/python/framework/experimental/tape.cc
Normal file
79
tensorflow/python/framework/experimental/tape.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyLossRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_tape, m) {
|
||||
py::class_<Tape>(m, "Tape")
|
||||
.def(py::init([](bool persistent) { return new Tape(persistent); }))
|
||||
.def("Watch",
|
||||
[](Tape* self, AbstractTensorHandle* t) { self->Watch(ToId(t)); })
|
||||
.def("ComputeGradient",
|
||||
[](Tape* self, TapeVSpace* vspace,
|
||||
std::vector<AbstractTensorHandle*> target_tensors,
|
||||
std::vector<AbstractTensorHandle*> source_tensors,
|
||||
std::vector<AbstractTensorHandle*> output_gradients) {
|
||||
std::vector<int64> target_tensor_ids;
|
||||
std::vector<int64> source_tensor_ids;
|
||||
target_tensor_ids.reserve(target_tensors.size());
|
||||
source_tensor_ids.reserve(source_tensors.size());
|
||||
for (auto t : target_tensors) {
|
||||
target_tensor_ids.emplace_back(ToId(t));
|
||||
}
|
||||
for (auto t : source_tensors) {
|
||||
source_tensor_ids.emplace_back(ToId(t));
|
||||
}
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
std::vector<AbstractTensorHandle*> results;
|
||||
Status s = self->ComputeGradient(
|
||||
*vspace, target_tensor_ids, source_tensor_ids,
|
||||
source_tensors_that_are_targets, output_gradients, &results,
|
||||
/*build_default_zeros_grads=*/false);
|
||||
MaybeRaiseRegisteredFromStatus(s);
|
||||
return results;
|
||||
});
|
||||
py::class_<TapeVSpace>(m, "TapeVSpace")
|
||||
.def(py::init([](AbstractContext* ctx) { return new TapeVSpace(ctx); }));
|
||||
py::class_<GradientRegistry>(m, "GradientRegistry").def(py::init([]() {
|
||||
auto registry = new GradientRegistry();
|
||||
MaybeRaiseRegisteredFromStatus(RegisterGradients(registry));
|
||||
return registry;
|
||||
}));
|
||||
}
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
54
tensorflow/python/framework/experimental/tape.py
Normal file
54
tensorflow/python/framework/experimental/tape.py
Normal file
@ -0,0 +1,54 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Experimental impl for GradientTape using unified APIs, for testing only."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework.experimental import _tape
|
||||
from tensorflow.python.framework.experimental import context_stack
|
||||
from tensorflow.python.framework.experimental import tape_stack
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class GradientTape(object):
|
||||
"""GradientTape using the unified API."""
|
||||
|
||||
def __init__(self, persistent=False):
|
||||
self._c_tape = _tape.Tape(persistent)
|
||||
|
||||
def watch(self, t):
|
||||
self._c_tape.Watch(t)
|
||||
|
||||
# TODO(srbs): Add support for unconnected_gradients.
|
||||
def gradient(self, targets, sources, output_gradients=None):
|
||||
ctx = context_stack.get_default()
|
||||
vspace = _tape.TapeVSpace(ctx)
|
||||
flat_targets = nest.flatten(targets)
|
||||
flat_sources = nest.flatten(sources)
|
||||
out_grads = self._c_tape.ComputeGradient(vspace, flat_targets, flat_sources,
|
||||
output_gradients or [])
|
||||
return nest.pack_sequence_as(sources, out_grads)
|
||||
|
||||
def __enter__(self):
|
||||
"""Enters a context inside which operations are recorded on this tape."""
|
||||
if tape_stack.get_default():
|
||||
raise ValueError("Nested tapes are not supported yet.")
|
||||
tape_stack.push(self._c_tape)
|
||||
return self
|
||||
|
||||
def __exit__(self, typ, value, traceback):
|
||||
tape_stack.pop()
|
35
tensorflow/python/framework/experimental/tape_stack.py
Normal file
35
tensorflow/python/framework/experimental/tape_stack.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Thread-local context manager stack."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework.experimental import thread_local_stack
|
||||
|
||||
_default_tape_stack = thread_local_stack.ThreadLocalStack()
|
||||
|
||||
|
||||
def get_default():
|
||||
return _default_tape_stack.peek()
|
||||
|
||||
|
||||
def push(tape):
|
||||
_default_tape_stack.push(tape)
|
||||
|
||||
|
||||
def pop():
|
||||
_default_tape_stack.pop()
|
@ -24,6 +24,7 @@ from tensorflow.python.framework.experimental import _unified_api
|
||||
from tensorflow.python.framework.experimental import context_stack as context_lib
|
||||
from tensorflow.python.framework.experimental import def_function
|
||||
from tensorflow.python.framework.experimental import math_ops
|
||||
from tensorflow.python.framework.experimental import tape as tape_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
NewImmediateExecutionContext = _unified_api.NewImmediateExecutionContext
|
||||
@ -57,6 +58,38 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
|
||||
eager_output = model(a, b)
|
||||
self.assertAllEqual(eager_output.numpy(), 3.0)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
("EagerGraph", False, False),
|
||||
("EagerMlir", False, True),
|
||||
# TODO(srbs): Enable for TFRT. Segfaults right now.
|
||||
# ("TfrtGraph", True, False),
|
||||
# ("TfrtMlir", True, True),
|
||||
])
|
||||
def testAddGrad(self, use_tfrt, use_mlir):
|
||||
if use_mlir:
|
||||
SetTracingImplementation("mlir")
|
||||
|
||||
def model(a, b):
|
||||
with tape_lib.GradientTape() as tape:
|
||||
tape.watch(a)
|
||||
tape.watch(b)
|
||||
result = math_ops.add(a, b)
|
||||
grads = tape.gradient(result, [a, b])
|
||||
return grads
|
||||
|
||||
eager_ctx = NewImmediateExecutionContext(use_tfrt)
|
||||
with context_lib.set_default(eager_ctx):
|
||||
a = eager_ctx.CreateFloatScalarHandle(1.)
|
||||
b = eager_ctx.CreateFloatScalarHandle(2.)
|
||||
|
||||
func_outputs = def_function.function(model)(a, b)
|
||||
self.assertAllEqual(func_outputs[0].numpy(), 1.0)
|
||||
self.assertAllEqual(func_outputs[1].numpy(), 1.0)
|
||||
|
||||
eager_outputs = model(a, b)
|
||||
self.assertAllEqual(eager_outputs[0].numpy(), 1.0)
|
||||
self.assertAllEqual(eager_outputs[1].numpy(), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user