89 lines
3.8 KiB
C++
89 lines
3.8 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
#include <pybind11/stl.h>
|
|
|
|
#include "pybind11/pybind11.h"
|
|
#include "tensorflow/c/eager/gradients.h"
|
|
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
|
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
|
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
|
#include "tensorflow/core/platform/status.h"
|
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace tensorflow {
|
|
namespace gradients {
|
|
|
|
Status RegisterGradients(GradientRegistry* registry) {
|
|
// TODO(srbs): Rename ops::Add and AddRegisterer to AddV2.
|
|
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
|
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
|
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
|
|
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
|
|
TF_RETURN_IF_ERROR(
|
|
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
|
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
|
TF_RETURN_IF_ERROR(registry->Register("Neg", NegRegisterer));
|
|
TF_RETURN_IF_ERROR(registry->Register("Sub", SubRegisterer));
|
|
return Status::OK();
|
|
}
|
|
|
|
PYBIND11_MODULE(_tape, m) {
|
|
py::class_<Tape>(m, "Tape")
|
|
.def(py::init([](bool persistent) { return new Tape(persistent); }))
|
|
.def("Watch",
|
|
[](Tape* self, AbstractTensorHandle* t) { self->Watch(ToId(t)); })
|
|
.def("ComputeGradient",
|
|
[](Tape* self, TapeVSpace* vspace,
|
|
std::vector<AbstractTensorHandle*> target_tensors,
|
|
std::vector<AbstractTensorHandle*> source_tensors,
|
|
std::vector<AbstractTensorHandle*> output_gradients) {
|
|
std::vector<int64> target_tensor_ids;
|
|
std::vector<int64> source_tensor_ids;
|
|
target_tensor_ids.reserve(target_tensors.size());
|
|
source_tensor_ids.reserve(source_tensors.size());
|
|
for (auto t : target_tensors) {
|
|
target_tensor_ids.emplace_back(ToId(t));
|
|
}
|
|
for (auto t : source_tensors) {
|
|
source_tensor_ids.emplace_back(ToId(t));
|
|
}
|
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
|
source_tensors_that_are_targets;
|
|
std::vector<AbstractTensorHandle*> results;
|
|
Status s = self->ComputeGradient(
|
|
*vspace, target_tensor_ids, source_tensor_ids,
|
|
source_tensors_that_are_targets, output_gradients, &results,
|
|
/*build_default_zeros_grads=*/false);
|
|
MaybeRaiseRegisteredFromStatus(s);
|
|
return results;
|
|
});
|
|
py::class_<TapeVSpace>(m, "TapeVSpace")
|
|
.def(py::init([](AbstractContext* ctx) { return new TapeVSpace(ctx); }));
|
|
py::class_<GradientRegistry>(m, "GradientRegistry").def(py::init([]() {
|
|
auto registry = new GradientRegistry();
|
|
MaybeRaiseRegisteredFromStatus(RegisterGradients(registry));
|
|
return registry;
|
|
}));
|
|
py::class_<TapeContext, AbstractContext>(m, "TapeContext")
|
|
.def(py::init(
|
|
[](AbstractContext* ctx, Tape* tape, GradientRegistry* registry) {
|
|
return new TapeContext(ctx, tape, *registry);
|
|
}));
|
|
}
|
|
} // namespace gradients
|
|
} // namespace tensorflow
|