- Integrate C++ tape with op building APIs via TapeContext and TapeOperation which delegate calls to a parent execution context and record operations on the tape. Please see gradients_test.cc for usage.
- This will replace the helper functions in gradients_internal.h. I will clean that up in a followup CL. - Also drop ForwardOperation::ctx since that is unused right now. We can add it later if we need. PiperOrigin-RevId: 333390787 Change-Id: I80f2c460a9538a1a14ed1497c59f7b37a633a633
This commit is contained in:
parent
a5dacff238
commit
7ec70d54b2
@ -189,6 +189,7 @@ cc_library(
|
||||
deps = [
|
||||
":abstract_operation",
|
||||
":c_api_unified_internal",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_operation",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
@ -229,6 +230,7 @@ tf_cuda_cc_test(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
@ -239,7 +241,8 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:array_grad",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -32,7 +32,7 @@ namespace tensorflow {
|
||||
// environment, a traced representation etc.
|
||||
class AbstractContext {
|
||||
protected:
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt };
|
||||
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractContext() {}
|
||||
|
||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
||||
// tracing or immediate execution mode.
|
||||
class AbstractOperation {
|
||||
protected:
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt };
|
||||
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
|
||||
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
|
||||
virtual ~AbstractOperation() {}
|
||||
|
||||
|
@ -80,7 +80,6 @@ struct ForwardOperation {
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
std::vector<AbstractTensorHandle*> outputs;
|
||||
AttrBuilder attrs;
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
// Interface for building default zeros gradients for op outputs which are
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
@ -26,7 +27,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
@ -53,72 +56,14 @@ class CppGradients
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
|
||||
// AddV2Registerer.
|
||||
TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `exp(inputs[0])` and records it on the tape.
|
||||
Status Exp(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr exp_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(exp_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `IdentityN(inputs)` and records it on the tape.
|
||||
Status IdentityN(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr identity_n_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(identity_n_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(identity_n_op.get())
|
||||
->SetOpName("my_identity_n"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op));
|
||||
int num_retvals = outputs.size();
|
||||
return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op,
|
||||
tape, registry);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
@ -132,8 +77,10 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(add_outputs),
|
||||
"Add")); // Compute x+y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -164,8 +111,9 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> exp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -197,8 +145,9 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[1]));
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs,
|
||||
absl::MakeSpan(identity_n_outputs), registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::IdentityN(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -533,7 +482,6 @@ TEST_P(CppGradients, TestSetAttrString) {
|
||||
|
||||
AbstractOperationPtr check_numerics_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx.get();
|
||||
Status s = Reset(check_numerics_op.get(), "CheckNumerics",
|
||||
/*raw_device_name=*/nullptr, &forward_op);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
@ -46,7 +46,6 @@ Status Add(AbstractContext* ctx, Tape* tape,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
@ -68,7 +67,6 @@ Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(matmul_op.get())) {
|
||||
@ -94,7 +92,6 @@ Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(mul_op.get())) {
|
||||
@ -117,7 +114,6 @@ Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(relu_op.get())) {
|
||||
@ -142,7 +138,6 @@ Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
|
||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(sm_op.get())) {
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tracing_utils.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
@ -25,6 +26,11 @@ Status MaybeSetOpName(AbstractOperation* op, const char* op_name) {
|
||||
if (isa<TracingOperation>(op)) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<TracingOperation>(op)->SetOpName(op_name));
|
||||
}
|
||||
if (isa<gradients::TapeOperation>(op)) {
|
||||
TF_RETURN_IF_ERROR(MaybeSetOpName(
|
||||
dyn_cast<gradients::TapeOperation>(op)->GetBackingOperation(),
|
||||
op_name));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tracing
|
||||
|
40
tensorflow/c/experimental/gradients/tape/BUILD
Normal file
40
tensorflow/c/experimental/gradients/tape/BUILD
Normal file
@ -0,0 +1,40 @@
|
||||
# A tape built on top of unified execution APIs.
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape_context",
|
||||
srcs = ["tape_context.cc"],
|
||||
hdrs = [
|
||||
"tape_context.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":tape_operation",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape_operation",
|
||||
srcs = ["tape_operation.cc"],
|
||||
hdrs = [
|
||||
"tape_operation.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_function",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
)
|
47
tensorflow/c/experimental/gradients/tape/tape_context.cc
Normal file
47
tensorflow/c/experimental/gradients/tape/tape_context.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
TapeContext::TapeContext(AbstractContext* c, Tape* tape,
|
||||
const GradientRegistry& registry)
|
||||
: AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) {
|
||||
// TODO(srbs): Make AbstractContext ref counted.
|
||||
// parent_ctx_->Ref();
|
||||
}
|
||||
void TapeContext::Release() {
|
||||
// TODO(srbs): Change to Unref()
|
||||
delete this;
|
||||
}
|
||||
TapeContext::~TapeContext() {
|
||||
// TODO(srbs): Make AbstractContext ref counted.
|
||||
// parent_ctx_->Unref();
|
||||
}
|
||||
TapeOperation* TapeContext::CreateOperation() {
|
||||
return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_);
|
||||
}
|
||||
Status TapeContext::RegisterFunction(AbstractFunction* f) {
|
||||
return parent_ctx_->RegisterFunction(f);
|
||||
}
|
||||
Status TapeContext::RemoveFunction(const string& func) {
|
||||
return parent_ctx_->RemoveFunction(func);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
44
tensorflow/c/experimental/gradients/tape/tape_context.h
Normal file
44
tensorflow/c/experimental/gradients/tape/tape_context.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
class TapeContext : public AbstractContext {
|
||||
public:
|
||||
explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&);
|
||||
void Release() override;
|
||||
TapeOperation* CreateOperation() override;
|
||||
Status RegisterFunction(AbstractFunction*) override;
|
||||
Status RemoveFunction(const string& func) override;
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractContext* ptr) {
|
||||
return ptr->getKind() == kTape;
|
||||
}
|
||||
~TapeContext() override;
|
||||
|
||||
private:
|
||||
AbstractContext* parent_ctx_; // Not owned.
|
||||
Tape* tape_;
|
||||
const GradientRegistry& registry_;
|
||||
};
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_
|
238
tensorflow/c/experimental/gradients/tape/tape_operation.cc
Normal file
238
tensorflow/c/experimental/gradients/tape/tape_operation.cc
Normal file
@ -0,0 +1,238 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_operation.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
|
||||
const GradientRegistry& registry)
|
||||
: AbstractOperation(kTape),
|
||||
parent_op_(parent_op),
|
||||
tape_(tape),
|
||||
registry_(registry) {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// parent_op_->Ref();
|
||||
}
|
||||
void TapeOperation::Release() {
|
||||
// TODO(srbs): Change to Unref().
|
||||
delete this;
|
||||
}
|
||||
TapeOperation::~TapeOperation() {
|
||||
// TODO(srbs): Make AbstractOperation RefCounted.
|
||||
// parent_op->Unref();
|
||||
}
|
||||
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
|
||||
forward_op_.op_name = op;
|
||||
forward_op_.attrs.Reset(op);
|
||||
forward_op_.inputs.clear();
|
||||
forward_op_.outputs.clear();
|
||||
return parent_op_->Reset(op, raw_device_name);
|
||||
}
|
||||
const string& TapeOperation::Name() const { return parent_op_->Name(); }
|
||||
const string& TapeOperation::DeviceName() const {
|
||||
return parent_op_->DeviceName();
|
||||
}
|
||||
Status TapeOperation::SetDeviceName(const char* name) {
|
||||
return parent_op_->SetDeviceName(name);
|
||||
}
|
||||
Status TapeOperation::AddInput(AbstractTensorHandle* input) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->AddInput(input));
|
||||
forward_op_.inputs.push_back(input);
|
||||
return Status::OK();
|
||||
}
|
||||
Status TapeOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs));
|
||||
for (auto input : inputs) {
|
||||
forward_op_.inputs.push_back(input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status TapeOperation::SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) {
|
||||
forward_op_.attrs.Set(attr_name, StringPiece(data, length));
|
||||
return parent_op_->SetAttrString(attr_name, data, length);
|
||||
}
|
||||
Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) {
|
||||
forward_op_.attrs.Set(attr_name, static_cast<int64>(value));
|
||||
return parent_op_->SetAttrInt(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrFloat(const char* attr_name, float value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrFloat(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrBool(const char* attr_name, bool value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrBool(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrType(const char* attr_name, DataType value) {
|
||||
forward_op_.attrs.Set(attr_name, value);
|
||||
return parent_op_->SetAttrType(attr_name, value);
|
||||
}
|
||||
Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) {
|
||||
if (num_dims > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
|
||||
num_dims,
|
||||
" dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), ".");
|
||||
}
|
||||
TensorShapeProto proto;
|
||||
if (num_dims < 0) {
|
||||
proto.set_unknown_rank(true);
|
||||
} else {
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
proto.add_dim()->set_size(dims[d]);
|
||||
}
|
||||
}
|
||||
|
||||
forward_op_.attrs.Set(attr_name, proto);
|
||||
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
|
||||
}
|
||||
Status TapeOperation::SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunction has not been implemented yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrFunctionName(const char* attr_name,
|
||||
const char* value, size_t length) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionName has not been implemented "
|
||||
"yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrTensor has not been implemented yet.");
|
||||
}
|
||||
Status TapeOperation::SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths, int num_values) {
|
||||
std::vector<StringPiece> v(num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
|
||||
}
|
||||
forward_op_.attrs.Set(attr_name, v);
|
||||
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrFloatList(const char* attr_name,
|
||||
const float* values, int num_values) {
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const float>(values, num_values));
|
||||
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values, int num_values) {
|
||||
forward_op_.attrs.Set(
|
||||
attr_name, gtl::ArraySlice<const int64>(
|
||||
reinterpret_cast<const int64*>(values), num_values));
|
||||
return parent_op_->SetAttrIntList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrTypeList(const char* attr_name,
|
||||
const DataType* values, int num_values) {
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const DataType>(values, num_values));
|
||||
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) {
|
||||
std::unique_ptr<bool[]> b(new bool[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
b[i] = values[i];
|
||||
}
|
||||
forward_op_.attrs.Set(attr_name,
|
||||
gtl::ArraySlice<const bool>(b.get(), num_values));
|
||||
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims, int num_values) {
|
||||
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
const auto num_dims_i = num_dims[i];
|
||||
|
||||
if (num_dims_i > TensorShape::MaxDimensions()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Value specified for `", attr_name, "` has ",
|
||||
num_dims_i, " dimensions which is over the limit of ",
|
||||
TensorShape::MaxDimensions(), "."));
|
||||
}
|
||||
if (num_dims_i < 0) {
|
||||
proto[i].set_unknown_rank(true);
|
||||
} else {
|
||||
const int64_t* dims_i = dims[i];
|
||||
auto proto_i = &proto[i];
|
||||
for (int d = 0; d < num_dims_i; ++d) {
|
||||
proto_i->add_dim()->set_size(dims_i[d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
forward_op_.attrs.Set(
|
||||
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
|
||||
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
|
||||
}
|
||||
Status TapeOperation::SetAttrFunctionList(
|
||||
const char* attr_name, absl::Span<const AbstractOperation*> values) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetAttrFunctionList has not been "
|
||||
"implemented yet.");
|
||||
}
|
||||
AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
|
||||
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
|
||||
std::vector<int64> input_ids(forward_op_.inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(forward_op_.inputs.size());
|
||||
for (int i = 0; i < forward_op_.inputs.size(); i++) {
|
||||
input_ids[i] = ToId(forward_op_.inputs[i]);
|
||||
input_dtypes[i] = forward_op_.inputs[i]->DataType();
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; i++) {
|
||||
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
|
||||
forward_op_.outputs.push_back(retvals[i]);
|
||||
}
|
||||
// TODO(b/166669239): This is needed to support AttrBuilder::Get for string
|
||||
// attributes. Number type attrs and DataType attrs work fine without this.
|
||||
// Consider getting rid of this and making the behavior between number types
|
||||
// and string consistent.
|
||||
forward_op_.attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
tape_->RecordOperation(
|
||||
parent_op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[this]() -> BackwardFunction* {
|
||||
std::unique_ptr<BackwardFunction> backward_fn;
|
||||
Status s = registry_.Lookup(forward_op_, &backward_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return backward_fn.release();
|
||||
},
|
||||
[](BackwardFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
80
tensorflow/c/experimental/gradients/tape/tape_operation.h
Normal file
80
tensorflow/c/experimental/gradients/tape/tape_operation.h
Normal file
@ -0,0 +1,80 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_operation.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
class TapeOperation : public AbstractOperation {
|
||||
public:
|
||||
explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&);
|
||||
void Release() override;
|
||||
Status Reset(const char* op, const char* raw_device_name) override;
|
||||
const string& Name() const override;
|
||||
const string& DeviceName() const override;
|
||||
Status SetDeviceName(const char* name) override;
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) override;
|
||||
Status SetAttrInt(const char* attr_name, int64_t value) override;
|
||||
Status SetAttrFloat(const char* attr_name, float value) override;
|
||||
Status SetAttrBool(const char* attr_name, bool value) override;
|
||||
Status SetAttrType(const char* attr_name, DataType value) override;
|
||||
Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) override;
|
||||
Status SetAttrFunction(const char* attr_name,
|
||||
const AbstractOperation* value) override;
|
||||
Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) override;
|
||||
Status SetAttrTensor(const char* attr_name,
|
||||
AbstractTensorInterface* tensor) override;
|
||||
Status SetAttrStringList(const char* attr_name, const void* const* values,
|
||||
const size_t* lengths, int num_values) override;
|
||||
Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) override;
|
||||
Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) override;
|
||||
Status SetAttrTypeList(const char* attr_name, const DataType* values,
|
||||
int num_values) override;
|
||||
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
|
||||
int num_values) override;
|
||||
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) override;
|
||||
Status SetAttrFunctionList(
|
||||
const char* attr_name,
|
||||
absl::Span<const AbstractOperation*> values) override;
|
||||
AbstractOperation* GetBackingOperation();
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractOperation* ptr) {
|
||||
return ptr->getKind() == kTape;
|
||||
}
|
||||
~TapeOperation() override;
|
||||
|
||||
private:
|
||||
AbstractOperation* parent_op_;
|
||||
ForwardOperation forward_op_;
|
||||
Tape* tape_;
|
||||
const GradientRegistry& registry_;
|
||||
};
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_
|
Loading…
x
Reference in New Issue
Block a user