From 7ec70d54b2f5e178d9d7307cc1c65180603a2f2c Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Wed, 23 Sep 2020 15:55:12 -0700 Subject: [PATCH] - Integrate C++ tape with op building APIs via TapeContext and TapeOperation which delegate calls to a parent execution context and record operations on the tape. Please see gradients_test.cc for usage. - This will replace the helper functions in gradients_internal.h. I will clean that up in a followup CL. - Also drop ForwardOperation::ctx since that is unused right now. We can add it later if we need. PiperOrigin-RevId: 333390787 Change-Id: I80f2c460a9538a1a14ed1497c59f7b37a633a633 --- tensorflow/c/eager/BUILD | 5 +- tensorflow/c/eager/abstract_context.h | 2 +- tensorflow/c/eager/abstract_operation.h | 2 +- tensorflow/c/eager/gradients.h | 1 - tensorflow/c/eager/gradients_test.cc | 84 ++----- .../c/eager/mnist_gradients_testutil.cc | 5 - tensorflow/c/eager/tracing_utils.cc | 6 + .../c/experimental/gradients/tape/BUILD | 40 +++ .../gradients/tape/tape_context.cc | 47 ++++ .../gradients/tape/tape_context.h | 44 ++++ .../gradients/tape/tape_operation.cc | 238 ++++++++++++++++++ .../gradients/tape/tape_operation.h | 80 ++++++ 12 files changed, 477 insertions(+), 77 deletions(-) create mode 100644 tensorflow/c/experimental/gradients/tape/BUILD create mode 100644 tensorflow/c/experimental/gradients/tape/tape_context.cc create mode 100644 tensorflow/c/experimental/gradients/tape/tape_context.h create mode 100644 tensorflow/c/experimental/gradients/tape/tape_operation.cc create mode 100644 tensorflow/c/experimental/gradients/tape/tape_operation.h diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 7b664d57e5a..a71fdc8aa06 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -189,6 +189,7 @@ cc_library( deps = [ ":abstract_operation", ":c_api_unified_internal", + "//tensorflow/c/experimental/gradients/tape:tape_operation", "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", ], @@ -229,6 +230,7 @@ tf_cuda_cc_test( linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ + ":abstract_context", ":abstract_tensor_handle", ":c_api_experimental", ":c_api_test_util", @@ -239,7 +241,8 @@ tf_cuda_cc_test( "//tensorflow/c:tf_status_helper", "//tensorflow/c/experimental/gradients:array_grad", "//tensorflow/c/experimental/gradients:math_grad", - "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops", "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h index b488255d150..d31b1e13611 100644 --- a/tensorflow/c/eager/abstract_context.h +++ b/tensorflow/c/eager/abstract_context.h @@ -32,7 +32,7 @@ namespace tensorflow { // environment, a traced representation etc. class AbstractContext { protected: - enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt }; + enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape }; explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} virtual ~AbstractContext() {} diff --git a/tensorflow/c/eager/abstract_operation.h b/tensorflow/c/eager/abstract_operation.h index b332679cc7c..4c630528f5d 100644 --- a/tensorflow/c/eager/abstract_operation.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -30,7 +30,7 @@ namespace tensorflow { // tracing or immediate execution mode. class AbstractOperation { protected: - enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt }; + enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape }; explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} virtual ~AbstractOperation() {} diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index dc49b3c2cb4..d240806c044 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -80,7 +80,6 @@ struct ForwardOperation { std::vector inputs; std::vector outputs; AttrBuilder attrs; - AbstractContext* ctx; }; // Interface for building default zeros gradients for op outputs which are diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 3aedf55e97a..01834879b67 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" @@ -26,7 +27,9 @@ limitations under the License. #include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/experimental/gradients/array_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" @@ -53,72 +56,14 @@ class CppGradients }; Status RegisterGradients(GradientRegistry* registry) { - TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + // TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to + // AddV2Registerer. + TF_RETURN_IF_ERROR(registry->Register("AddV2", AddRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); return Status::OK(); } -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr add_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR( - Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(add_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(add_op.get())->SetOpName("my_add")); - } - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); - int num_retvals = 1; - return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `exp(inputs[0])` and records it on the tape. -Status Exp(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr exp_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR( - Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op)); - if (isa(exp_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(exp_op.get())->SetOpName("my_exp")); - } - TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op)); - int num_retvals = 1; - return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `IdentityN(inputs)` and records it on the tape. -Status IdentityN(AbstractContext* ctx, Tape* tape, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - AbstractOperationPtr identity_n_op(ctx->CreateOperation()); - ForwardOperation forward_op; - forward_op.ctx = ctx; - TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa(identity_n_op.get())) { - TF_RETURN_IF_ERROR(dyn_cast(identity_n_op.get()) - ->SetOpName("my_identity_n")); - } - TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op)); - int num_retvals = outputs.size(); - return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op, - tape, registry); -} // Computes // y = inputs[0] + inputs[1] @@ -132,8 +77,10 @@ Status AddGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. std::vector add_outputs(1); - TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), - registry)); // Compute x+y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, + absl::MakeSpan(add_outputs), + "Add")); // Compute x+y. std::unordered_map source_tensors_that_are_targets; @@ -164,8 +111,9 @@ Status ExpGradModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch x. std::vector exp_outputs(1); - TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs), - registry)); // Compute x+y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp")); std::unordered_map source_tensors_that_are_targets; @@ -197,8 +145,9 @@ Status IdentityNGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[1])); vector identity_n_outputs(2); - TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs, - absl::MakeSpan(identity_n_outputs), registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::IdentityN( + tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN")); std::unordered_map source_tensors_that_are_targets; @@ -533,7 +482,6 @@ TEST_P(CppGradients, TestSetAttrString) { AbstractOperationPtr check_numerics_op(ctx->CreateOperation()); ForwardOperation forward_op; - forward_op.ctx = ctx.get(); Status s = Reset(check_numerics_op.get(), "CheckNumerics", /*raw_device_name=*/nullptr, &forward_op); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index 932605ab8e0..8354b37354e 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -46,7 +46,6 @@ Status Add(AbstractContext* ctx, Tape* tape, const GradientRegistry& registry) { AbstractOperationPtr add_op(ctx->CreateOperation()); ForwardOperation forward_op; - forward_op.ctx = ctx; TF_RETURN_IF_ERROR( Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); if (isa(add_op.get())) { @@ -68,7 +67,6 @@ Status MatMul(AbstractContext* ctx, Tape* tape, const GradientRegistry& registry) { AbstractOperationPtr matmul_op(ctx->CreateOperation()); ForwardOperation forward_op; - forward_op.ctx = ctx; TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul", /*raw_device_name=*/nullptr, &forward_op)); if (isa(matmul_op.get())) { @@ -94,7 +92,6 @@ Status Mul(AbstractContext* ctx, Tape* tape, const GradientRegistry& registry) { AbstractOperationPtr mul_op(ctx->CreateOperation()); ForwardOperation forward_op; - forward_op.ctx = ctx; TF_RETURN_IF_ERROR( Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op)); if (isa(mul_op.get())) { @@ -117,7 +114,6 @@ Status Relu(AbstractContext* ctx, Tape* tape, const GradientRegistry& registry) { AbstractOperationPtr relu_op(ctx->CreateOperation()); ForwardOperation forward_op; - forward_op.ctx = ctx; TF_RETURN_IF_ERROR( Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op)); if (isa(relu_op.get())) { @@ -142,7 +138,6 @@ Status SparseSoftmaxCrossEntropyWithLogits( AbstractOperationPtr sm_op(ctx->CreateOperation()); ForwardOperation forward_op; - forward_op.ctx = ctx; TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", /*raw_device_name=*/nullptr, &forward_op)); if (isa(sm_op.get())) { diff --git a/tensorflow/c/eager/tracing_utils.cc b/tensorflow/c/eager/tracing_utils.cc index 72e0d35ea24..8eec4bc7d9a 100644 --- a/tensorflow/c/eager/tracing_utils.cc +++ b/tensorflow/c/eager/tracing_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/eager/tracing_utils.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" @@ -25,6 +26,11 @@ Status MaybeSetOpName(AbstractOperation* op, const char* op_name) { if (isa(op)) { TF_RETURN_IF_ERROR(dyn_cast(op)->SetOpName(op_name)); } + if (isa(op)) { + TF_RETURN_IF_ERROR(MaybeSetOpName( + dyn_cast(op)->GetBackingOperation(), + op_name)); + } return Status::OK(); } } // namespace tracing diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD new file mode 100644 index 00000000000..e57cc11b84f --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -0,0 +1,40 @@ +# A tape built on top of unified execution APIs. +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tape_context", + srcs = ["tape_context.cc"], + hdrs = [ + "tape_context.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":tape_operation", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_function", + "//tensorflow/c/eager:abstract_operation", + ], +) + +cc_library( + name = "tape_operation", + srcs = ["tape_operation.cc"], + hdrs = [ + "tape_operation.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_function", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:gradients_internal", + ], +) diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.cc b/tensorflow/c/experimental/gradients/tape/tape_context.cc new file mode 100644 index 00000000000..1fa1a3f24f1 --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_context.cc @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" + +namespace tensorflow { +namespace gradients { +TapeContext::TapeContext(AbstractContext* c, Tape* tape, + const GradientRegistry& registry) + : AbstractContext(kTape), parent_ctx_(c), tape_(tape), registry_(registry) { + // TODO(srbs): Make AbstractContext ref counted. + // parent_ctx_->Ref(); +} +void TapeContext::Release() { + // TODO(srbs): Change to Unref() + delete this; +} +TapeContext::~TapeContext() { + // TODO(srbs): Make AbstractContext ref counted. + // parent_ctx_->Unref(); +} +TapeOperation* TapeContext::CreateOperation() { + return new TapeOperation(parent_ctx_->CreateOperation(), tape_, registry_); +} +Status TapeContext::RegisterFunction(AbstractFunction* f) { + return parent_ctx_->RegisterFunction(f); +} +Status TapeContext::RemoveFunction(const string& func) { + return parent_ctx_->RemoveFunction(func); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.h b/tensorflow/c/experimental/gradients/tape/tape_context.h new file mode 100644 index 00000000000..291053226fb --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_context.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" + +namespace tensorflow { +namespace gradients { +class TapeContext : public AbstractContext { + public: + explicit TapeContext(AbstractContext*, Tape*, const GradientRegistry&); + void Release() override; + TapeOperation* CreateOperation() override; + Status RegisterFunction(AbstractFunction*) override; + Status RemoveFunction(const string& func) override; + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kTape; + } + ~TapeContext() override; + + private: + AbstractContext* parent_ctx_; // Not owned. + Tape* tape_; + const GradientRegistry& registry_; +}; +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.cc b/tensorflow/c/experimental/gradients/tape/tape_operation.cc new file mode 100644 index 00000000000..0b247d08f6c --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.cc @@ -0,0 +1,238 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/tape/tape_operation.h" + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape, + const GradientRegistry& registry) + : AbstractOperation(kTape), + parent_op_(parent_op), + tape_(tape), + registry_(registry) { + // TODO(srbs): Make AbstractOperation RefCounted. + // parent_op_->Ref(); +} +void TapeOperation::Release() { + // TODO(srbs): Change to Unref(). + delete this; +} +TapeOperation::~TapeOperation() { + // TODO(srbs): Make AbstractOperation RefCounted. + // parent_op->Unref(); +} +Status TapeOperation::Reset(const char* op, const char* raw_device_name) { + forward_op_.op_name = op; + forward_op_.attrs.Reset(op); + forward_op_.inputs.clear(); + forward_op_.outputs.clear(); + return parent_op_->Reset(op, raw_device_name); +} +const string& TapeOperation::Name() const { return parent_op_->Name(); } +const string& TapeOperation::DeviceName() const { + return parent_op_->DeviceName(); +} +Status TapeOperation::SetDeviceName(const char* name) { + return parent_op_->SetDeviceName(name); +} +Status TapeOperation::AddInput(AbstractTensorHandle* input) { + TF_RETURN_IF_ERROR(parent_op_->AddInput(input)); + forward_op_.inputs.push_back(input); + return Status::OK(); +} +Status TapeOperation::AddInputList( + absl::Span inputs) { + TF_RETURN_IF_ERROR(parent_op_->AddInputList(inputs)); + for (auto input : inputs) { + forward_op_.inputs.push_back(input); + } + return Status::OK(); +} +Status TapeOperation::SetAttrString(const char* attr_name, const char* data, + size_t length) { + forward_op_.attrs.Set(attr_name, StringPiece(data, length)); + return parent_op_->SetAttrString(attr_name, data, length); +} +Status TapeOperation::SetAttrInt(const char* attr_name, int64_t value) { + forward_op_.attrs.Set(attr_name, static_cast(value)); + return parent_op_->SetAttrInt(attr_name, value); +} +Status TapeOperation::SetAttrFloat(const char* attr_name, float value) { + forward_op_.attrs.Set(attr_name, value); + return parent_op_->SetAttrFloat(attr_name, value); +} +Status TapeOperation::SetAttrBool(const char* attr_name, bool value) { + forward_op_.attrs.Set(attr_name, value); + return parent_op_->SetAttrBool(attr_name, value); +} +Status TapeOperation::SetAttrType(const char* attr_name, DataType value) { + forward_op_.attrs.Set(attr_name, value); + return parent_op_->SetAttrType(attr_name, value); +} +Status TapeOperation::SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) { + if (num_dims > TensorShape::MaxDimensions()) { + return errors::InvalidArgument("Value specified for `", attr_name, "` has ", + num_dims, + " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), "."); + } + TensorShapeProto proto; + if (num_dims < 0) { + proto.set_unknown_rank(true); + } else { + for (int d = 0; d < num_dims; ++d) { + proto.add_dim()->set_size(dims[d]); + } + } + + forward_op_.attrs.Set(attr_name, proto); + return parent_op_->SetAttrShape(attr_name, dims, num_dims); +} +Status TapeOperation::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { + return tensorflow::errors::Unimplemented( + "SetAttrFunction has not been implemented yet."); +} +Status TapeOperation::SetAttrFunctionName(const char* attr_name, + const char* value, size_t length) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionName has not been implemented " + "yet."); +} +Status TapeOperation::SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) { + return tensorflow::errors::Unimplemented( + "SetAttrTensor has not been implemented yet."); +} +Status TapeOperation::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, int num_values) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = StringPiece(static_cast(values[i]), lengths[i]); + } + forward_op_.attrs.Set(attr_name, v); + return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values); +} +Status TapeOperation::SetAttrFloatList(const char* attr_name, + const float* values, int num_values) { + forward_op_.attrs.Set(attr_name, + gtl::ArraySlice(values, num_values)); + return parent_op_->SetAttrFloatList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrIntList(const char* attr_name, + const int64_t* values, int num_values) { + forward_op_.attrs.Set( + attr_name, gtl::ArraySlice( + reinterpret_cast(values), num_values)); + return parent_op_->SetAttrIntList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrTypeList(const char* attr_name, + const DataType* values, int num_values) { + forward_op_.attrs.Set(attr_name, + gtl::ArraySlice(values, num_values)); + return parent_op_->SetAttrTypeList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { + std::unique_ptr b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + forward_op_.attrs.Set(attr_name, + gtl::ArraySlice(b.get(), num_values)); + return parent_op_->SetAttrBoolList(attr_name, values, num_values); +} +Status TapeOperation::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, int num_values) { + std::unique_ptr proto(new TensorShapeProto[num_values]); + for (int i = 0; i < num_values; ++i) { + const auto num_dims_i = num_dims[i]; + + if (num_dims_i > TensorShape::MaxDimensions()) { + return errors::InvalidArgument( + strings::StrCat("Value specified for `", attr_name, "` has ", + num_dims_i, " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), ".")); + } + if (num_dims_i < 0) { + proto[i].set_unknown_rank(true); + } else { + const int64_t* dims_i = dims[i]; + auto proto_i = &proto[i]; + for (int d = 0; d < num_dims_i; ++d) { + proto_i->add_dim()->set_size(dims_i[d]); + } + } + } + forward_op_.attrs.Set( + attr_name, gtl::ArraySlice(proto.get(), num_values)); + return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values); +} +Status TapeOperation::SetAttrFunctionList( + const char* attr_name, absl::Span values) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionList has not been " + "implemented yet."); +} +AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; } +Status TapeOperation::Execute(absl::Span retvals, + int* num_retvals) { + TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals)); + std::vector input_ids(forward_op_.inputs.size()); + std::vector input_dtypes(forward_op_.inputs.size()); + for (int i = 0; i < forward_op_.inputs.size(); i++) { + input_ids[i] = ToId(forward_op_.inputs[i]); + input_dtypes[i] = forward_op_.inputs[i]->DataType(); + } + for (int i = 0; i < *num_retvals; i++) { + // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. + forward_op_.outputs.push_back(retvals[i]); + } + // TODO(b/166669239): This is needed to support AttrBuilder::Get for string + // attributes. Number type attrs and DataType attrs work fine without this. + // Consider getting rid of this and making the behavior between number types + // and string consistent. + forward_op_.attrs.BuildNodeDef(); + std::vector tape_tensors; + for (auto t : retvals) { + tape_tensors.push_back(TapeTensor(t)); + } + tape_->RecordOperation( + parent_op_->Name(), tape_tensors, input_ids, input_dtypes, + [this]() -> BackwardFunction* { + std::unique_ptr backward_fn; + Status s = registry_.Lookup(forward_op_, &backward_fn); + if (!s.ok()) { + return nullptr; + } + return backward_fn.release(); + }, + [](BackwardFunction* ptr) { + if (ptr) { + delete ptr; + } + }); + return Status::OK(); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.h b/tensorflow/c/experimental/gradients/tape/tape_operation.h new file mode 100644 index 00000000000..b971176d9e7 --- /dev/null +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ + +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +class TapeOperation : public AbstractOperation { + public: + explicit TapeOperation(AbstractOperation*, Tape*, const GradientRegistry&); + void Release() override; + Status Reset(const char* op, const char* raw_device_name) override; + const string& Name() const override; + const string& DeviceName() const override; + Status SetDeviceName(const char* name) override; + Status AddInput(AbstractTensorHandle* input) override; + Status AddInputList(absl::Span inputs) override; + Status Execute(absl::Span retvals, + int* num_retvals) override; + Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + Status SetAttrInt(const char* attr_name, int64_t value) override; + Status SetAttrFloat(const char* attr_name, float value) override; + Status SetAttrBool(const char* attr_name, bool value) override; + Status SetAttrType(const char* attr_name, DataType value) override; + Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override; + Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + Status SetAttrStringList(const char* attr_name, const void* const* values, + const size_t* lengths, int num_values) override; + Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + Status SetAttrBoolList(const char* attr_name, const unsigned char* values, + int num_values) override; + Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + Status SetAttrFunctionList( + const char* attr_name, + absl::Span values) override; + AbstractOperation* GetBackingOperation(); + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kTape; + } + ~TapeOperation() override; + + private: + AbstractOperation* parent_op_; + ForwardOperation forward_op_; + Tape* tape_; + const GradientRegistry& registry_; +}; + +} // namespace gradients +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_