[C++ gradients] Add experimental python bindings for unified API's Tape.

See unified_api_test for usage.

PiperOrigin-RevId: 330992900
Change-Id: I1e7b2fc3ce6ab7d53c8f05a9807e7a03f6e59715
This commit is contained in:
Saurabh Saxena 2020-09-10 12:31:56 -07:00 committed by TensorFlower Gardener
parent 8e789c3872
commit 896635c9cb
13 changed files with 353 additions and 37 deletions

View File

@ -103,9 +103,13 @@ filegroup(
"c_api_unified_experimental.h",
"c_api_unified_experimental_internal.h",
"dlpack.h",
"gradients.h",
"gradients_internal.h",
"immediate_execution_context.h",
"immediate_execution_operation.h",
"immediate_execution_tensor_handle.h",
"mnist_gradients_testutil.h",
"tape.h",
"tfe_cancellation_manager_internal.h",
"tfe_executor_internal.h",
"tfe_monitoring_internal.h",
@ -164,31 +168,6 @@ cc_library(
],
)
cc_library(
name = "gradients",
srcs = [
"gradients.cc",
"gradients_internal.h",
],
hdrs = [
"gradients.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal",
":tape",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gradients_internal",
srcs = [
@ -266,9 +245,7 @@ cc_library(
":gradients_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/c/experimental/ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set",

View File

@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#include <memory>
#include "absl/types/span.h"
@ -148,3 +150,5 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_

View File

@ -16,7 +16,7 @@ cc_library(
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/core/lib/llvm_rtti",
],
)
@ -34,7 +34,7 @@ cc_library(
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
@ -55,10 +55,40 @@ cc_library(
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library(
name = "gradients",
hdrs = [
"array_grad.h",
"math_grad.h",
"nn_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_grad",
":math_grad",
":nn_grad",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"array_grad.h",
"math_grad.h",
"nn_grad.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -68,6 +68,11 @@ cc_library(
cc_library(
name = "ops",
hdrs = [
"array_ops.h",
"math_ops.h",
"nn_ops.h",
],
visibility = [
"//tensorflow:internal",
],

View File

@ -682,6 +682,7 @@ tf_python_pybind_extension(
"//tensorflow/c:headers",
"//tensorflow/c/eager:headers",
"//tensorflow/c/eager:pywrap_required_hdrs",
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
@ -702,6 +703,7 @@ tf_python_pybind_extension(
"//tensorflow/core/common_runtime:core_cpu_headers_lib",
"//tensorflow/core:lib_headers_for_pybind",
"@com_google_absl//absl/types:optional",
"//tensorflow/core/lib/llvm_rtti",
] + if_static(
extra_deps = [
"//tensorflow/core/protobuf:eager_service_proto_cc",
@ -5939,6 +5941,8 @@ pywrap_tensorflow_macro(
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/experimental/ops",
"//tensorflow/c/experimental/gradients",
"//tensorflow/c/eager:mnist_gradients_testutil",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core/data/service:server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
@ -7448,6 +7452,7 @@ cc_library(
"//tensorflow/c/eager:headers",
"//tensorflow/c/eager:pywrap_required_hdrs",
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
"//tensorflow/c/experimental/gradients:pywrap_required_hdrs",
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
"//tensorflow/python/eager:pywrap_required_hdrs",
"//tensorflow/core/common_runtime:core_cpu_lib_headers",
@ -7484,6 +7489,7 @@ tf_python_pybind_extension(
"//tensorflow/c:headers",
"//tensorflow/c/eager:headers",
"//tensorflow/c/eager:pywrap_required_hdrs",
"//tensorflow/c/experimental/ops:pywrap_required_hdrs",
"//tensorflow/core/common_runtime/eager:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
@ -7511,6 +7517,7 @@ tf_python_pybind_extension(
"//tensorflow/core:lib_headers_for_pybind",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform",
"//tensorflow/core/lib/llvm_rtti",
] + if_static(
extra_deps = [
"//tensorflow/core/protobuf:eager_service_proto_cc",

View File

@ -27,6 +27,22 @@ tf_python_pybind_extension(
],
)
tf_python_pybind_extension(
name = "_tape",
srcs = ["tape.cc"],
features = ["-layering_check"],
module_name = "_tape",
deps = [
"//tensorflow/c/eager:tfe_tensorhandle_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:unified_api_pywrap_required_headers",
"@pybind11",
],
)
tf_python_pybind_extension(
name = "_math_ops",
srcs = ["math_ops.cc"],
@ -44,10 +60,31 @@ tf_python_pybind_extension(
],
)
py_library(
name = "gradient_registry",
srcs = ["gradient_registry.py"],
deps = [":_tape"],
)
py_library(
name = "math_ops",
srcs = ["math_ops.py"],
deps = ["_math_ops"],
deps = [
":_math_ops",
":gradient_registry",
":tape_stack",
],
)
py_library(
name = "tape",
srcs = ["tape.py"],
deps = [
":_tape",
":context_stack",
":tape_stack",
"//tensorflow/python/data/util:nest",
],
)
py_library(
@ -66,6 +103,12 @@ py_library(
deps = [":thread_local_stack"],
)
py_library(
name = "tape_stack",
srcs = ["tape_stack.py"],
deps = [":thread_local_stack"],
)
cuda_py_test(
name = "unified_api_test",
size = "small",
@ -82,6 +125,7 @@ cuda_py_test(
":context_stack",
":def_function",
":math_ops",
":tape",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],

View File

@ -0,0 +1,27 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global GradientRegistry."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework.experimental import _tape
_GRADIENT_REGISTRY_GLOBAL = _tape.GradientRegistry()
def get_global_registry():
return _GRADIENT_REGISTRY_GLOBAL

View File

@ -23,22 +23,39 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
// This library provides helpers for running ops and recording them on a
// GradientTape. This is currently needed because the tape does not provide
// an implementation of the abstract execution APIs but that will change.
// TODO(b/168209775): Remove this and its imported symbols once the tape
// execution context is ready.
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
using tensorflow::AbstractContext;
using tensorflow::AbstractTensorHandle;
using tensorflow::ops::Add;
using tensorflow::gradients::GradientRegistry;
using tensorflow::gradients::Tape;
namespace tensorflow {
PYBIND11_MODULE(_math_ops, m) {
m.def("add", [](AbstractContext* ctx, AbstractTensorHandle* a,
AbstractTensorHandle* b, const char* name) {
AbstractTensorHandle* b, const char* name, Tape* tape,
GradientRegistry* registry) {
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(1);
if (!name) {
name = "Add";
}
MaybeRaiseRegisteredFromStatus(
Add(ctx, {a, b}, absl::MakeSpan(outputs), name));
if (!tape) {
MaybeRaiseRegisteredFromStatus(
ops::Add(ctx, {a, b}, absl::MakeSpan(outputs), name));
} else {
MaybeRaiseRegisteredFromStatus(gradients::internal::Add(
ctx, tape, {a, b}, absl::MakeSpan(outputs), *registry));
}
return outputs[0];
});
}
} // namespace tensorflow

View File

@ -20,8 +20,12 @@ from __future__ import print_function
from tensorflow.python.framework.experimental import _math_ops
from tensorflow.python.framework.experimental import context_stack as context
from tensorflow.python.framework.experimental import gradient_registry
from tensorflow.python.framework.experimental import tape_stack
def add(a, b, name=None):
ctx = context.get_default()
return _math_ops.add(ctx, a, b, name)
tape = tape_stack.get_default()
grad_registry = gradient_registry.get_global_registry()
return _math_ops.add(ctx, a, b, name, tape, grad_registry)

View File

@ -0,0 +1,79 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <pybind11/stl.h>
#include "pybind11/pybind11.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
namespace tensorflow {
namespace gradients {
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyLossRegisterer));
return Status::OK();
}
PYBIND11_MODULE(_tape, m) {
py::class_<Tape>(m, "Tape")
.def(py::init([](bool persistent) { return new Tape(persistent); }))
.def("Watch",
[](Tape* self, AbstractTensorHandle* t) { self->Watch(ToId(t)); })
.def("ComputeGradient",
[](Tape* self, TapeVSpace* vspace,
std::vector<AbstractTensorHandle*> target_tensors,
std::vector<AbstractTensorHandle*> source_tensors,
std::vector<AbstractTensorHandle*> output_gradients) {
std::vector<int64> target_tensor_ids;
std::vector<int64> source_tensor_ids;
target_tensor_ids.reserve(target_tensors.size());
source_tensor_ids.reserve(source_tensors.size());
for (auto t : target_tensors) {
target_tensor_ids.emplace_back(ToId(t));
}
for (auto t : source_tensors) {
source_tensor_ids.emplace_back(ToId(t));
}
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> results;
Status s = self->ComputeGradient(
*vspace, target_tensor_ids, source_tensor_ids,
source_tensors_that_are_targets, output_gradients, &results,
/*build_default_zeros_grads=*/false);
MaybeRaiseRegisteredFromStatus(s);
return results;
});
py::class_<TapeVSpace>(m, "TapeVSpace")
.def(py::init([](AbstractContext* ctx) { return new TapeVSpace(ctx); }));
py::class_<GradientRegistry>(m, "GradientRegistry").def(py::init([]() {
auto registry = new GradientRegistry();
MaybeRaiseRegisteredFromStatus(RegisterGradients(registry));
return registry;
}));
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,54 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental impl for GradientTape using unified APIs, for testing only."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework.experimental import _tape
from tensorflow.python.framework.experimental import context_stack
from tensorflow.python.framework.experimental import tape_stack
from tensorflow.python.util import nest
class GradientTape(object):
"""GradientTape using the unified API."""
def __init__(self, persistent=False):
self._c_tape = _tape.Tape(persistent)
def watch(self, t):
self._c_tape.Watch(t)
# TODO(srbs): Add support for unconnected_gradients.
def gradient(self, targets, sources, output_gradients=None):
ctx = context_stack.get_default()
vspace = _tape.TapeVSpace(ctx)
flat_targets = nest.flatten(targets)
flat_sources = nest.flatten(sources)
out_grads = self._c_tape.ComputeGradient(vspace, flat_targets, flat_sources,
output_gradients or [])
return nest.pack_sequence_as(sources, out_grads)
def __enter__(self):
"""Enters a context inside which operations are recorded on this tape."""
if tape_stack.get_default():
raise ValueError("Nested tapes are not supported yet.")
tape_stack.push(self._c_tape)
return self
def __exit__(self, typ, value, traceback):
tape_stack.pop()

View File

@ -0,0 +1,35 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Thread-local context manager stack."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework.experimental import thread_local_stack
_default_tape_stack = thread_local_stack.ThreadLocalStack()
def get_default():
return _default_tape_stack.peek()
def push(tape):
_default_tape_stack.push(tape)
def pop():
_default_tape_stack.pop()

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework.experimental import _unified_api
from tensorflow.python.framework.experimental import context_stack as context_lib
from tensorflow.python.framework.experimental import def_function
from tensorflow.python.framework.experimental import math_ops
from tensorflow.python.framework.experimental import tape as tape_lib
from tensorflow.python.platform import test
NewImmediateExecutionContext = _unified_api.NewImmediateExecutionContext
@ -57,6 +58,38 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase):
eager_output = model(a, b)
self.assertAllEqual(eager_output.numpy(), 3.0)
@parameterized.named_parameters([
("EagerGraph", False, False),
("EagerMlir", False, True),
# TODO(srbs): Enable for TFRT. Segfaults right now.
# ("TfrtGraph", True, False),
# ("TfrtMlir", True, True),
])
def testAddGrad(self, use_tfrt, use_mlir):
if use_mlir:
SetTracingImplementation("mlir")
def model(a, b):
with tape_lib.GradientTape() as tape:
tape.watch(a)
tape.watch(b)
result = math_ops.add(a, b)
grads = tape.gradient(result, [a, b])
return grads
eager_ctx = NewImmediateExecutionContext(use_tfrt)
with context_lib.set_default(eager_ctx):
a = eager_ctx.CreateFloatScalarHandle(1.)
b = eager_ctx.CreateFloatScalarHandle(2.)
func_outputs = def_function.function(model)(a, b)
self.assertAllEqual(func_outputs[0].numpy(), 1.0)
self.assertAllEqual(func_outputs[1].numpy(), 1.0)
eager_outputs = model(a, b)
self.assertAllEqual(eager_outputs[0].numpy(), 1.0)
self.assertAllEqual(eager_outputs[1].numpy(), 1.0)
if __name__ == "__main__":
test.main()