diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 5c101bef85f..a77e76644b8 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -171,6 +171,87 @@ cc_library( ], ) +cc_library( + name = "gradients", + srcs = [ + "gradients.cc", + "gradients_internal.h", + ], + hdrs = [ + "gradients.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":abstract_operation", + ":abstract_tensor_handle", + ":c_api_unified_internal", + ":tape", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "gradients_internal", + srcs = [ + "gradients.cc", + ], + hdrs = [ + "gradients.h", + "gradients_internal.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":abstract_operation", + ":abstract_tensor_handle", + ":c_api_unified_internal", + ":tape", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +tf_cuda_cc_test( + name = "gradients_test", + size = "small", + srcs = [ + "gradients_test.cc", + ], + args = ["--heap_check=local"], + extra_copts = tfe_xla_copts(), + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_test_util", + ":c_api_unified_internal", + ":gradients_internal", + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "abstract_tensor_handle", hdrs = ["abstract_tensor_handle.h"], @@ -747,6 +828,7 @@ filegroup( "c_api_unified_experimental_eager.cc", "c_api_unified_experimental_graph.cc", "c_api_unified_experimental_internal.h", + "gradients.cc", # Uses RTTI. "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc new file mode 100644 index 00000000000..3a7a6282192 --- /dev/null +++ b/tensorflow/c/eager/gradients.cc @@ -0,0 +1,400 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradients.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +namespace tensorflow { +namespace gradients { + +Status GradientRegistry::Register(const string& op_name, + GradientFunctionFactory factory) { + auto iter = registry_.find(op_name); + if (iter != registry_.end()) { + const string error_msg = "Gradient already exists for op: " + op_name + "."; + return errors::AlreadyExists(error_msg); + } + registry_.insert({op_name, factory}); + return Status::OK(); +} +Status GradientRegistry::Lookup( + const ForwardOperation& op, + std::unique_ptr* grad_fn) const { + auto iter = registry_.find(op.op_name); + if (iter == registry_.end()) { + const string error_msg = "No gradient defined for op: " + op.op_name + "."; + return errors::NotFound(error_msg); + } + grad_fn->reset(iter->second(op)); + return Status::OK(); +} + +int64 ToId(AbstractTensorHandle* t) { + return static_cast(reinterpret_cast(t)); +} + +TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx) + : handle_(handle), ctx_(ctx) { + // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely + // on the client to keep this tensor live for the duration of the gradient + // computation. + // handle_->Ref(); +} +TapeTensor::TapeTensor(const TapeTensor& other) { + handle_ = other.handle_; + // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely + // on the client to keep this tensor live for the duration of the gradient + // computation. + // handle_->Ref(); + ctx_ = other.ctx_; +} +TapeTensor::~TapeTensor() { + // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely + // on the client to keep this tensor live for the duration of the gradient + // computation. + // handle_->Unref(); +} + +tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); } + +tensorflow::DataType TapeTensor::GetDType() const { + return handle_->DataType(); +} + +AbstractTensorHandle* TapeTensor::OnesLike() const { + AbstractOperationPtr op(ctx_->CreateOperation()); + Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr); + if (!s.ok()) { + return nullptr; + } + if (isa(op.get())) { + s = dyn_cast(op.get())->SetOpName( + absl::StrCat("OnesLike", ToId(handle_)).c_str()); + if (!s.ok()) { + return nullptr; + } + } + s = op->AddInput(handle_); + if (!s.ok()) { + return nullptr; + } + int num_outputs = 1; + // TODO(srbs): Figure out who is in charge of releasing this. + std::vector outputs(num_outputs); + s = op->Execute(absl::Span(outputs), &num_outputs); + if (!s.ok()) { + return nullptr; + } + return outputs[0]; +} +AbstractTensorHandle* TapeTensor::ZerosLike() const { + AbstractOperationPtr op(ctx_->CreateOperation()); + // TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR. + Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr); + if (!s.ok()) { + return nullptr; + } + if (isa(op.get())) { + s = dyn_cast(op.get())->SetOpName( + absl::StrCat("OnesLike", ToId(handle_)).c_str()); + if (!s.ok()) { + return nullptr; + } + } + s = op->AddInput(handle_); + if (!s.ok()) { + return nullptr; + } + int num_outputs = 1; + // TODO(srbs): Figure out who is in charge of releasing this. + std::vector outputs(num_outputs); + s = op->Execute(absl::Span(outputs), &num_outputs); + if (!s.ok()) { + return nullptr; + } + return outputs[0]; +} + +// Returns the number of elements in the gradient tensor. +int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const { + // TODO(srbs): It seems like this is used only for performance optimization + // and not for correctness. The only downside of keeping this 1 seems to be + // that the gradient accumulation is unbounded and we will never + // aggressively aggregate accumulated gradients to recover memory. + // Revisit and fix. + return 1; +} + +// Consumes references to the tensors in the gradient_tensors list and returns +// a tensor with the result. +AbstractTensorHandle* TapeVSpace::AggregateGradients( + gtl::ArraySlice gradient_tensors) const { + if (gradient_tensors.size() == 1) { + return gradient_tensors[0]; + } + + AbstractOperationPtr op(ctx_->CreateOperation()); + Status s = op->Reset("AddN", /*raw_device_name=*/nullptr); + if (!s.ok()) { + return nullptr; + } + s = op->AddInputList(gradient_tensors); + if (!s.ok()) { + return nullptr; + } + + int num_outputs = 1; + std::vector outputs(num_outputs); + s = op->Execute(absl::Span(outputs), &num_outputs); + if (!s.ok()) { + return nullptr; + } + return outputs[0]; +} + +// Calls the passed-in backward function. +Status TapeVSpace::CallBackwardFunction( + GradientFunction* backward_function, + const std::vector& unneeded_gradients, + gtl::ArraySlice output_gradients, + std::vector* result) const { + if (backward_function == nullptr) return Status::OK(); + return backward_function->Compute(output_gradients, result); +} + +// Looks up the ID of a Gradient. +int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const { + return ToId(tensor); +} + +// Converts a Gradient to a TapeTensor. +TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const { + return TapeTensor(g, ctx_); +} + +void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {} + +void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const { + gradient->Release(); +} + +// Helper functions which delegate to `AbstractOperation`, update +// the state of the ForwardOperation and call the tape as appropriate. +// These APIs are mainly to faciliate testing and are subject to change. +namespace internal { +Status Reset(AbstractOperation* op_, const char* op, + const char* raw_device_name, ForwardOperation* forward_op_) { + forward_op_->op_name = op; + return op_->Reset(op, raw_device_name); +} +Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, + ForwardOperation* forward_op_) { + TF_RETURN_IF_ERROR(op_->AddInput(input)); + forward_op_->inputs.push_back(input); + return Status::OK(); +} +Status AddInputList(AbstractOperation* op_, + absl::Span inputs, + ForwardOperation* forward_op_) { + TF_RETURN_IF_ERROR(op_->AddInputList(inputs)); + for (auto input : inputs) { + forward_op_->inputs.push_back(input); + } + return Status::OK(); +} + +Status SetAttrString(AbstractOperation* op_, const char* attr_name, + const char* data, size_t length, + ForwardOperation* forward_op_) { + forward_op_->attrs.Set(attr_name, StringPiece(data, length)); + return op_->SetAttrString(attr_name, data, length); +} +Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value, + ForwardOperation* forward_op_) { + forward_op_->attrs.Set(attr_name, static_cast(value)); + return op_->SetAttrInt(attr_name, value); +} +Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value, + ForwardOperation* forward_op_) { + forward_op_->attrs.Set(attr_name, value); + return op_->SetAttrFloat(attr_name, value); +} +Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value, + ForwardOperation* forward_op_) { + forward_op_->attrs.Set(attr_name, value); + return op_->SetAttrBool(attr_name, value); +} +Status SetAttrType(AbstractOperation* op_, const char* attr_name, + DataType value, ForwardOperation* forward_op_) { + forward_op_->attrs.Set(attr_name, value); + return op_->SetAttrType(attr_name, value); +} +Status SetAttrShape(AbstractOperation* op_, const char* attr_name, + const int64_t* dims, const int num_dims, + ForwardOperation* forward_op_) { + if (num_dims > TensorShape::MaxDimensions()) { + return errors::InvalidArgument("Value specified for `", attr_name, "` has ", + num_dims, + " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), "."); + } + TensorShapeProto proto; + if (num_dims < 0) { + proto.set_unknown_rank(true); + } else { + for (int d = 0; d < num_dims; ++d) { + proto.add_dim()->set_size(dims[d]); + } + } + + forward_op_->attrs.Set(attr_name, proto); + return op_->SetAttrShape(attr_name, dims, num_dims); +} +Status SetAttrFunction(AbstractOperation* op_, const char* attr_name, + const AbstractOperation* value, + ForwardOperation* forward_op_) { + return tensorflow::errors::Unimplemented( + "SetAttrFunction has not been implemented yet."); +} +Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name, + const char* value, size_t length, + ForwardOperation* forward_op_) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionName has not been implemented " + "yet."); +} +Status SetAttrTensor(AbstractOperation* op_, const char* attr_name, + AbstractTensorInterface* tensor, + ForwardOperation* forward_op_) { + return tensorflow::errors::Unimplemented( + "SetAttrTensor has not been implemented yet."); +} +Status SetAttrStringList(AbstractOperation* op_, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values, ForwardOperation* forward_op_) { + std::vector v(num_values); + for (int i = 0; i < num_values; ++i) { + v[i] = StringPiece(static_cast(values[i]), lengths[i]); + } + forward_op_->attrs.Set(attr_name, v); + return op_->SetAttrStringList(attr_name, values, lengths, num_values); +} +Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name, + const float* values, int num_values, + ForwardOperation* forward_op_) { + forward_op_->attrs.Set(attr_name, + gtl::ArraySlice(values, num_values)); + return op_->SetAttrFloatList(attr_name, values, num_values); +} +Status SetAttrIntList(AbstractOperation* op_, const char* attr_name, + const int64_t* values, int num_values, + ForwardOperation* forward_op_) { + forward_op_->attrs.Set( + attr_name, gtl::ArraySlice( + reinterpret_cast(values), num_values)); + return op_->SetAttrIntList(attr_name, values, num_values); +} +Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name, + const DataType* values, int num_values, + ForwardOperation* forward_op_) { + forward_op_->attrs.Set(attr_name, + gtl::ArraySlice(values, num_values)); + return op_->SetAttrTypeList(attr_name, values, num_values); +} +Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name, + const unsigned char* values, int num_values, + ForwardOperation* forward_op_) { + std::unique_ptr b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + forward_op_->attrs.Set(attr_name, + gtl::ArraySlice(b.get(), num_values)); + return op_->SetAttrBoolList(attr_name, values, num_values); +} +Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, ForwardOperation* forward_op_) { + std::unique_ptr proto(new TensorShapeProto[num_values]); + for (int i = 0; i < num_values; ++i) { + const auto num_dims_i = num_dims[i]; + + if (num_dims_i > TensorShape::MaxDimensions()) { + return errors::InvalidArgument( + strings::StrCat("Value specified for `", attr_name, "` has ", + num_dims_i, " dimensions which is over the limit of ", + TensorShape::MaxDimensions(), ".")); + } + if (num_dims_i < 0) { + proto[i].set_unknown_rank(true); + } else { + const int64_t* dims_i = dims[i]; + auto proto_i = &proto[i]; + for (int d = 0; d < num_dims_i; ++d) { + proto_i->add_dim()->set_size(dims_i[d]); + } + } + } + forward_op_->attrs.Set( + attr_name, gtl::ArraySlice(proto.get(), num_values)); + return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values); +} +Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name, + absl::Span values, + ForwardOperation* forward_op_) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionList has not been " + "implemented yet."); +} +Status Execute(AbstractOperation* op_, AbstractContext* ctx, + absl::Span retvals, int* num_retvals, + ForwardOperation* forward_op_, Tape* tape, + const GradientRegistry& registry) { + TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals)); + std::vector input_ids(forward_op_->inputs.size()); + std::vector input_dtypes(forward_op_->inputs.size()); + for (int i = 0; i < forward_op_->inputs.size(); i++) { + input_ids[i] = ToId(forward_op_->inputs[i]); + input_dtypes[i] = forward_op_->inputs[i]->DataType(); + } + std::vector tape_tensors; + for (auto t : retvals) { + tape_tensors.push_back(TapeTensor(t, ctx)); + } + tape->RecordOperation( + op_->Name(), tape_tensors, input_ids, input_dtypes, + [registry, forward_op_]() -> GradientFunction* { + std::unique_ptr grad_fn; + Status s = registry.Lookup(*forward_op_, &grad_fn); + if (!s.ok()) { + return nullptr; + } + return grad_fn.release(); + }, + [](GradientFunction* ptr) { + if (ptr) { + delete ptr; + } + }); + return Status::OK(); +} +} // namespace internal + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h new file mode 100644 index 00000000000..e09b6ff8613 --- /dev/null +++ b/tensorflow/c/eager/gradients.h @@ -0,0 +1,171 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_ +#define TENSORFLOW_C_EAGER_GRADIENTS_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/tape.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" + +namespace tensorflow { +namespace gradients { + +// =============== Experimental C++ API for computing gradients =============== + +// Sample gradient function: +// +// class AddGradientFunction : public GradientFunction { +// public: +// Status Compute(absl::Span grad_inputs, +// std::vector* grad_outputs) override { +// grad_outputs->resize(2); +// (*grad_outputs)[0] = grad_inputs[0]; +// (*grad_outputs)[1] = grad_inputs[0]; +// return Status::OK(); +// } +// ~AddGradientFunction() override {} +// }; +// +// GradientFunction* AddRegisterer(const ForwardOperation& op) { +// // More complex gradient functions can use inputs/attrs etc. from the +// // forward `op`. +// return new AddGradientFunction; +// } +// +// Status RegisterGradients(GradientRegistry* registry) { +// return registry->Register("Add", AddRegisterer); +// } +class GradientFunction { + public: + // TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in + // `grad_inputs`. + virtual Status Compute(absl::Span grad_inputs, + std::vector* grad_outputs) = 0; + virtual ~GradientFunction() {} +}; + +// Metadata from the forward operation that is made available to the +// gradient registerer to instantiate a GradientFunction. +struct ForwardOperation { + public: + string op_name; + std::vector inputs; + std::vector outputs; + AttrBuilder attrs; + AbstractContext* ctx; +}; + +using GradientFunctionFactory = + std::function; + +// Map from op name to a `GradientFunctionFactory`. +class GradientRegistry { + public: + Status Register(const string& op, GradientFunctionFactory factory); + Status Lookup(const ForwardOperation& op, + std::unique_ptr* grad_fn) const; + + private: + absl::flat_hash_map registry_; +}; + +// Returns a unique id for the tensor which is used by the tape to build +// the gradient graph. See documentation of `TapeTensor` for more details. +int64 ToId(AbstractTensorHandle* t); + +// Wrapper for a tensor output of an operation executing under a tape. +// +// `GetID` returns a unique id for the wrapped tensor which is used to maintain +// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of +// the op that produced it (or -1 if this tensor was watched using +// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each +// op executed under the tape. A separate map (`tensorflow::eager::OpTape`) +// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`, +// inputs and outputs and the gradient function These data structures combined +// allow us to trace the data dependencies between operations and hence compute +// gradients. +// +// This also implements `ZerosLike` and `OnesLike` to create the default +// incoming gradients for tensors which do not already have an incoming +// gradient. +class TapeTensor { + public: + TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx); + TapeTensor(const TapeTensor& other); + ~TapeTensor(); + + tensorflow::int64 GetID() const; + tensorflow::DataType GetDType() const; + + AbstractTensorHandle* OnesLike() const; + AbstractTensorHandle* ZerosLike() const; + + private: + AbstractTensorHandle* handle_; + // The context where OnesLike and ZerosLike ops are to be created. + AbstractContext* ctx_; +}; + +// Vector space for actually computing gradients. Implements methods for calling +// the backward function with incoming gradients and returning the outgoing +// gradient and for performing gradient aggregation. +// See `tensorflow::eager::VSpace` for more details. +class TapeVSpace + : public eager::VSpace { + public: + explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {} + ~TapeVSpace() override {} + + // Returns the number of elements in the gradient tensor. + int64 NumElements(AbstractTensorHandle* tensor) const override; + + // Consumes references to the tensors in the gradient_tensors list and returns + // a tensor with the result. + AbstractTensorHandle* AggregateGradients( + gtl::ArraySlice gradient_tensors) const override; + + // Calls the passed-in backward function. + Status CallBackwardFunction( + GradientFunction* backward_function, + const std::vector& unneeded_gradients, + gtl::ArraySlice output_gradients, + std::vector* result) const override; + + // Looks up the ID of a Gradient. + int64 TensorId(AbstractTensorHandle* tensor) const override; + + // Converts a Gradient to a TapeTensor. + TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override; + + void MarkAsResult(AbstractTensorHandle* gradient) const override; + + void DeleteGradient(AbstractTensorHandle* gradient) const override; + + private: + // The context where the aggregation op `Add` is to be created. + AbstractContext* ctx_; +}; + +// A tracing/immediate-execution agnostic tape. +using Tape = tensorflow::eager::GradientTape; + +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_ diff --git a/tensorflow/c/eager/gradients_internal.h b/tensorflow/c/eager/gradients_internal.h new file mode 100644 index 00000000000..5ddf017413a --- /dev/null +++ b/tensorflow/c/eager/gradients_internal.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +// Helper functions which delegate to `AbstractOperation`, update +// the state of the ForwardOperation and call the tape as appropriate. +// These APIs are mainly to faciliate testing and are subject to change. + +// Records the op name in the `ForwardOperation`. +Status Reset(AbstractOperation*, const char* op, const char* raw_device_name, + ForwardOperation*); + +// Records the inputs in the `ForwardOperation`. +Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*); +Status AddInputList(AbstractOperation*, + absl::Span inputs, + ForwardOperation*); + +// Sets the attrs in the `ForwardOperation`. +Status SetAttrString(AbstractOperation*, const char* attr_name, + const char* data, size_t length, ForwardOperation*); +Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value, + ForwardOperation*); +Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value, + ForwardOperation*); +Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value, + ForwardOperation*); +Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value, + ForwardOperation*); +Status SetAttrShape(AbstractOperation*, const char* attr_name, + const int64_t* dims, const int num_dims, ForwardOperation*); +Status SetAttrFunction(AbstractOperation*, const char* attr_name, + const AbstractOperation* value, ForwardOperation*); +Status SetAttrFunctionName(AbstractOperation*, const char* attr_name, + const char* value, size_t length, ForwardOperation*); +Status SetAttrTensor(AbstractOperation*, const char* attr_name, + AbstractTensorInterface* tensor, ForwardOperation*); +Status SetAttrStringList(AbstractOperation*, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values, ForwardOperation*); +Status SetAttrFloatList(AbstractOperation*, const char* attr_name, + const float* values, int num_values, ForwardOperation*); +Status SetAttrIntList(AbstractOperation*, const char* attr_name, + const int64_t* values, int num_values, ForwardOperation*); +Status SetAttrTypeList(AbstractOperation*, const char* attr_name, + const DataType* values, int num_values, + ForwardOperation*); +Status SetAttrBoolList(AbstractOperation*, const char* attr_name, + const unsigned char* values, int num_values, + ForwardOperation*); +Status SetAttrShapeList(AbstractOperation*, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, ForwardOperation*); +Status SetAttrFunctionList(AbstractOperation*, const char* attr_name, + absl::Span values, + ForwardOperation*); + +// Make the call to `Tape::RecordOperation`. +Status Execute(AbstractOperation*, AbstractContext*, + absl::Span retvals, int* num_retvals, + ForwardOperation*, Tape*, const GradientRegistry&); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_GRADIENTS_INTERNAL_H_ diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc new file mode 100644 index 00000000000..5820058f3e2 --- /dev/null +++ b/tensorflow/c/eager/gradients_test.cc @@ -0,0 +1,328 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradients.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_SetTracingImplementation(std::get<0>(GetParam())); + } +}; + +// Creates an Identity op. +Status Identity(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr identity_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + identity_op->Reset("Identity", /*raw_device_name=*/nullptr)); + if (isa(identity_op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(identity_op.get()) + ->SetOpName(name)); + } + TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0])); + int num_retvals = 1; + TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +// =================== Register gradients for Add ============================ +class AddGradientFunction : public GradientFunction { + public: + explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {} + Status Compute(absl::Span grad_inputs, + std::vector* grad_outputs) override { + grad_outputs->resize(2); + std::vector identity_outputs(1); + TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]}, + absl::MakeSpan(identity_outputs), "Id0")); + (*grad_outputs)[0] = identity_outputs[0]; + TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]}, + absl::MakeSpan(identity_outputs), "Id1")); + (*grad_outputs)[1] = identity_outputs[0]; + return Status::OK(); + } + ~AddGradientFunction() override {} + + private: + AbstractContext* ctx_; +}; + +GradientFunction* AddRegisterer(const ForwardOperation& op) { + return new AddGradientFunction(op.ctx); +} + +Status RegisterGradients(GradientRegistry* registry) { + return registry->Register("Add", AddRegisterer); +} + +// =================== End gradient registrations ============================ + +// Computes `inputs[0] + inputs[1]` and records it on the tape. +Status Add(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractOperationPtr add_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(add_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(add_op.get())->SetOpName("my_add")); + } + TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); + TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); + int num_retvals = 1; + return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes +// y = inputs[0] + inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status AddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector add_outputs(1); + TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), + registry)); // Compute x+y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads)); + for (auto add_output : add_outputs) { + add_output->Release(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +AbstractContext* BuildFunction(const char* fn_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); + return unwrap(graph_ctx); +} + +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + std::vector* params) { + tracing::TracingTensorHandle* handle = nullptr; + for (auto input : inputs) { + TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( + input->DataType(), &handle)); + params->emplace_back(handle); + } + return Status::OK(); +} + +using Model = std::function, + absl::Span, const GradientRegistry&)>; + +// Runs `model` maybe wrapped in a function. +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry) { + if (use_function) { + const char* fn_name = "test_fn"; + std::unique_ptr scoped_func; + { + AbstractContextPtr func_ctx(BuildFunction(fn_name)); + std::vector func_inputs; + func_inputs.reserve(inputs.size()); + TF_RETURN_IF_ERROR( + CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); + OutputList output_list; + output_list.expected_num_outputs = outputs.size(); + output_list.outputs.resize(outputs.size()); + TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), + absl::MakeSpan(output_list.outputs), registry)); + for (auto func_input : func_inputs) { + func_input->Release(); + } + AbstractFunction* func = nullptr; + TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) + ->Finalize(&output_list, &func)); + scoped_func.reset(func); + output_list.outputs[0]->Release(); + output_list.outputs[1]->Release(); + TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); + } + + AbstractOperationPtr fn_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); + for (auto input : inputs) { + TF_RETURN_IF_ERROR(fn_op->AddInput(input)); + } + int retvals = outputs.size(); + TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals)); + TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); + return Status::OK(); + } else { + return model(ctx, inputs, outputs, registry); + } +} + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_DeleteContextOptions(opts); + return Status::OK(); +} + +Status TestScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return Status::OK(); +} + +Status getValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + return Status::OK(); +} + +TEST_P(CppGradients, TestAddGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // tape.watch(y) + // y = x + y + // outputs = tape.gradient(y, [x, y]) + std::vector outputs(2); + s = RunModel(AddGradModel, ctx.get(), {x.get(), y.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[0]->Release(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; + + s = getValue(outputs[1], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[1]->Release(); + TF_DeleteTensor(result_tensor); +} + +// TODO(b/160888630): Enable this test with mlir after AddInputList is +// supported. It is needed for AddN op which is used for gradient aggregation. +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow