First pass at implementing a C++ API for gradients built on top of abstract interfaces.
`gradients.h` has the public headers. `gradients_internals.h` has some helpers for testing, while we figure out how to hook this into the AbstractOperation itself. PiperOrigin-RevId: 320682374 Change-Id: I53e1c76f3c0897ff66f2563501806c425de69f24
This commit is contained in:
parent
37d20f87f3
commit
02e3cb289c
@ -171,6 +171,87 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
srcs = [
|
||||
"gradients.cc",
|
||||
"gradients_internal.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_internal",
|
||||
srcs = [
|
||||
"gradients.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
"gradients_internal.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gradients_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
@ -747,6 +828,7 @@ filegroup(
|
||||
"c_api_unified_experimental_eager.cc",
|
||||
"c_api_unified_experimental_graph.cc",
|
||||
"c_api_unified_experimental_internal.h",
|
||||
"gradients.cc", # Uses RTTI.
|
||||
"*test*",
|
||||
"*dlpack*",
|
||||
],
|
||||
|
400
tensorflow/c/eager/gradients.cc
Normal file
400
tensorflow/c/eager/gradients.cc
Normal file
@ -0,0 +1,400 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
Status GradientRegistry::Register(const string& op_name,
|
||||
GradientFunctionFactory factory) {
|
||||
auto iter = registry_.find(op_name);
|
||||
if (iter != registry_.end()) {
|
||||
const string error_msg = "Gradient already exists for op: " + op_name + ".";
|
||||
return errors::AlreadyExists(error_msg);
|
||||
}
|
||||
registry_.insert({op_name, factory});
|
||||
return Status::OK();
|
||||
}
|
||||
Status GradientRegistry::Lookup(
|
||||
const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const {
|
||||
auto iter = registry_.find(op.op_name);
|
||||
if (iter == registry_.end()) {
|
||||
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
grad_fn->reset(iter->second(op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 ToId(AbstractTensorHandle* t) {
|
||||
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
|
||||
}
|
||||
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
|
||||
: handle_(handle), ctx_(ctx) {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
}
|
||||
TapeTensor::TapeTensor(const TapeTensor& other) {
|
||||
handle_ = other.handle_;
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Ref();
|
||||
ctx_ = other.ctx_;
|
||||
}
|
||||
TapeTensor::~TapeTensor() {
|
||||
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
|
||||
// on the client to keep this tensor live for the duration of the gradient
|
||||
// computation.
|
||||
// handle_->Unref();
|
||||
}
|
||||
|
||||
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
|
||||
|
||||
tensorflow::DataType TapeTensor::GetDType() const {
|
||||
return handle_->DataType();
|
||||
}
|
||||
|
||||
AbstractTensorHandle* TapeTensor::OnesLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const {
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
|
||||
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (isa<tracing::TracingOperation>(op.get())) {
|
||||
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
|
||||
absl::StrCat("OnesLike", ToId(handle_)).c_str());
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
s = op->AddInput(handle_);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
int num_outputs = 1;
|
||||
// TODO(srbs): Figure out who is in charge of releasing this.
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
|
||||
// TODO(srbs): It seems like this is used only for performance optimization
|
||||
// and not for correctness. The only downside of keeping this 1 seems to be
|
||||
// that the gradient accumulation is unbounded and we will never
|
||||
// aggressively aggregate accumulated gradients to recover memory.
|
||||
// Revisit and fix.
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Consumes references to the tensors in the gradient_tensors list and returns
|
||||
// a tensor with the result.
|
||||
AbstractTensorHandle* TapeVSpace::AggregateGradients(
|
||||
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
|
||||
if (gradient_tensors.size() == 1) {
|
||||
return gradient_tensors[0];
|
||||
}
|
||||
|
||||
AbstractOperationPtr op(ctx_->CreateOperation());
|
||||
Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
s = op->AddInputList(gradient_tensors);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return outputs[0];
|
||||
}
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status TapeVSpace::CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const {
|
||||
if (backward_function == nullptr) return Status::OK();
|
||||
return backward_function->Compute(output_gradients, result);
|
||||
}
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
|
||||
return ToId(tensor);
|
||||
}
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
|
||||
return TapeTensor(g, ctx_);
|
||||
}
|
||||
|
||||
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
|
||||
|
||||
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
||||
gradient->Release();
|
||||
}
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
// the state of the ForwardOperation and call the tape as appropriate.
|
||||
// These APIs are mainly to faciliate testing and are subject to change.
|
||||
namespace internal {
|
||||
Status Reset(AbstractOperation* op_, const char* op,
|
||||
const char* raw_device_name, ForwardOperation* forward_op_) {
|
||||
forward_op_->op_name = op;
|
||||
return op_->Reset(op, raw_device_name);
|
||||
}
|
||||
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
|
||||
ForwardOperation* forward_op_) {
|
||||
TF_RETURN_IF_ERROR(op_->AddInput(input));
|
||||
forward_op_->inputs.push_back(input);
|
||||
return Status::OK();
|
||||
}
|
||||
Status AddInputList(AbstractOperation* op_,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ForwardOperation* forward_op_) {
|
||||
TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
|
||||
for (auto input : inputs) {
|
||||
forward_op_->inputs.push_back(input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetAttrString(AbstractOperation* op_, const char* attr_name,
|
||||
const char* data, size_t length,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, StringPiece(data, length));
|
||||
return op_->SetAttrString(attr_name, data, length);
|
||||
}
|
||||
Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
|
||||
return op_->SetAttrInt(attr_name, value);
|
||||
}
|
||||
Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, value);
|
||||
return op_->SetAttrFloat(attr_name, value);
|
||||
}
|
||||
Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, value);
|
||||
return op_->SetAttrBool(attr_name, value);
|
||||
}
|
||||
Status SetAttrType(AbstractOperation* op_, const char* attr_name,
|
||||
DataType value, ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name, value);
|
||||
return op_->SetAttrType(attr_name, value);
|
||||
}
|
||||
Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
|
||||
const int64_t* dims, const int num_dims,
|
||||
ForwardOperation* forward_op_) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
forward_op_->attrs.Set(attr_name, proto);
|
||||
return op_->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
|
||||
const AbstractOperation* value,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunction has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
|
||||
const char* value, size_t length,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented "
|
||||
"yet.");
|
||||
}
|
||||
Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
|
||||
AbstractTensorInterface* tensor,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTensor has not been implemented yet.");
|
||||
}
|
||||
Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values, ForwardOperation* forward_op_) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
forward_op_->attrs.Set(attr_name, v);
|
||||
return op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
}
|
||||
Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
|
||||
const float* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const float>(values, num_values));
|
||||
return op_->SetAttrFloatList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
|
||||
const int64_t* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return op_->SetAttrIntList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
|
||||
const DataType* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
forward_op_->attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const DataType>(values, num_values));
|
||||
return op_->SetAttrTypeList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
|
||||
const unsigned char* values, int num_values,
|
||||
ForwardOperation* forward_op_) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
forward_op_->attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return op_->SetAttrBoolList(attr_name, values, num_values);
|
||||
}
|
||||
Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, ForwardOperation* forward_op_) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
forward_op_->attrs.Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values,
|
||||
ForwardOperation* forward_op_) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionList has not been "
|
||||
"implemented yet.");
|
||||
}
|
||||
Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
|
||||
ForwardOperation* forward_op_, Tape* tape,
|
||||
const GradientRegistry& registry) {
|
||||
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
|
||||
std::vector<int64> input_ids(forward_op_->inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(forward_op_->inputs.size());
|
||||
for (int i = 0; i < forward_op_->inputs.size(); i++) {
|
||||
input_ids[i] = ToId(forward_op_->inputs[i]);
|
||||
input_dtypes[i] = forward_op_->inputs[i]->DataType();
|
||||
}
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t, ctx));
|
||||
}
|
||||
tape->RecordOperation(
|
||||
op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[registry, forward_op_]() -> GradientFunction* {
|
||||
std::unique_ptr<GradientFunction> grad_fn;
|
||||
Status s = registry.Lookup(*forward_op_, &grad_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return grad_fn.release();
|
||||
},
|
||||
[](GradientFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
171
tensorflow/c/eager/gradients.h
Normal file
171
tensorflow/c/eager/gradients.h
Normal file
@ -0,0 +1,171 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
|
||||
#define TENSORFLOW_C_EAGER_GRADIENTS_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tape.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
// =============== Experimental C++ API for computing gradients ===============
|
||||
|
||||
// Sample gradient function:
|
||||
//
|
||||
// class AddGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// grad_outputs->resize(2);
|
||||
// (*grad_outputs)[0] = grad_inputs[0];
|
||||
// (*grad_outputs)[1] = grad_inputs[0];
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~AddGradientFunction() override {}
|
||||
// };
|
||||
//
|
||||
// GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
// // More complex gradient functions can use inputs/attrs etc. from the
|
||||
// // forward `op`.
|
||||
// return new AddGradientFunction;
|
||||
// }
|
||||
//
|
||||
// Status RegisterGradients(GradientRegistry* registry) {
|
||||
// return registry->Register("Add", AddRegisterer);
|
||||
// }
|
||||
class GradientFunction {
|
||||
public:
|
||||
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
|
||||
// `grad_inputs`.
|
||||
virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
|
||||
virtual ~GradientFunction() {}
|
||||
};
|
||||
|
||||
// Metadata from the forward operation that is made available to the
|
||||
// gradient registerer to instantiate a GradientFunction.
|
||||
struct ForwardOperation {
|
||||
public:
|
||||
string op_name;
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
std::vector<AbstractTensorHandle*> outputs;
|
||||
AttrBuilder attrs;
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
using GradientFunctionFactory =
|
||||
std::function<GradientFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Map from op name to a `GradientFunctionFactory`.
|
||||
class GradientRegistry {
|
||||
public:
|
||||
Status Register(const string& op, GradientFunctionFactory factory);
|
||||
Status Lookup(const ForwardOperation& op,
|
||||
std::unique_ptr<GradientFunction>* grad_fn) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
|
||||
};
|
||||
|
||||
// Returns a unique id for the tensor which is used by the tape to build
|
||||
// the gradient graph. See documentation of `TapeTensor` for more details.
|
||||
int64 ToId(AbstractTensorHandle* t);
|
||||
|
||||
// Wrapper for a tensor output of an operation executing under a tape.
|
||||
//
|
||||
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
|
||||
// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
|
||||
// the op that produced it (or -1 if this tensor was watched using
|
||||
// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
|
||||
// op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
|
||||
// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
|
||||
// inputs and outputs and the gradient function These data structures combined
|
||||
// allow us to trace the data dependencies between operations and hence compute
|
||||
// gradients.
|
||||
//
|
||||
// This also implements `ZerosLike` and `OnesLike` to create the default
|
||||
// incoming gradients for tensors which do not already have an incoming
|
||||
// gradient.
|
||||
class TapeTensor {
|
||||
public:
|
||||
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
|
||||
TapeTensor(const TapeTensor& other);
|
||||
~TapeTensor();
|
||||
|
||||
tensorflow::int64 GetID() const;
|
||||
tensorflow::DataType GetDType() const;
|
||||
|
||||
AbstractTensorHandle* OnesLike() const;
|
||||
AbstractTensorHandle* ZerosLike() const;
|
||||
|
||||
private:
|
||||
AbstractTensorHandle* handle_;
|
||||
// The context where OnesLike and ZerosLike ops are to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// Vector space for actually computing gradients. Implements methods for calling
|
||||
// the backward function with incoming gradients and returning the outgoing
|
||||
// gradient and for performing gradient aggregation.
|
||||
// See `tensorflow::eager::VSpace` for more details.
|
||||
class TapeVSpace
|
||||
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
|
||||
public:
|
||||
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
~TapeVSpace() override {}
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 NumElements(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Consumes references to the tensors in the gradient_tensors list and returns
|
||||
// a tensor with the result.
|
||||
AbstractTensorHandle* AggregateGradients(
|
||||
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status CallBackwardFunction(
|
||||
GradientFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TensorId(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
|
||||
|
||||
void MarkAsResult(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
void DeleteGradient(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
private:
|
||||
// The context where the aggregation op `Add` is to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// A tracing/immediate-execution agnostic tape.
|
||||
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
|
||||
GradientFunction, TapeTensor>;
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_
|
87
tensorflow/c/eager/gradients_internal.h
Normal file
87
tensorflow/c/eager/gradients_internal.h
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
|
||||
#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
|
||||
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
// the state of the ForwardOperation and call the tape as appropriate.
|
||||
// These APIs are mainly to faciliate testing and are subject to change.
|
||||
|
||||
// Records the op name in the `ForwardOperation`.
|
||||
Status Reset(AbstractOperation*, const char* op, const char* raw_device_name,
|
||||
ForwardOperation*);
|
||||
|
||||
// Records the inputs in the `ForwardOperation`.
|
||||
Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*);
|
||||
Status AddInputList(AbstractOperation*,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
ForwardOperation*);
|
||||
|
||||
// Sets the attrs in the `ForwardOperation`.
|
||||
Status SetAttrString(AbstractOperation*, const char* attr_name,
|
||||
const char* data, size_t length, ForwardOperation*);
|
||||
Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value,
|
||||
ForwardOperation*);
|
||||
Status SetAttrShape(AbstractOperation*, const char* attr_name,
|
||||
const int64_t* dims, const int num_dims, ForwardOperation*);
|
||||
Status SetAttrFunction(AbstractOperation*, const char* attr_name,
|
||||
const AbstractOperation* value, ForwardOperation*);
|
||||
Status SetAttrFunctionName(AbstractOperation*, const char* attr_name,
|
||||
const char* value, size_t length, ForwardOperation*);
|
||||
Status SetAttrTensor(AbstractOperation*, const char* attr_name,
|
||||
AbstractTensorInterface* tensor, ForwardOperation*);
|
||||
Status SetAttrStringList(AbstractOperation*, const char* attr_name,
|
||||
const void* const* values, const size_t* lengths,
|
||||
int num_values, ForwardOperation*);
|
||||
Status SetAttrFloatList(AbstractOperation*, const char* attr_name,
|
||||
const float* values, int num_values, ForwardOperation*);
|
||||
Status SetAttrIntList(AbstractOperation*, const char* attr_name,
|
||||
const int64_t* values, int num_values, ForwardOperation*);
|
||||
Status SetAttrTypeList(AbstractOperation*, const char* attr_name,
|
||||
const DataType* values, int num_values,
|
||||
ForwardOperation*);
|
||||
Status SetAttrBoolList(AbstractOperation*, const char* attr_name,
|
||||
const unsigned char* values, int num_values,
|
||||
ForwardOperation*);
|
||||
Status SetAttrShapeList(AbstractOperation*, const char* attr_name,
|
||||
const int64_t** dims, const int* num_dims,
|
||||
int num_values, ForwardOperation*);
|
||||
Status SetAttrFunctionList(AbstractOperation*, const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values,
|
||||
ForwardOperation*);
|
||||
|
||||
// Make the call to `Tape::RecordOperation`.
|
||||
Status Execute(AbstractOperation*, AbstractContext*,
|
||||
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
|
||||
ForwardOperation*, Tape*, const GradientRegistry&);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
|
328
tensorflow/c/eager/gradients_test.cc
Normal file
328
tensorflow/c/eager/gradients_test.cc
Normal file
@ -0,0 +1,328 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
}
|
||||
};
|
||||
|
||||
// Creates an Identity op.
|
||||
Status Identity(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
if (isa<tracing::TracingOperation>(identity_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// =================== Register gradients for Add ============================
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id0"));
|
||||
(*grad_outputs)[0] = identity_outputs[0];
|
||||
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
absl::MakeSpan(identity_outputs), "Id1"));
|
||||
(*grad_outputs)[1] = identity_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction(op.ctx);
|
||||
}
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
return registry->Register("Add", AddRegisterer);
|
||||
}
|
||||
|
||||
// =================== End gradient registrations ============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Release();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
// Runs `model` maybe wrapped in a function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
std::vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = outputs.size();
|
||||
output_list.outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(output_list.outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Release();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
output_list.outputs[0]->Release();
|
||||
output_list.outputs[1]->Release();
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size();
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestAddGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x)
|
||||
// tape.watch(y)
|
||||
// y = x + y
|
||||
// outputs = tape.gradient(y, [x, y])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* result_tensor;
|
||||
s = getValue(outputs[0], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Release();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
s = getValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Release();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
// supported. It is needed for AddN op which is used for gradient aggregation.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user