- Integrate C++ tape with op building APIs via TapeContext and TapeOperation which delegate calls to a parent execution context and record operations on the tape. Please see gradients_test.cc for usage.

- This will replace the helper functions in gradients_internal.h. I will clean that up in a followup CL.
- Also drop ForwardOperation::ctx since that is unused right now. We can add it later if we need.

PiperOrigin-RevId: 333390787
Change-Id: I80f2c460a9538a1a14ed1497c59f7b37a633a633
This commit is contained in:
Saurabh Saxena 2020-09-23 15:55:12 -07:00 committed by TensorFlower Gardener
parent a5dacff238
commit 7ec70d54b2
12 changed files with 477 additions and 77 deletions

View File

@ -189,6 +189,7 @@ cc_library(
deps = [
":abstract_operation",
":c_api_unified_internal",
"//tensorflow/c/experimental/gradients/tape:tape_operation",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
@ -229,6 +230,7 @@ tf_cuda_cc_test(
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":abstract_context",
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
@ -239,7 +241,8 @@ tf_cuda_cc_test(
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:array_grad",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",

View File

@ -32,7 +32,7 @@ namespace tensorflow {
// environment, a traced representation etc.
class AbstractContext {
protected:
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt };
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {}

View File

@ -30,7 +30,7 @@ namespace tensorflow {
// tracing or immediate execution mode.
class AbstractOperation {
protected:
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt };
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {}

View File

@ -80,7 +80,6 @@ struct ForwardOperation {
std::vector<AbstractTensorHandle*> inputs;
std::vector<AbstractTensorHandle*> outputs;
AttrBuilder attrs;
AbstractContext* ctx;
};
// Interface for building default zeros gradients for op outputs which are

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
@ -26,7 +27,9 @@ limitations under the License.
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
@ -53,72 +56,14 @@ class CppGradients
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
// AddV2Registerer.
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
return Status::OK();
}
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `exp(inputs[0])` and records it on the tape.
Status Exp(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr exp_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(exp_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
}
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `IdentityN(inputs)` and records it on the tape.
Status IdentityN(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(identity_n_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
->SetOpName("my_identity_n"));
}
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
int num_retvals = outputs.size();
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
tape, registry);
}
// Computes
// y = inputs[0] + inputs[1]
@ -132,8 +77,10 @@ Status AddGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
registry)); // Compute x+y.
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
absl::MakeSpan(add_outputs),
"Add")); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -164,8 +111,9 @@ Status ExpGradModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1);
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
registry)); // Compute x+y.
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -197,8 +145,9 @@ Status IdentityNGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[1]));
vector<AbstractTensorHandle*> identity_n_outputs(2);
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
absl::MakeSpan(identity_n_outputs), registry));
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::IdentityN(
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -533,7 +482,6 @@ TEST_P(CppGradients, TestSetAttrString) {
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx.get();
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
/*raw_device_name=*/nullptr, &forward_op);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();

View File

@ -46,7 +46,6 @@ Status Add(AbstractContext* ctx, Tape* tape,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
@ -68,7 +67,6 @@ Status MatMul(AbstractContext* ctx, Tape* tape,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
@ -94,7 +92,6 @@ Status Mul(AbstractContext* ctx, Tape* tape,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
@ -117,7 +114,6 @@ Status Relu(AbstractContext* ctx, Tape* tape,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
@ -142,7 +138,6 @@ Status SparseSoftmaxCrossEntropyWithLogits(
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/c/eager/tracing_utils.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
@ -25,6 +26,11 @@ Status MaybeSetOpName(AbstractOperation* op, const char* op_name) {
if (isa<TracingOperation>(op)) {
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(op)->SetOpName(op_name));
}
if (isa<gradients::TapeOperation>(op)) {
TF_RETURN_IF_ERROR(MaybeSetOpName(
dyn_cast<gradients::TapeOperation>(op)->GetBackingOperation(),
op_name));
}
return Status::OK();
}
} // namespace tracing

View File

@ -0,0 +1,40 @@
# A tape built on top of unified execution APIs.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "tape_context",
srcs = ["tape_context.cc"],
hdrs = [
"tape_context.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":tape_operation",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/c/eager:abstract_operation",
],
)
cc_library(
name = "tape_operation",
srcs = ["tape_operation.cc"],
hdrs = [
"tape_operation.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_function",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:gradients_internal",
],
)

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
namespace tensorflow {
namespace gradients {
TapeContext::TapeContext(AbstractContext* c, Tape* tape,
const GradientRegistry& registry)
: AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) {
// TODO(srbs): Make AbstractContext ref counted.
// parent_ctx_->Ref();
}
void TapeContext::Release() {
// TODO(srbs): Change to Unref()
delete this;
}
TapeContext::~TapeContext() {
// TODO(srbs): Make AbstractContext ref counted.
// parent_ctx_->Unref();
}
TapeOperation* TapeContext::CreateOperation() {
return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_);
}
Status TapeContext::RegisterFunction(AbstractFunction* f) {
return parent_ctx_->RegisterFunction(f);
}
Status TapeContext::RemoveFunction(const string& func) {
return parent_ctx_->RemoveFunction(func);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
namespace tensorflow {
namespace gradients {
class TapeContext : public AbstractContext {
public:
explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&);
void Release() override;
TapeOperation* CreateOperation() override;
Status RegisterFunction(AbstractFunction*) override;
Status RemoveFunction(const string& func) override;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kTape;
}
~TapeContext() override;
private:
AbstractContext* parent_ctx_; // Not owned.
Tape* tape_;
const GradientRegistry& registry_;
};
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_

View File

@ -0,0 +1,238 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
const GradientRegistry& registry)
: AbstractOperation(kTape),
parent_op_(parent_op),
tape_(tape),
registry_(registry) {
// TODO(srbs): Make AbstractOperation RefCounted.
// parent_op_->Ref();
}
void TapeOperation::Release() {
// TODO(srbs): Change to Unref().
delete this;
}
TapeOperation::~TapeOperation() {
// TODO(srbs): Make AbstractOperation RefCounted.
// parent_op->Unref();
}
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
forward_op_.op_name = op;
forward_op_.attrs.Reset(op);
forward_op_.inputs.clear();
forward_op_.outputs.clear();
return parent_op_->Reset(op, raw_device_name);
}
const string& TapeOperation::Name() const { return parent_op_->Name(); }
const string& TapeOperation::DeviceName() const {
return parent_op_->DeviceName();
}
Status TapeOperation::SetDeviceName(const char* name) {
return parent_op_->SetDeviceName(name);
}
Status TapeOperation::AddInput(AbstractTensorHandle* input) {
TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
forward_op_.inputs.push_back(input);
return Status::OK();
}
Status TapeOperation::AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) {
TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs));
for (auto input : inputs) {
forward_op_.inputs.push_back(input);
}
return Status::OK();
}
Status TapeOperation::SetAttrString(const char* attr_name, const char* data,
size_t length) {
forward_op_.attrs.Set(attr_name, StringPiece(data, length));
return parent_op_->SetAttrString(attr_name, data, length);
}
Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) {
forward_op_.attrs.Set(attr_name, static_cast<int64>(value));
return parent_op_->SetAttrInt(attr_name, value);
}
Status TapeOperation::SetAttrFloat(const char* attr_name, float value) {
forward_op_.attrs.Set(attr_name, value);
return parent_op_->SetAttrFloat(attr_name, value);
}
Status TapeOperation::SetAttrBool(const char* attr_name, bool value) {
forward_op_.attrs.Set(attr_name, value);
return parent_op_->SetAttrBool(attr_name, value);
}
Status TapeOperation::SetAttrType(const char* attr_name, DataType value) {
forward_op_.attrs.Set(attr_name, value);
return parent_op_->SetAttrType(attr_name, value);
}
Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
forward_op_.attrs.Set(attr_name, proto);
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
}
Status TapeOperation::SetAttrFunction(const char* attr_name,
const AbstractOperation* value) {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status TapeOperation::SetAttrFunctionName(const char* attr_name,
const char* value, size_t length) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented "
"yet.");
}
Status TapeOperation::SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status TapeOperation::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
forward_op_.attrs.Set(attr_name, v);
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status TapeOperation::SetAttrFloatList(const char* attr_name,
const float* values, int num_values) {
forward_op_.attrs.Set(attr_name,
gtl::ArraySlice<const float>(values, num_values));
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrIntList(const char* attr_name,
const int64_t* values, int num_values) {
forward_op_.attrs.Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return parent_op_->SetAttrIntList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrTypeList(const char* attr_name,
const DataType* values, int num_values) {
forward_op_.attrs.Set(attr_name,
gtl::ArraySlice<const DataType>(values, num_values));
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
forward_op_.attrs.Set(attr_name,
gtl::ArraySlice<const bool>(b.get(), num_values));
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
}
Status TapeOperation::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims, int num_values) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
forward_op_.attrs.Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status TapeOperation::SetAttrFunctionList(
const char* attr_name, absl::Span<const AbstractOperation*> values) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been "
"implemented yet.");
}
AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
std::vector<int64> input_ids(forward_op_.inputs.size());
std::vector<tensorflow::DataType> input_dtypes(forward_op_.inputs.size());
for (int i = 0; i < forward_op_.inputs.size(); i++) {
input_ids[i] = ToId(forward_op_.inputs[i]);
input_dtypes[i] = forward_op_.inputs[i]->DataType();
}
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_.outputs.push_back(retvals[i]);
}
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
// attributes. Number type attrs and DataType attrs work fine without this.
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_.attrs.BuildNodeDef();
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t));
}
tape_->RecordOperation(
parent_op_->Name(), tape_tensors, input_ids, input_dtypes,
[this]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry_.Lookup(forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return backward_fn.release();
},
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}
});
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,80 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
class TapeOperation : public AbstractOperation {
public:
explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&);
void Release() override;
Status Reset(const char* op, const char* raw_device_name) override;
const string& Name() const override;
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override;
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override;
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperation*> values) override;
AbstractOperation* GetBackingOperation();
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kTape;
}
~TapeOperation() override;
private:
AbstractOperation* parent_op_;
ForwardOperation forward_op_;
Tape* tape_;
const GradientRegistry& registry_;
};
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_