First pass at implementing a C++ API for gradients built on top of abstract interfaces.

`gradients.h` has the public headers.
`gradients_internals.h` has some helpers for testing, while we figure out how to hook this into the AbstractOperation itself.

PiperOrigin-RevId: 320682374
Change-Id: I53e1c76f3c0897ff66f2563501806c425de69f24
This commit is contained in:
Saurabh Saxena 2020-07-10 14:59:04 -07:00 committed by TensorFlower Gardener
parent 37d20f87f3
commit 02e3cb289c
5 changed files with 1068 additions and 0 deletions

View File

@ -171,6 +171,87 @@ cc_library(
],
)
cc_library(
name = "gradients",
srcs = [
"gradients.cc",
"gradients_internal.h",
],
hdrs = [
"gradients.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal",
":tape",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "gradients_internal",
srcs = [
"gradients.cc",
],
hdrs = [
"gradients.h",
"gradients_internal.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal",
":tape",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "gradients_test",
size = "small",
srcs = [
"gradients_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "abstract_tensor_handle",
hdrs = ["abstract_tensor_handle.h"],
@ -747,6 +828,7 @@ filegroup(
"c_api_unified_experimental_eager.cc",
"c_api_unified_experimental_graph.cc",
"c_api_unified_experimental_internal.h",
"gradients.cc", # Uses RTTI.
"*test*",
"*dlpack*",
],

View File

@ -0,0 +1,400 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace gradients {
Status GradientRegistry::Register(const string& op_name,
GradientFunctionFactory factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
grad_fn->reset(iter->second(op));
return Status::OK();
}
int64 ToId(AbstractTensorHandle* t) {
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
}
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
: handle_(handle), ctx_(ctx) {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
}
TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_;
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
ctx_ = other.ctx_;
}
TapeTensor::~TapeTensor() {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Unref();
}
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
tensorflow::DataType TapeTensor::GetDType() const {
return handle_->DataType();
}
AbstractTensorHandle* TapeTensor::OnesLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
// TODO(srbs): It seems like this is used only for performance optimization
// and not for correctness. The only downside of keeping this 1 seems to be
// that the gradient accumulation is unbounded and we will never
// aggressively aggregate accumulated gradients to recover memory.
// Revisit and fix.
return 1;
}
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* TapeVSpace::AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
if (gradient_tensors.size() == 1) {
return gradient_tensors[0];
}
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
s = op->AddInputList(gradient_tensors);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
// Calls the passed-in backward function.
Status TapeVSpace::CallBackwardFunction(
GradientFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
return backward_function->Compute(output_gradients, result);
}
// Looks up the ID of a Gradient.
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
return ToId(tensor);
}
// Converts a Gradient to a TapeTensor.
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
return TapeTensor(g, ctx_);
}
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Release();
}
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change.
namespace internal {
Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) {
forward_op_->op_name = op;
return op_->Reset(op, raw_device_name);
}
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInput(input));
forward_op_->inputs.push_back(input);
return Status::OK();
}
Status AddInputList(AbstractOperation* op_,
absl::Span<AbstractTensorHandle* const> inputs,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
for (auto input : inputs) {
forward_op_->inputs.push_back(input);
}
return Status::OK();
}
Status SetAttrString(AbstractOperation* op_, const char* attr_name,
const char* data, size_t length,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, StringPiece(data, length));
return op_->SetAttrString(attr_name, data, length);
}
Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
return op_->SetAttrInt(attr_name, value);
}
Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrFloat(attr_name, value);
}
Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrBool(attr_name, value);
}
Status SetAttrType(AbstractOperation* op_, const char* attr_name,
DataType value, ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrType(attr_name, value);
}
Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
const int64_t* dims, const int num_dims,
ForwardOperation* forward_op_) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
forward_op_->attrs.Set(attr_name, proto);
return op_->SetAttrShape(attr_name, dims, num_dims);
}
Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
const AbstractOperation* value,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
const char* value, size_t length,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented "
"yet.");
}
Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
AbstractTensorInterface* tensor,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values, ForwardOperation* forward_op_) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
forward_op_->attrs.Set(attr_name, v);
return op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
const float* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const float>(values, num_values));
return op_->SetAttrFloatList(attr_name, values, num_values);
}
Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
const int64_t* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return op_->SetAttrIntList(attr_name, values, num_values);
}
Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
const DataType* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const DataType>(values, num_values));
return op_->SetAttrTypeList(attr_name, values, num_values);
}
Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
const unsigned char* values, int num_values,
ForwardOperation* forward_op_) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const bool>(b.get(), num_values));
return op_->SetAttrBoolList(attr_name, values, num_values);
}
Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, ForwardOperation* forward_op_) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
absl::Span<const AbstractOperation*> values,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been "
"implemented yet.");
}
Status Execute(AbstractOperation* op_, AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
ForwardOperation* forward_op_, Tape* tape,
const GradientRegistry& registry) {
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
std::vector<int64> input_ids(forward_op_->inputs.size());
std::vector<tensorflow::DataType> input_dtypes(forward_op_->inputs.size());
for (int i = 0; i < forward_op_->inputs.size(); i++) {
input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType();
}
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx));
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> GradientFunction* {
std::unique_ptr<GradientFunction> grad_fn;
Status s = registry.Lookup(*forward_op_, &grad_fn);
if (!s.ok()) {
return nullptr;
}
return grad_fn.release();
},
[](GradientFunction* ptr) {
if (ptr) {
delete ptr;
}
});
return Status::OK();
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,171 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
#define TENSORFLOW_C_EAGER_GRADIENTS_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
namespace tensorflow {
namespace gradients {
// =============== Experimental C++ API for computing gradients ===============
// Sample gradient function:
//
// class AddGradientFunction : public GradientFunction {
// public:
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
// grad_outputs->resize(2);
// (*grad_outputs)[0] = grad_inputs[0];
// (*grad_outputs)[1] = grad_inputs[0];
// return Status::OK();
// }
// ~AddGradientFunction() override {}
// };
//
// GradientFunction* AddRegisterer(const ForwardOperation& op) {
// // More complex gradient functions can use inputs/attrs etc. from the
// // forward `op`.
// return new AddGradientFunction;
// }
//
// Status RegisterGradients(GradientRegistry* registry) {
// return registry->Register("Add", AddRegisterer);
// }
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a GradientFunction.
struct ForwardOperation {
public:
string op_name;
std::vector<AbstractTensorHandle*> inputs;
std::vector<AbstractTensorHandle*> outputs;
AttrBuilder attrs;
AbstractContext* ctx;
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op, GradientFunctionFactory factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const;
private:
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
};
// Returns a unique id for the tensor which is used by the tape to build
// the gradient graph. See documentation of `TapeTensor` for more details.
int64 ToId(AbstractTensorHandle* t);
// Wrapper for a tensor output of an operation executing under a tape.
//
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
// the op that produced it (or -1 if this tensor was watched using
// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
// op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
// inputs and outputs and the gradient function These data structures combined
// allow us to trace the data dependencies between operations and hence compute
// gradients.
//
// This also implements `ZerosLike` and `OnesLike` to create the default
// incoming gradients for tensors which do not already have an incoming
// gradient.
class TapeTensor {
public:
TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx);
TapeTensor(const TapeTensor& other);
~TapeTensor();
tensorflow::int64 GetID() const;
tensorflow::DataType GetDType() const;
AbstractTensorHandle* OnesLike() const;
AbstractTensorHandle* ZerosLike() const;
private:
AbstractTensorHandle* handle_;
// The context where OnesLike and ZerosLike ops are to be created.
AbstractContext* ctx_;
};
// Vector space for actually computing gradients. Implements methods for calling
// the backward function with incoming gradients and returning the outgoing
// gradient and for performing gradient aggregation.
// See `tensorflow::eager::VSpace` for more details.
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
// Returns the number of elements in the gradient tensor.
int64 NumElements(AbstractTensorHandle* tensor) const override;
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
// Calls the passed-in backward function.
Status CallBackwardFunction(
GradientFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override;
// Looks up the ID of a Gradient.
int64 TensorId(AbstractTensorHandle* tensor) const override;
// Converts a Gradient to a TapeTensor.
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
void MarkAsResult(AbstractTensorHandle* gradient) const override;
void DeleteGradient(AbstractTensorHandle* gradient) const override;
private:
// The context where the aggregation op `Add` is to be created.
AbstractContext* ctx_;
};
// A tracing/immediate-execution agnostic tape.
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor>;
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
namespace internal {
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change.
// Records the op name in the `ForwardOperation`.
Status Reset(AbstractOperation*, const char* op, const char* raw_device_name,
ForwardOperation*);
// Records the inputs in the `ForwardOperation`.
Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*);
Status AddInputList(AbstractOperation*,
absl::Span<AbstractTensorHandle* const> inputs,
ForwardOperation*);
// Sets the attrs in the `ForwardOperation`.
Status SetAttrString(AbstractOperation*, const char* attr_name,
const char* data, size_t length, ForwardOperation*);
Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value,
ForwardOperation*);
Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value,
ForwardOperation*);
Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value,
ForwardOperation*);
Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value,
ForwardOperation*);
Status SetAttrShape(AbstractOperation*, const char* attr_name,
const int64_t* dims, const int num_dims, ForwardOperation*);
Status SetAttrFunction(AbstractOperation*, const char* attr_name,
const AbstractOperation* value, ForwardOperation*);
Status SetAttrFunctionName(AbstractOperation*, const char* attr_name,
const char* value, size_t length, ForwardOperation*);
Status SetAttrTensor(AbstractOperation*, const char* attr_name,
AbstractTensorInterface* tensor, ForwardOperation*);
Status SetAttrStringList(AbstractOperation*, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values, ForwardOperation*);
Status SetAttrFloatList(AbstractOperation*, const char* attr_name,
const float* values, int num_values, ForwardOperation*);
Status SetAttrIntList(AbstractOperation*, const char* attr_name,
const int64_t* values, int num_values, ForwardOperation*);
Status SetAttrTypeList(AbstractOperation*, const char* attr_name,
const DataType* values, int num_values,
ForwardOperation*);
Status SetAttrBoolList(AbstractOperation*, const char* attr_name,
const unsigned char* values, int num_values,
ForwardOperation*);
Status SetAttrShapeList(AbstractOperation*, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, ForwardOperation*);
Status SetAttrFunctionList(AbstractOperation*, const char* attr_name,
absl::Span<const AbstractOperation*> values,
ForwardOperation*);
// Make the call to `Tape::RecordOperation`.
Status Execute(AbstractOperation*, AbstractContext*,
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
ForwardOperation*, Tape*, const GradientRegistry&);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_

View File

@ -0,0 +1,328 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
}
};
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
return Status::OK();
}
// =================== Register gradients for Add ============================
class AddGradientFunction : public GradientFunction {
public:
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
private:
AbstractContext* ctx_;
};
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction(op.ctx);
}
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
}
// =================== End gradient registrations ============================
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
for (auto add_output : add_outputs) {
add_output->Release();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
std::vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
params->emplace_back(handle);
}
return Status::OK();
}
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
// Runs `model` maybe wrapped in a function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
OutputList output_list;
output_list.expected_num_outputs = outputs.size();
output_list.outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
for (auto func_input : func_inputs) {
func_input->Release();
}
AbstractFunction* func = nullptr;
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
output_list.outputs[0]->Release();
output_list.outputs[1]->Release();
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size();
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
TEST_P(CppGradients, TestAddGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// tape.watch(y)
// y = x + y
// outputs = tape.gradient(y, [x, y])
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[0]->Release();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
s = getValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Release();
TF_DeleteTensor(result_tensor);
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow