Simplify C++ tape APIs to match the proposal in https://github.com/tensorflow/community/pull/335

- The main change is to get rid of int64 tensor ids from Tape and directly use AbstractTensorHandles.
- Get rid of tensorflow::gradients::Context and directly use AbstractContext*.
- Get rid of DefaultGradientFunction and BackwardFunction(which was a wrapper for a GradientFunction and DefaultGradientFunction). We had introduced DefaultGradientFunction in order to support existing python gradient functions which expect all necessary incoming grads to be non-None. This is only relevant for ops with more than one output, which are few. We could handle those by creating a wrapper GradientFunction that builds the zeros if needed. Getting rid of DefaultGradientFunction greatly simplifies the API.
- Introduce ForwardOperation::skip_input_indices. This will be filled up in a follow-up change. There is a bug tracking this.
- Introduce helpers for implementing behavior of tf.no_gradient and tf.stop_gradient, i.e. RegisterNotDifferentiable and NotDifferentiableGradientFunction.
- One slight behavior change: Currently, when an op does not have a GradientFunction registered we silently record a nullptr GradientFunction on the tape. This sometimes leads to uninformative error messages. Now we loudly raise an error when GradientRegistry::Lookup fails in TapeContext::Execute. So any op executing under a TapeContext must have a registered GradientFunction. Non-differentiable ops need to be explicitly registered using RegisterNotDifferentiable e.g. CheckNumerics in gradients_test.cc

c/eager/tape.h: I changed the signatures of gradient functions to use `absl::Span<Gradient*>` instead of `vector<Gradient*>*` for the result grads. This makes it consistent with the new Tape API and generally makes things cleaner.

PiperOrigin-RevId: 345534016
Change-Id: Ie1bf5dff88f87390e6b470acc379d3852ce68b5c
This commit is contained in:
Saurabh Saxena 2020-12-03 14:20:02 -08:00 committed by TensorFlower Gardener
parent 9193f62b81
commit 4e25cac495
20 changed files with 766 additions and 903 deletions

View File

@ -217,12 +217,12 @@ cc_library(
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api_unified_internal",
":tape",
"//tensorflow/core/common_runtime/eager:attr_builder",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
@ -278,6 +278,7 @@ tf_cuda_cc_test(
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:array_grad",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:not_differentiable",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops",
"//tensorflow/cc/profiler",
@ -454,6 +455,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:tensor_float_32_utils",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@ -468,7 +470,6 @@ tf_cuda_cc_test(
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"nomac",
"no_cuda_asan", # b/173825513
],
deps = [
@ -493,6 +494,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:tensor_float_32_utils",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",

View File

@ -20,11 +20,19 @@ limitations under the License.
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
namespace {
// TODO(b/172558015): Using the pointer address as the identifier for the tensor
// may lead to collisions. Introduce another way to get a unique id for this
// tensor.
int64 ToId(const AbstractTensorHandle* t) {
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
}
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
AbstractTensorHandle** result) {
AbstractOperationPtr op(ctx->CreateOperation());
@ -43,85 +51,28 @@ Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
}
} // namespace
class IncomingGradientsImpl : public IncomingGradients {
public:
explicit IncomingGradientsImpl(
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
DefaultGradientFunction* default_gradients)
: grad_inputs_(grad_inputs),
ctx_(ctx),
default_gradients_(default_gradients) {}
AbstractTensorHandle* operator[](int i) const override {
return default_gradients_->get(ctx_, grad_inputs_, i);
}
size_t size() const override { return grad_inputs_.size(); }
private:
absl::Span<AbstractTensorHandle* const> grad_inputs_;
Context* ctx_;
DefaultGradientFunction* default_gradients_;
};
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
: outputs_(op.outputs) {
for (auto output : outputs_) {
output->Ref();
}
}
AbstractTensorHandle* AllZerosDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
if (grad_inputs[i]) {
return grad_inputs[i];
}
if (cached_default_grads_[i]) {
return cached_default_grads_[i].get();
}
AbstractTensorHandle* result = nullptr;
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
if (!s.ok()) {
if (result) {
result->Unref();
}
VLOG(1) << "Failed to create ZerosLike for index " << i;
return nullptr;
}
cached_default_grads_[i].reset(result);
return result;
}
PassThroughDefaultGradients::PassThroughDefaultGradients(
const ForwardOperation& op) {}
AbstractTensorHandle* PassThroughDefaultGradients::get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
return grad_inputs[i];
}
Status GradientRegistry::Register(
const string& op_name, BackwardFunctionFactory backward_function_factory) {
const string& op_name, GradientFunctionFactory gradient_function_factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, backward_function_factory});
registry_.insert({op_name, gradient_function_factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<BackwardFunction>* backward_function) const {
std::unique_ptr<GradientFunction>* gradient_function) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
backward_function->reset(iter->second(op));
gradient_function->reset(iter->second(op));
return Status::OK();
}
int64 ToId(AbstractTensorHandle* t) {
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
}
TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
handle_->Ref();
}
@ -140,6 +91,47 @@ AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
// Returns the number of elements in the gradient tensor.
int64 NumElements(AbstractTensorHandle* tensor) const override;
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
// Calls the passed-in backward function.
// op_type is the op's name provided in RecordOperation.
Status CallBackwardFunction(
const string& op_type, GradientFunction* gradient_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
absl::Span<AbstractTensorHandle*> result) const override;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const override;
// Looks up the ID of a Gradient.
int64 TensorId(AbstractTensorHandle* tensor) const override;
// Converts a Gradient to a TapeTensor.
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
void MarkAsResult(AbstractTensorHandle* gradient) const override;
void DeleteGradient(AbstractTensorHandle* gradient) const override;
private:
// The context where the aggregation op `Add` is to be created.
AbstractContext* ctx_;
};
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
// TODO(srbs): It seems like this is used only for performance optimization
@ -178,17 +170,20 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
}
// Calls the passed-in backward function.
// op_type is the op's name provided in RecordOperation.
Status TapeVSpace::CallBackwardFunction(
BackwardFunction* backward_function,
const string& op_type, GradientFunction* gradient_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
Context ctx = {ctx_};
IncomingGradientsImpl incoming_gradients(
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
return backward_function->GetGradientFunction()->Compute(
&ctx, incoming_gradients, result);
absl::Span<AbstractTensorHandle*> result) const {
if (gradient_function == nullptr) {
return errors::InvalidArgument(
"Provided null gradient_function for '", op_type, "'.\n",
"If the intent is to treat this op as non-differentiable consider "
"using RegisterNotDifferentiable or "
"NotDifferentiableGradientFunction.");
}
return gradient_function->Compute(ctx_, output_gradients, result);
}
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
@ -224,6 +219,81 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Unref();
}
void Tape::Watch(const AbstractTensorHandle* t) {
GradientTape::Watch(ToId(t));
}
void Tape::RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle* const> outputs,
GradientFunction* gradient_function,
const string& op_name) {
std::vector<int64> input_ids(inputs.size());
std::vector<tensorflow::DataType> input_dtypes(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
input_ids[i] = ToId(inputs[i]);
input_dtypes[i] = inputs[i]->DataType();
}
std::vector<TapeTensor> tape_tensors;
for (auto t : outputs) {
tape_tensors.push_back(TapeTensor(t));
}
GradientTape::RecordOperation(
op_name, tape_tensors, input_ids, input_dtypes,
[gradient_function]() -> GradientFunction* { return gradient_function; },
[](GradientFunction* ptr) {
if (ptr) {
delete ptr;
}
});
}
bool Tape::ShouldRecord(
absl::Span<const AbstractTensorHandle* const> tensors) const {
std::vector<int64> tensor_ids(tensors.size());
std::vector<tensorflow::DataType> tensor_dtypes(tensors.size());
for (int i = 0; i < tensors.size(); i++) {
tensor_ids[i] = ToId(tensors[i]);
tensor_dtypes[i] = tensors[i]->DataType();
}
return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes);
}
void Tape::DeleteTrace(const AbstractTensorHandle* t) {
GradientTape::DeleteTrace(ToId(t));
}
std::vector<int64> MakeTensorIDList(
absl::Span<AbstractTensorHandle* const> tensors) {
std::vector<int64> ids(tensors.size());
for (int i = 0; i < tensors.size(); i++) {
ids[i] = ToId(tensors[i]);
}
return ids;
}
Status Tape::ComputeGradient(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
absl::Span<AbstractTensorHandle* const> sources,
absl::Span<AbstractTensorHandle* const> output_gradients,
absl::Span<AbstractTensorHandle*> result) {
TapeVSpace vspace(ctx);
std::vector<int64> target_tensor_ids = MakeTensorIDList(targets);
std::vector<int64> source_tensor_ids = MakeTensorIDList(sources);
tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(
source_tensor_ids.begin(), source_tensor_ids.end());
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
for (int i = 0; i < target_tensor_ids.size(); ++i) {
int64 target_id = target_tensor_ids[i];
if (sources_set.find(target_id) != sources_set.end()) {
auto tensor = targets[i];
sources_that_are_targets.insert(
std::make_pair(target_id, TapeTensor(tensor)));
}
}
TF_RETURN_IF_ERROR(GradientTape::ComputeGradient(
vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets,
output_gradients, result, /*build_default_zeros_grads*/ false));
return Status::OK();
}
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to facilitate testing and are subject to change.
@ -398,12 +468,6 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
ForwardOperation* forward_op_, Tape* tape,
const GradientRegistry& registry) {
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
std::vector<int64> input_ids(forward_op_->inputs.size());
std::vector<tensorflow::DataType> input_dtypes(forward_op_->inputs.size());
for (int i = 0; i < forward_op_->inputs.size(); i++) {
input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType();
}
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]);
@ -413,25 +477,10 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_->attrs.BuildNodeDef();
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t));
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry.Lookup(*forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return backward_fn.release();
},
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}
});
std::unique_ptr<GradientFunction> gradient_fn;
TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn));
tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(),
op_->Name());
return Status::OK();
}
} // namespace internal

View File

@ -33,10 +33,11 @@ namespace gradients {
// public:
// Status Compute(Context* ctx,
// absl::Span<AbstractTensorHandle* const> grad_inputs,
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
// grad_outputs->resize(2);
// (*grad_outputs)[0] = grad_inputs[0];
// (*grad_outputs)[1] = grad_inputs[0];
// absl::Span<AbstractTensorHandle*> grad_outputs) override {
// grad_outputs[0] = grad_inputs[0];
// grad_outputs[1] = grad_inputs[0];
// grad_outputs[0]->Ref();
// grad_outputs[1]->Ref();
// return Status::OK();
// }
// ~AddGradientFunction() override {}
@ -51,123 +52,41 @@ namespace gradients {
// Status RegisterGradients(GradientRegistry* registry) {
// return registry->Register("Add", AddRegisterer);
// }
struct Context {
public:
AbstractContext* ctx;
};
class IncomingGradients {
public:
virtual AbstractTensorHandle* operator[](int i) const = 0;
virtual size_t size() const = 0;
virtual ~IncomingGradients() {}
};
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) = 0;
virtual ~GradientFunction() {}
};
// Metadata from the forward operation that is made available to the
// gradient registerer to instantiate a BackwardFunction.
// gradient registerer to instantiate a GradientFunction.
struct ForwardOperation {
public:
string op_name;
std::vector<AbstractTensorHandle*> inputs;
std::vector<AbstractTensorHandle*> outputs;
std::vector<int64> skip_input_indices;
AttrBuilder attrs;
};
// Interface for building default zeros gradients for op outputs which are
// missing incoming gradients. Custom implementations of this can be used to
// control which of the forward op's output tensors/their metadata needs to
// be kept around in memory to build the default zeros grad.
//
// Some common helper implementations are provided below.
class DefaultGradientFunction {
public:
virtual AbstractTensorHandle* get(
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) = 0;
virtual ~DefaultGradientFunction() {}
};
using GradientFunctionFactory =
std::function<GradientFunction*(const ForwardOperation& op)>;
// Returns zeros for any `nullptr` in `grad_inputs`.
//
// This may require keeping track of all of forward op's output
// tensors and hence may incur a higher memory footprint. Use sparingly.
//
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
// handle.
//
// The destructor of this class `Unref`'s any cached tensor handles so users of
// those tensor handles should `Ref` them in order to keep them alive if needed.
class AllZerosDefaultGradients : public DefaultGradientFunction {
public:
explicit AllZerosDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
private:
// TODO(srbs): We do not always need to keep the tensors around. In immediate
// execution mode we just need to store the shape and dtype. During tracing
// we may need to keep the tensor around if the shape is not full defined.
std::vector<AbstractTensorHandle*> outputs_;
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
};
// Passes through `grad_inputs` as-is. The `GradientFunction`
// will be expected to deal with nullptr in `grad_inputs` if any.
class PassThroughDefaultGradients : public DefaultGradientFunction {
public:
explicit PassThroughDefaultGradients(const ForwardOperation& op);
AbstractTensorHandle* get(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
int i) override;
};
// A `BackwardFunction` wraps a `GradientFunction` and a
// `DefaultGradientFunction`. Both are owned by this class' instance.
class BackwardFunction {
public:
BackwardFunction(GradientFunction* gradient_function,
DefaultGradientFunction* default_gradients)
: gradient_function_(gradient_function),
default_gradients_(default_gradients) {}
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
DefaultGradientFunction* GetDefaultGradientFunction() {
return default_gradients_.get();
}
private:
std::unique_ptr<GradientFunction> gradient_function_;
std::unique_ptr<DefaultGradientFunction> default_gradients_;
};
using BackwardFunctionFactory =
std::function<BackwardFunction*(const ForwardOperation& op)>;
// Map from op name to a `BackwardFunctionFactory`.
// Map from op name to a `GradientFunctionFactory`.
class GradientRegistry {
public:
Status Register(const string& op,
BackwardFunctionFactory backward_function_factory);
GradientFunctionFactory gradient_function_factory);
Status Lookup(const ForwardOperation& op,
std::unique_ptr<BackwardFunction>* backward_function) const;
std::unique_ptr<GradientFunction>* gradient_function) const;
private:
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
};
// Returns a unique id for the tensor which is used by the tape to build
// the gradient graph. See documentation of `TapeTensor` for more details.
int64 ToId(AbstractTensorHandle* t);
// TODO(srbs): Figure out if we can avoid declaring this in the public header.
// Wrapper for a tensor output of an operation executing under a tape.
//
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
@ -203,59 +122,53 @@ class TapeTensor {
AbstractTensorHandle* handle_;
};
// Vector space for actually computing gradients. Implements methods for calling
// the backward function with incoming gradients and returning the outgoing
// gradient and for performing gradient aggregation.
// See `tensorflow::eager::VSpace` for more details.
class TapeVSpace
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
public:
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
~TapeVSpace() override {}
// Returns the number of elements in the gradient tensor.
int64 NumElements(AbstractTensorHandle* tensor) const override;
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
// Calls the passed-in backward function.
Status CallBackwardFunction(
BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const override;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
Status BuildOnesLike(const TapeTensor& t,
AbstractTensorHandle** result) const override;
// Looks up the ID of a Gradient.
int64 TensorId(AbstractTensorHandle* tensor) const override;
// Converts a Gradient to a TapeTensor.
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
void MarkAsResult(AbstractTensorHandle* gradient) const override;
void DeleteGradient(AbstractTensorHandle* gradient) const override;
private:
// The context where the aggregation op `Add` is to be created.
AbstractContext* ctx_;
};
// A tracing/immediate-execution agnostic tape.
//
// Gradient functions defined for this library support handling null incoming
// gradients. `Tape::ComputeGradient` should be called with
// `build_default_zeros_grads=false`. Calling with
// `build_default_zeros_grads=true` (the default) is equivalent but just results
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
BackwardFunction, TapeTensor>;
// Gradient functions defined for this tape must support handling null incoming
// gradients.
class Tape : protected eager::GradientTape<AbstractTensorHandle,
GradientFunction, TapeTensor> {
public:
using GradientTape<AbstractTensorHandle, GradientFunction,
TapeTensor>::GradientTape;
// Returns whether the tape is persistent, i.e., whether the tape will hold
// onto its internal state after a call to `ComputeGradient`.
using GradientTape<AbstractTensorHandle, GradientFunction,
TapeTensor>::IsPersistent;
// Adds this tensor to the list of watched tensors.
//
// This is a no-op if the tensor is already being watched either from an
// earlier call to `GradientTape::Watch` or being an output of an op with
// watched inputs.
void Watch(const AbstractTensorHandle*);
// Records an operation with given inputs and outputs
// on the tape and marks all its outputs as watched if at
// least one input of the op is watched and has a trainable dtype.
// op_name is optional and is used for debugging only.
void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle* const> outputs,
GradientFunction* gradient_function,
const string& op_name = "");
// Returns whether any tensor in a list of tensors is being watched and has
// a trainable dtype.
bool ShouldRecord(
absl::Span<const AbstractTensorHandle* const> tensors) const;
// Unwatches this tensor on the tape. Mainly used for cleanup when deleting
// eager tensors.
void DeleteTrace(const AbstractTensorHandle*);
// Consumes the internal state of the tape (so cannot be called more than
// once unless the tape is persistent) and produces the gradient of the target
// tensors with respect to the source tensors. The output gradients are used
// if not empty and not null. The result is populated with one tensor per
// target element.
Status ComputeGradient(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
absl::Span<AbstractTensorHandle* const> sources,
absl::Span<AbstractTensorHandle* const> output_gradients,
absl::Span<AbstractTensorHandle*> result);
};
} // namespace gradients
} // namespace tensorflow

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/not_differentiable.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
@ -68,6 +69,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Log1p", Log1pRegisterer));
TF_RETURN_IF_ERROR(registry->Register("DivNoNan", DivNoNanRegisterer));
TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics"));
return Status::OK();
}
@ -80,30 +82,20 @@ Status AddGradModel(AbstractContext* ctx,
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
absl::MakeSpan(add_outputs),
"Add")); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto add_output : add_outputs) {
add_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
return Status::OK();
}
@ -116,26 +108,18 @@ Status ExpGradModel(AbstractContext* ctx,
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(inputs[0]); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/exp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
outputs[0] = out_grads[0];
return Status::OK();
}
@ -148,26 +132,18 @@ Status SqrtGradModel(AbstractContext* ctx,
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(inputs[0]); // Watch x.
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/sqrt_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto sqrt_output : sqrt_outputs) {
sqrt_output->Unref();
}
outputs[0] = out_grads[0];
return Status::OK();
}
@ -181,30 +157,21 @@ Status IdentityNGradModel(AbstractContext* ctx,
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
tape->Watch(ToId(inputs[1]));
tape->Watch(inputs[0]);
tape->Watch(inputs[1]);
vector<AbstractTensorHandle*> identity_n_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::IdentityN(
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
/*targets=*/{identity_n_outputs[1]},
/*sources=*/{inputs[0], inputs[1]},
/*output_gradients=*/{}, outputs));
for (auto identity_n_output : identity_n_outputs) {
identity_n_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
return Status::OK();
}
@ -217,27 +184,19 @@ Status NegGradModel(AbstractContext* ctx,
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0]));
tape->Watch(inputs[0]);
std::vector<AbstractTensorHandle*> neg_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/neg_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto neg_output : neg_outputs) {
neg_output->Unref();
}
outputs[0] = out_grads[0];
return Status::OK();
}
@ -250,30 +209,20 @@ Status SubGradModel(AbstractContext* ctx,
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> sub_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
absl::MakeSpan(sub_outputs),
"Sub")); // Compute x-y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/sub_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto sub_output : sub_outputs) {
sub_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
return Status::OK();
}
@ -286,30 +235,20 @@ Status MulGradModel(AbstractContext* ctx,
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> mul_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
absl::MakeSpan(mul_outputs),
"Mul")); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(mul_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mul_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto mul_output : mul_outputs) {
mul_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
@ -322,27 +261,19 @@ Status Log1pGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(inputs[0]); // Watch x.
std::vector<AbstractTensorHandle*> log1p_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
absl::MakeSpan(log1p_outputs),
"Log1p")); // Compute log(1 + x).
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(log1p_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/log1p_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto log1p_output : log1p_outputs) {
log1p_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
@ -355,30 +286,20 @@ Status DivNoNanGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> div_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::DivNoNan(tape_ctx.get(), inputs,
absl::MakeSpan(div_outputs),
"DivNoNan")); // Compute x / y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(div_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/div_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto div_output : div_outputs) {
div_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
@ -890,6 +811,8 @@ TEST_P(CppGradients, TestSetAttrString) {
int num_retvals = 1;
std::vector<AbstractTensorHandle*> outputs(1);
GradientRegistry registry;
s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto tape = std::make_unique<Tape>(/*persistent=*/false);
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
&num_retvals, &forward_op, tape.get(), registry);
@ -901,6 +824,52 @@ TEST_P(CppGradients, TestSetAttrString) {
ASSERT_EQ(read_message, message);
}
Status RecordOperationWithNullGradientFunctionModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> neg_outputs(1);
TF_RETURN_IF_ERROR(ops::Neg(ctx, inputs, absl::MakeSpan(neg_outputs), "Neg"));
tape.RecordOperation(inputs, neg_outputs, nullptr, "Neg");
return tape.ComputeGradient(ctx, /*targets=*/neg_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs);
}
TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(RecordOperationWithNullGradientFunctionModel, ctx.get(),
{x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_EQ(
"Provided null gradient_function for 'Neg'.\nIf the intent is to treat "
"this op as non-differentiable consider using RegisterNotDifferentiable "
"or NotDifferentiableGradientFunction.",
s.error_message());
ASSERT_EQ(nullptr, outputs[0]);
}
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
// supported. It is needed for IdentityN.
#ifdef PLATFORM_GOOGLE

View File

@ -47,29 +47,19 @@ Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto add_output : add_outputs) {
add_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
@ -81,10 +71,9 @@ Status MatMulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
@ -92,21 +81,12 @@ Status MatMulGradModel(AbstractContext* ctx,
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mm_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto mm_output : mm_outputs) {
mm_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
@ -138,39 +118,32 @@ Status MNISTForwardModel(AbstractContext* ctx,
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
TF_RETURN_IF_ERROR(ops::Relu(ctx, {temp_outputs[0]},
absl::MakeSpan(temp_outputs),
"relu")); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(ops::MatMul(
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
"matmul1",
ctx, {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
ctx, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss")); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0];
outputs[0] = scores;
outputs[1] = loss_vals;
delete tape;
return Status::OK();
}
@ -181,21 +154,9 @@ Status MatMulTransposeModel(AbstractContext* ctx,
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(X));
tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, outputs, "matmul0",
/*transpose_a=*/true,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
@ -203,30 +164,22 @@ Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X
tape->Watch(inputs[0]); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(relu_outputs),
"relu0")); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/relu_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto relu_output : relu_outputs) {
relu_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
@ -235,28 +188,19 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels.
tape->Watch(inputs[0]); // Watch scores.
tape->Watch(inputs[1]); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
/*targets=*/sm_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
@ -270,11 +214,10 @@ Status MNISTGradModel(AbstractContext* ctx,
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/true);
tape->Watch(ToId(X)); // Watch X.
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1.
tape->Watch(X); // Watch X.
tape->Watch(W1); // Watch W1.
tape->Watch(W2); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
@ -303,24 +246,14 @@ Status MNISTGradModel(AbstractContext* ctx,
AbstractTensorHandle* loss = temp_outputs[0];
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(
tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)},
/*source_tensor_ids=*/{ToId(W1), ToId(W2)},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/{loss},
/*sources=*/{W1, W2},
/*output_gradients=*/{},
outputs.subspan(0, 2)));
// Only release 2nd temp output as first holds loss values.
temp_outputs[1]->Unref();
outputs[0] = out_grads[0]; // dW1
outputs[1] = out_grads[1]; // dW2
outputs[2] = loss;
delete tape;
return Status::OK();
}
@ -329,83 +262,33 @@ Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* eta = inputs[0];
AbstractTensorHandle* A = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
absl::MakeSpan(temp_outputs),
"scalarMul0")); // Compute eta*A
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
return ops::Mul(ctx, inputs, outputs,
"scalarMul0"); // Compute eta*A
}
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
return ops::MatMul(ctx, inputs, outputs, "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false); // Compute X*W1
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* x = inputs[0];
AbstractTensorHandle* y = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
absl::MakeSpan(temp_outputs),
"mul0")); // Compute x*y
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
return ops::Mul(ctx, inputs, outputs,
"mul0"); // Compute x*y
}
Status SoftmaxModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* x = inputs[0];
AbstractTensorHandle* labels = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
outputs[0] = temp_outputs[0]; // loss values
delete tape;
return Status::OK();
return ops::SparseSoftmaxCrossEntropyWithLogits(ctx, inputs, outputs,
"sm_loss");
}
// ============================= End Models ================================

View File

@ -93,11 +93,14 @@ class VSpace {
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
// Calls the passed-in backward function.
//
// `unneeded_gradients` contains sorted list of input indices for which a
// gradient is not required.
virtual Status CallBackwardFunction(
BackwardFunction* backward_function,
const string& op_type, BackwardFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) const = 0;
absl::Span<Gradient*> result) const = 0;
// Builds a tensor filled with ones with the same shape and dtype as `t`.
virtual Status BuildOnesLike(const TapeTensor& t,
@ -133,11 +136,24 @@ class GradientTape {
}
}
// Returns whether any tensor in a list of tensors is being watched and has
// a trainable dtype.
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes);
gtl::ArraySlice<tensorflow::DataType> dtypes) const;
// Adds this tensor to the list of watched tensors.
//
// This is a no-op if the tensor is already being watched either from an
// earlier call to `GradientTape::Watch` or being an output of an op with
// watched inputs.
void Watch(int64 tensor_id);
// Records an operation with inputs `input_tensor_id` and outputs
// `output_tensors` on the tape and marks all its outputs as watched if at
// least one input of the op is watched and has trainable dtype.
//
// op_type is used to decide which of the incoming gradients can be left as
// nullptr instead of building zeros when build_default_zeros_grads == true.
void RecordOperation(
const string& op_type, const std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
@ -159,9 +175,10 @@ class GradientTape {
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
bool build_default_zeros_grads = true);
// Whether the tape is persistent. See ctor for detailed description.
bool IsPersistent() const { return persistent_; }
private:
@ -311,11 +328,10 @@ class ForwardAccumulator {
// function is running; this effectively adds the backward tape to the active
// set (but does not require complicated callbacks to the language bindings).
Status ForwardpropFromTape(
const std::vector<TapeTensor>& output_tensors,
const string& op_type, const std::vector<TapeTensor>& output_tensors,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter,
const std::vector<Gradient*>& in_grads,
std::vector<Gradient*>* out_grads);
const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads);
// Maps from tensor IDs to corresponding JVPs.
std::unordered_map<int64, Gradient*> accumulated_gradients_;
@ -368,7 +384,7 @@ inline bool IsDtypeTrainable(DataType dtype) {
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
gtl::ArraySlice<tensorflow::DataType> dtypes) const {
CHECK_EQ(tensor_ids.size(), dtypes.size());
for (int i = 0; i < tensor_ids.size(); ++i) {
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
@ -668,7 +684,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
const gtl::ArraySlice<int64> target_tensor_ids,
const gtl::ArraySlice<int64> source_tensor_ids,
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
bool build_default_zeros_grads) {
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
@ -757,23 +773,17 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
out_gradients.push_back(new_gradients);
}
}
std::vector<Gradient*> in_gradients;
VLOG(1) << "Calling gradient function for '" << trace.op_type << "'";
std::vector<Gradient*> in_gradients(trace.input_tensor_id.size());
DCHECK(build_default_zeros_grads || zero_indices.empty());
if (any_gradient_nonzero) {
for (const auto i : zero_indices) {
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
}
Status s;
s = vspace.CallBackwardFunction(trace.backward_function,
s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function,
unneeded_gradients, out_gradients,
&in_gradients);
if (in_gradients.size() != trace.input_tensor_id.size()) {
return tensorflow::errors::Internal(
"Recorded operation '", trace.op_type,
"' returned too few gradients. Expected ",
trace.input_tensor_id.size(), " but received ",
in_gradients.size());
}
absl::MakeSpan(in_gradients));
if (!persistent_) {
trace.backward_function_deleter(trace.backward_function);
}
@ -781,7 +791,6 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
return s;
}
} else {
in_gradients.resize(trace.input_tensor_id.size());
if (!persistent_) {
trace.backward_function_deleter(trace.backward_function);
}
@ -791,8 +800,6 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
}
}
}
VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
<< trace.input_tensor_id.size() << " sources";
for (int i = 0, end = in_gradients.size(); i < end; ++i) {
const int64 id = trace.input_tensor_id[i];
if (in_gradients[i] != nullptr) {
@ -856,20 +863,25 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
if (!state.op_tape.empty()) {
return tensorflow::errors::Internal("Invalid tape state.");
}
result->reserve(source_tensor_ids.size());
if (result.size() != source_tensor_ids.size()) {
return errors::Internal("Expected result Span to be of size ",
source_tensor_ids.size(), " found ", result.size(),
" in call to Tape::ComputeGradient.");
}
std::unordered_set<int64> used_gradient_ids(source_tensor_ids.size());
for (auto is : source_tensor_ids) {
auto grad_it = gradients.find(is);
for (int i = 0; i < source_tensor_ids.size(); i++) {
int64 tensor_id = source_tensor_ids[i];
auto grad_it = gradients.find(tensor_id);
if (grad_it == gradients.end()) {
result->push_back(nullptr);
result[i] = nullptr;
} else {
if (grad_it->second.size() > 1) {
Gradient* grad = vspace.AggregateGradients(grad_it->second);
grad_it->second.clear();
grad_it->second.push_back(grad);
}
result->push_back(grad_it->second[0]);
used_gradient_ids.insert(is);
result[i] = grad_it->second[0];
used_gradient_ids.insert(tensor_id);
}
}
VLOG(1) << "Final gradients size: "
@ -910,10 +922,10 @@ bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
Status
ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
const std::vector<TapeTensor>& output_tensors,
const string& op_type, const std::vector<TapeTensor>& output_tensors,
const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter,
const std::vector<Gradient*>& in_grads, std::vector<Gradient*>* out_grads) {
const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads) {
/* This function is approximately equivalent to this Python code:
forwardprop_aids = tf.ones_like(output_tensors)
@ -957,7 +969,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
sources_set.insert(aid_id);
tape->Watch(aid_id);
}
std::vector<Gradient*> grad;
std::vector<Gradient*> grad(in_grads.size());
auto delete_grad = gtl::MakeCleanup([&grad, this] {
for (Gradient* tensor : grad) {
this->vspace_.DeleteGradient(tensor);
@ -969,16 +981,13 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
backward_function(backward_function_getter(),
backward_function_deleter);
TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
op_type, backward_function.get(), unneeded_gradients, forwardprop_aids,
absl::MakeSpan(grad)));
}
// Stop the tape from recording
pop_backward_tape.release()();
if (grad.size() != in_grads.size()) {
return tensorflow::errors::Internal("Wrong number of gradients returned.");
}
std::vector<int64> targets;
std::vector<Gradient*> used_in_grads;
// We may end up with slightly fewer elements than we reserve, but grad.size()
@ -1076,9 +1085,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
if (forward_function == nullptr) {
// We have no special-cased forward gradient. Fall back to running the
// backward function under a gradient tape.
forward_grads.resize(output_tensors.size());
TF_RETURN_IF_ERROR(ForwardpropFromTape(
output_tensors, backward_function_getter, backward_function_deleter,
in_grads, &forward_grads));
op_type, output_tensors, backward_function_getter,
backward_function_deleter, in_grads, absl::MakeSpan(forward_grads)));
} else {
TF_RETURN_IF_ERROR(
(*forward_function)(in_grads, &forward_grads, use_batch_));

View File

@ -30,6 +30,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
@ -79,12 +80,29 @@ cc_library(
],
)
cc_library(
name = "not_differentiable",
srcs = ["not_differentiable.cc"],
hdrs = [
"not_differentiable.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:gradients_internal",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "gradients",
hdrs = [
"array_grad.h",
"math_grad.h",
"nn_grad.h",
"not_differentiable.h",
],
visibility = [
"//tensorflow:internal",
@ -93,6 +111,7 @@ cc_library(
":array_grad",
":math_grad",
":nn_grad",
":not_differentiable",
"//tensorflow/c/eager:gradients_internal",
],
)
@ -103,6 +122,7 @@ filegroup(
"array_grad.h",
"math_grad.h",
"nn_grad.h",
"not_differentiable.h",
],
visibility = [
"//tensorflow/core:__pkg__",

View File

@ -14,23 +14,24 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/eager/abstract_context.h"
namespace tensorflow {
namespace gradients {
namespace {
using std::vector;
class IdentityNGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(grad_inputs.size(), nullptr);
for (int i = 0; i < grad_inputs.size(); i++) {
auto grad_input = grad_inputs[i];
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
for (int i = 0; i < grad_outputs.size(); i++) {
auto grad_input = grad_outputs[i];
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
// that takes care of this similar to `Tensor`?
if (grad_input) {
grad_input->Ref();
}
(*grad_outputs)[i] = grad_input;
grad_inputs[i] = grad_input;
}
return Status::OK();
}
@ -38,10 +39,8 @@ class IdentityNGradientFunction : public GradientFunction {
};
} // namespace
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
auto gradient_function = new IdentityNGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* IdentityNRegisterer(const ForwardOperation& op) {
return new IdentityNGradientFunction;
}
} // namespace gradients

View File

@ -19,7 +19,7 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
GradientFunction* IdentityNRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -37,17 +37,17 @@ namespace {
class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
// TODO(b/161805092): Support broadcasting.
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
(*grad_outputs)[1] = grad_inputs[0];
DCHECK(grad_outputs[0]);
grad_inputs[0] = grad_outputs[0];
grad_inputs[1] = grad_outputs[0];
(*grad_outputs)[0]->Ref();
(*grad_outputs)[1]->Ref();
grad_inputs[0]->Ref();
grad_inputs[1]->Ref();
return Status::OK();
}
~AddGradientFunction() override {}
@ -58,18 +58,18 @@ class ExpGradientFunction : public GradientFunction {
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
exp->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
vector<AbstractTensorHandle*> conj_outputs(1);
std::string name = "Conj_Exp_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()},
absl::MakeSpan(conj_outputs), name.c_str()));
TF_RETURN_IF_ERROR(
Conj(ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
grad_outputs->resize(1);
name = "Mul_Exp_Grad";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
TF_RETURN_IF_ERROR(Mul(ctx, {conj_outputs[0], grad_outputs[0]}, grad_inputs,
name.c_str()));
return Status::OK();
}
~ExpGradientFunction() override {}
@ -83,12 +83,12 @@ class SqrtGradientFunction : public GradientFunction {
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
sqrt->Ref();
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
std::string name = "Sqrt_Grad";
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
TF_RETURN_IF_ERROR(SqrtGrad(ctx, {sqrt_.get(), grad_outputs[0]},
absl::MakeSpan(grad_inputs), name.c_str()));
return Status::OK();
}
~SqrtGradientFunction() override {}
@ -101,10 +101,17 @@ class MatMulGradientFunction : public GradientFunction {
public:
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
AttrBuilder f_attrs)
: forward_inputs(f_inputs), forward_attrs(f_attrs) {}
: forward_inputs_(f_inputs), forward_attrs_(f_attrs) {
for (auto input : forward_inputs_) {
if (input) {
input->Ref();
}
}
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
/* Given upstream grad U and a matmul op A*B, the gradients are:
*
* dA = U * B.T
@ -112,29 +119,28 @@ class MatMulGradientFunction : public GradientFunction {
*
* where A.T means `transpose(A)`
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
grad_outputs->resize(2);
AbstractTensorHandle* upstream_grad = grad_outputs[0];
// Get transpose attrs
bool t_a;
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_a", &t_a));
TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_a", &t_a));
bool t_b;
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_b", &t_b));
TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_b", &t_b));
// Conj each input
vector<AbstractTensorHandle*> conj_outputs(1);
std::string name = "Conj_A_MatMul_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]},
TF_RETURN_IF_ERROR(Conj(ctx, {forward_inputs_[0]},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandle* A = conj_outputs[0];
AbstractTensorHandlePtr A(conj_outputs[0]);
name = "Conj_B_MatMul_Grad";
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]},
TF_RETURN_IF_ERROR(Conj(ctx, {forward_inputs_[1]},
absl::MakeSpan(conj_outputs), name.c_str()));
AbstractTensorHandle* B = conj_outputs[0];
AbstractTensorHandlePtr B(conj_outputs[0]);
// Calc Grad
vector<AbstractTensorHandle*> matmul_A_outputs(1);
@ -142,50 +148,50 @@ class MatMulGradientFunction : public GradientFunction {
std::string name_grad_A = "MatMul_Grad_A";
std::string name_grad_B = "MatMul_Grad_B";
if (!t_a && !t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, B.get()},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
TF_RETURN_IF_ERROR(MatMul(ctx, {A.get(), upstream_grad},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ false));
} else if (!t_a && t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, B.get()},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ false));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, A.get()},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ false));
} else if (t_a && !t_b) {
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
TF_RETURN_IF_ERROR(MatMul(ctx, {B.get(), upstream_grad},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
TF_RETURN_IF_ERROR(MatMul(ctx, {A.get(), upstream_grad},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ false,
/*transpose_b = */ false));
} else { // t_a && t_b
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
TF_RETURN_IF_ERROR(MatMul(ctx, {B.get(), upstream_grad},
absl::MakeSpan(matmul_A_outputs),
name_grad_A.c_str(),
/*transpose_a = */ true,
/*transpose_b = */ true));
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, A.get()},
absl::MakeSpan(matmul_B_outputs),
name_grad_B.c_str(),
/*transpose_a = */ true,
@ -193,33 +199,40 @@ class MatMulGradientFunction : public GradientFunction {
}
// Gradient for A
(*grad_outputs)[0] = matmul_A_outputs[0];
grad_inputs[0] = matmul_A_outputs[0];
// Gradient for B
(*grad_outputs)[1] = matmul_B_outputs[0];
grad_inputs[1] = matmul_B_outputs[0];
return Status::OK();
}
~MatMulGradientFunction() override {}
~MatMulGradientFunction() override {
for (auto input : forward_inputs_) {
if (input) {
input->Unref();
}
}
}
private:
vector<AbstractTensorHandle*> forward_inputs;
AttrBuilder forward_attrs;
// TODO(b/174778737): Only hold needed inputs.
vector<AbstractTensorHandle*> forward_inputs_;
AttrBuilder forward_attrs_;
};
class NegGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
/* Given upstream grad U and a Neg op Y = -X, the gradients are:
*
* dX = -U
*
*/
grad_outputs->resize(1);
std::string name = "Neg_Grad";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(*grad_outputs), name.c_str()));
TF_RETURN_IF_ERROR(
ops::Neg(ctx, {grad_outputs[0]}, grad_inputs, name.c_str()));
return Status::OK();
}
~NegGradientFunction() override {}
@ -227,8 +240,9 @@ class NegGradientFunction : public GradientFunction {
class SubGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
/* Given upstream grad U and a Sub op A-B, the gradients are:
*
* dA = U
@ -236,20 +250,16 @@ class SubGradientFunction : public GradientFunction {
*
*/
grad_outputs->resize(2);
// Grad for A
DCHECK(grad_inputs[0]);
(*grad_outputs)[0] = grad_inputs[0];
(*grad_outputs)[0]->Ref();
DCHECK(grad_outputs[0]);
grad_inputs[0] = grad_outputs[0];
grad_inputs[0]->Ref();
// Grad for B
// negate the upstream grad
std::vector<AbstractTensorHandle*> neg_outputs(1);
std::string name = "Neg_Sub_Grad_B";
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(neg_outputs), name.c_str()));
(*grad_outputs)[1] = neg_outputs[0];
TF_RETURN_IF_ERROR(ops::Neg(ctx, {grad_outputs[0]},
grad_inputs.subspan(1, 1), name.c_str()));
return Status::OK();
}
@ -259,10 +269,17 @@ class SubGradientFunction : public GradientFunction {
class MulGradientFunction : public GradientFunction {
public:
explicit MulGradientFunction(vector<AbstractTensorHandle*> f_inputs)
: forward_inputs(f_inputs) {}
: forward_inputs_(f_inputs) {
for (auto input : forward_inputs_) {
if (input) {
input->Ref();
}
}
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
/* Given upstream grad U and a mul op A*B, the gradients are:
*
* dA = U * B
@ -270,36 +287,46 @@ class MulGradientFunction : public GradientFunction {
*
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> mul_outputs(1);
AbstractTensorHandle* upstream_grad = grad_outputs[0];
// Gradient for A
std::string name = "Mul_Grad_A";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {upstream_grad, forward_inputs[1]},
absl::MakeSpan(mul_outputs), name.c_str()));
(*grad_outputs)[0] = mul_outputs[0];
TF_RETURN_IF_ERROR(Mul(ctx, {upstream_grad, forward_inputs_[1]},
grad_inputs.subspan(0, 1), name.c_str()));
// Gradient for B
name = "Mul_Grad_B";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {forward_inputs[0], upstream_grad},
absl::MakeSpan(mul_outputs), name.c_str()));
(*grad_outputs)[1] = mul_outputs[0];
TF_RETURN_IF_ERROR(Mul(ctx, {forward_inputs_[0], upstream_grad},
grad_inputs.subspan(1, 1), name.c_str()));
return Status::OK();
}
~MulGradientFunction() override {}
~MulGradientFunction() override {
for (auto input : forward_inputs_) {
if (input) {
input->Unref();
}
}
}
private:
vector<AbstractTensorHandle*> forward_inputs;
// TODO(b/174778737): Only hold needed inputs.
vector<AbstractTensorHandle*> forward_inputs_;
};
class Log1pGradientFunction : public GradientFunction {
public:
explicit Log1pGradientFunction(vector<AbstractTensorHandle*> f_inputs)
: forward_inputs(f_inputs) {}
: forward_inputs_(f_inputs) {
for (auto input : forward_inputs_) {
if (input) {
input->Ref();
}
}
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
// TODO(vnvo2409): Add control dependency
/* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are:
*
@ -307,56 +334,72 @@ class Log1pGradientFunction : public GradientFunction {
*
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* X = forward_inputs[0];
AbstractTensorHandle* upstream_grad = grad_outputs[0];
AbstractTensorHandle* X = forward_inputs_[0];
grad_outputs->resize(1);
vector<AbstractTensorHandle*> temp_outputs(1);
// Calculate conjugate of X
std::string name = "Conj_Log1p_Grad_X";
TF_RETURN_IF_ERROR(
Conj(ctx->ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str()));
Conj(ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str()));
AbstractTensorHandle* Conj_X = temp_outputs[0];
AbstractTensorHandlePtr Conj_X(temp_outputs[0]);
// Creates Ones
name = "OnesLike_Log1p_Grad_X";
TF_RETURN_IF_ERROR(OnesLike(ctx->ctx, {Conj_X},
TF_RETURN_IF_ERROR(OnesLike(ctx, {Conj_X.get()},
absl::MakeSpan(temp_outputs), name.c_str()));
AbstractTensorHandle* Ones_X = temp_outputs[0];
AbstractTensorHandlePtr Ones_X(temp_outputs[0]);
name = "Add_Log1p_Grad_X";
// Calculate 1 + Conj(X)
TF_RETURN_IF_ERROR(Add(ctx->ctx, {Ones_X, Conj_X},
TF_RETURN_IF_ERROR(Add(ctx, {Ones_X.get(), Conj_X.get()},
absl::MakeSpan(temp_outputs), name.c_str()));
AbstractTensorHandle* Conj_XP1 = temp_outputs[0];
AbstractTensorHandlePtr Conj_XP1(temp_outputs[0]);
name = "Div_Log1p_Grad_X";
// Calculate U / (1 + Conj(X))
TF_RETURN_IF_ERROR(Div(ctx->ctx, {upstream_grad, Conj_XP1},
absl::MakeSpan(temp_outputs), name.c_str()));
(*grad_outputs)[0] = temp_outputs[0];
TF_RETURN_IF_ERROR(
Div(ctx, {upstream_grad, Conj_XP1.get()}, grad_inputs, name.c_str()));
return Status::OK();
}
~Log1pGradientFunction() override {}
~Log1pGradientFunction() override {
for (auto input : forward_inputs_) {
if (input) {
input->Unref();
}
}
}
private:
vector<AbstractTensorHandle*> forward_inputs;
// TODO(b/174778737): Only hold needed inputs.
vector<AbstractTensorHandle*> forward_inputs_;
};
class DivNoNanGradientFunction : public GradientFunction {
public:
explicit DivNoNanGradientFunction(vector<AbstractTensorHandle*> f_inputs,
vector<AbstractTensorHandle*> f_outputs)
: forward_inputs(f_inputs), forward_outputs(f_outputs) {}
: forward_inputs_(f_inputs), forward_outputs_(f_outputs) {
for (auto input : forward_inputs_) {
if (input) {
input->Ref();
}
}
for (auto output : forward_outputs_) {
if (output) {
output->Ref();
}
}
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
// TODO(vnvo2409): Add shape broadcasting
/* Given upstream grad U and a Div op: Z = X/Y, the gradients are:
*
@ -365,126 +408,88 @@ class DivNoNanGradientFunction : public GradientFunction {
*
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* Y = forward_inputs[1];
AbstractTensorHandle* Z = forward_outputs[0];
grad_outputs->resize(2);
vector<AbstractTensorHandle*> temp_outputs(1);
AbstractTensorHandle* upstream_grad = grad_outputs[0];
AbstractTensorHandle* Y = forward_inputs_[1];
AbstractTensorHandle* Z = forward_outputs_[0];
// Calculate dX = U / Y
std::string name = "Div_Grad_X";
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {upstream_grad, Y},
absl::MakeSpan(temp_outputs), name.c_str()));
(*grad_outputs)[0] = temp_outputs[0];
TF_RETURN_IF_ERROR(DivNoNan(ctx, {upstream_grad, Y},
grad_inputs.subspan(0, 1), name.c_str()));
vector<AbstractTensorHandle*> temp_outputs(1);
// Calculate dY = -U*Z / Y
name = "Neg_Div_Grad_Y";
TF_RETURN_IF_ERROR(Neg(ctx->ctx, {upstream_grad},
absl::MakeSpan(temp_outputs), name.c_str())); // -U
AbstractTensorHandle* MinusU = temp_outputs[0];
TF_RETURN_IF_ERROR(Neg(ctx, {upstream_grad}, absl::MakeSpan(temp_outputs),
name.c_str())); // -U
AbstractTensorHandlePtr MinusU(temp_outputs[0]);
name = "Mul_Div_Grad_Y";
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {MinusU, Z}, absl::MakeSpan(temp_outputs),
TF_RETURN_IF_ERROR(Mul(ctx, {MinusU.get(), Z}, absl::MakeSpan(temp_outputs),
name.c_str())); // -U*Z
AbstractTensorHandle* UZ = temp_outputs[0];
AbstractTensorHandlePtr UZ(temp_outputs[0]);
name = "Div_Grad_Y";
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {UZ, Y}, absl::MakeSpan(temp_outputs),
TF_RETURN_IF_ERROR(DivNoNan(ctx, {UZ.get(), Y}, grad_inputs.subspan(1, 1),
name.c_str())); // -U*Z / Y
(*grad_outputs)[1] = temp_outputs[0];
return Status::OK();
}
~DivNoNanGradientFunction() override {}
~DivNoNanGradientFunction() override {
for (auto input : forward_inputs_) {
if (input) {
input->Unref();
}
}
for (auto output : forward_outputs_) {
if (output) {
output->Unref();
}
}
}
private:
vector<AbstractTensorHandle*> forward_inputs;
vector<AbstractTensorHandle*> forward_outputs;
// TODO(b/174778737): Only hold needed inputs and outputs.
vector<AbstractTensorHandle*> forward_inputs_;
vector<AbstractTensorHandle*> forward_outputs_;
};
} // namespace
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
auto gradient_function = new AddGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
}
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
return new ExpGradientFunction(op.outputs[0]);
}
BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* MatMulRegisterer(const ForwardOperation& op) {
return new MatMulGradientFunction(op.inputs, op.attrs);
}
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* SqrtRegisterer(const ForwardOperation& op) {
return new SqrtGradientFunction(op.outputs[0]);
}
BackwardFunction* NegRegisterer(const ForwardOperation& op) {
auto gradient_function = new NegGradientFunction;
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* NegRegisterer(const ForwardOperation& op) {
return new NegGradientFunction;
}
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new SubGradientFunction;
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* SubRegisterer(const ForwardOperation& op) {
return new SubGradientFunction;
}
BackwardFunction* MulRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new MulGradientFunction(op.inputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* MulRegisterer(const ForwardOperation& op) {
return new MulGradientFunction(op.inputs);
}
BackwardFunction* Log1pRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new Log1pGradientFunction(op.inputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* Log1pRegisterer(const ForwardOperation& op) {
return new Log1pGradientFunction(op.inputs);
}
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new DivNoNanGradientFunction(op.inputs, op.outputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* DivNoNanRegisterer(const ForwardOperation& op) {
return new DivNoNanGradientFunction(op.inputs, op.outputs);
}
} // namespace gradients

View File

@ -20,15 +20,15 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
BackwardFunction* AddRegisterer(const ForwardOperation& op);
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
BackwardFunction* NegRegisterer(const ForwardOperation& op);
BackwardFunction* SubRegisterer(const ForwardOperation& op);
BackwardFunction* MulRegisterer(const ForwardOperation& op);
BackwardFunction* Log1pRegisterer(const ForwardOperation& op);
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op);
GradientFunction* AddRegisterer(const ForwardOperation& op);
GradientFunction* ExpRegisterer(const ForwardOperation& op);
GradientFunction* MatMulRegisterer(const ForwardOperation& op);
GradientFunction* SqrtRegisterer(const ForwardOperation& op);
GradientFunction* NegRegisterer(const ForwardOperation& op);
GradientFunction* SubRegisterer(const ForwardOperation& op);
GradientFunction* MulRegisterer(const ForwardOperation& op);
GradientFunction* Log1pRegisterer(const ForwardOperation& op);
GradientFunction* DivNoNanRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -36,29 +36,37 @@ namespace {
class ReluGradientFunction : public GradientFunction {
public:
explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
: forward_outputs(f_outputs) {}
: forward_outputs_(f_outputs) {
for (auto output : forward_outputs_) {
if (output) {
output->Ref();
}
}
}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* activations = forward_outputs[0];
grad_outputs->resize(1);
vector<AbstractTensorHandle*> relugrad_outputs(1);
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
AbstractTensorHandle* upstream_grad = grad_outputs[0];
AbstractTensorHandle* activations = forward_outputs_[0];
// Calculate Grad
std::string name = "relu_grad";
TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations},
absl::MakeSpan(relugrad_outputs),
name.c_str()));
(*grad_outputs)[0] = relugrad_outputs[0];
TF_RETURN_IF_ERROR(
ReluGrad(ctx, {upstream_grad, activations}, grad_inputs, name.c_str()));
return Status::OK();
}
~ReluGradientFunction() override {}
~ReluGradientFunction() override {
for (auto output : forward_outputs_) {
if (output) {
output->Unref();
}
}
}
private:
vector<AbstractTensorHandle*> forward_outputs;
// TODO(b/174778737): Only hold needed outputs.
vector<AbstractTensorHandle*> forward_outputs_;
};
Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec,
@ -87,98 +95,79 @@ class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
public:
explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction(
vector<AbstractTensorHandle*> f_outputs)
: forward_outputs(f_outputs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
: forward_outputs_(f_outputs) {}
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
// Grad for Softmax Input
vector<AbstractTensorHandle*> mul_outputs(1);
TF_RETURN_IF_ERROR(BroadcastMul(
ctx->ctx, grad_inputs[0], forward_outputs[1],
absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad
(*grad_outputs)[0] = mul_outputs[0];
ctx, grad_outputs[0], forward_outputs_[1],
grad_inputs.subspan(0, 1))); // upstream_grad * local softmax grad
// Grad for labels is null
(*grad_outputs)[1] = nullptr;
grad_inputs[1] = nullptr;
return Status::OK();
}
~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {}
private:
vector<AbstractTensorHandle*> forward_outputs;
vector<AbstractTensorHandle*> forward_outputs_;
};
// TODO(vnvo2409): Add python test
class BiasAddGradientFunction : public GradientFunction {
public:
explicit BiasAddGradientFunction(AttrBuilder f_attrs)
: forward_attrs(f_attrs) {}
: forward_attrs_(f_attrs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override {
/* Given upstream grad U and a BiasAdd: A + bias, the gradients are:
*
* dA = U
* dbias = reduceSum(U, dims = channel_dim)
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
AbstractTensorHandle* upstream_grad = grad_outputs[0];
DCHECK(upstream_grad);
grad_outputs->resize(2);
// Recover data format from forward pass for gradient.
std::string data_format;
TF_RETURN_IF_ERROR(forward_attrs.Get("data_format", &data_format));
TF_RETURN_IF_ERROR(forward_attrs_.Get("data_format", &data_format));
// Grad for A
(*grad_outputs)[0] = upstream_grad;
(*grad_outputs)[0]->Ref();
grad_inputs[0] = upstream_grad;
grad_inputs[0]->Ref();
// Grad for bias
vector<AbstractTensorHandle*> bias_add_grad_outputs(1);
std::string name = "bias_add_grad";
TF_RETURN_IF_ERROR(BiasAddGrad(ctx->ctx, {upstream_grad},
absl::MakeSpan(bias_add_grad_outputs),
TF_RETURN_IF_ERROR(BiasAddGrad(ctx, {upstream_grad},
grad_inputs.subspan(1, 1),
data_format.c_str(), name.c_str()));
(*grad_outputs)[1] = bias_add_grad_outputs[0];
return Status::OK();
}
~BiasAddGradientFunction() override {}
private:
AttrBuilder forward_attrs;
AttrBuilder forward_attrs_;
};
} // namespace
BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
auto gradient_function = new ReluGradientFunction(op.outputs);
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* ReluRegisterer(const ForwardOperation& op) {
return new ReluGradientFunction(op.outputs);
}
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op) {
auto gradient_function =
new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
return new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
}
BackwardFunction* BiasAddRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new BiasAddGradientFunction(op.attrs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
GradientFunction* BiasAddRegisterer(const ForwardOperation& op) {
return new BiasAddGradientFunction(op.attrs);
}
} // namespace gradients

View File

@ -19,10 +19,10 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
GradientFunction* ReluRegisterer(const ForwardOperation& op);
GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op);
BackwardFunction* BiasAddRegisterer(const ForwardOperation& op);
GradientFunction* BiasAddRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -25,7 +25,6 @@ namespace internal {
namespace {
using tensorflow::TF_StatusPtr;
using tracing::TracingOperation;
Status BiasAddModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
@ -38,30 +37,20 @@ Status BiasAddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch A.
tape->Watch(ToId(inputs[1])); // Watch Bias.
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]); // Watch A.
tape.Watch(inputs[1]); // Watch Bias.
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "BiasAddGrad"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(temp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/not_differentiable.h"
namespace tensorflow {
namespace gradients {
Status NotDifferentiableGradientFunction::Compute(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) {
for (int i = 0; i < grad_inputs.size(); i++) {
grad_inputs[i] = nullptr;
}
return Status::OK();
}
Status RegisterNotDifferentiable(GradientRegistry* registry, const string& op) {
return registry->Register(op, [](const ForwardOperation& op) {
return new NotDifferentiableGradientFunction;
});
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,34 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
// Ignores `grad_outputs` and sets all entries in grad_inputs to nullptr.
class NotDifferentiableGradientFunction : public GradientFunction {
Status Compute(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> grad_outputs,
absl::Span<AbstractTensorHandle*> grad_inputs) override;
};
// Shorthand for registry->Register(op, new NotDifferentiableGradientFunction)
Status RegisterNotDifferentiable(GradientRegistry* registry, const string& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_

View File

@ -197,12 +197,6 @@ AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
std::vector<int64> input_ids(forward_op_.inputs.size());
std::vector<tensorflow::DataType> input_dtypes(forward_op_.inputs.size());
for (int i = 0; i < forward_op_.inputs.size(); i++) {
input_ids[i] = ToId(forward_op_.inputs[i]);
input_dtypes[i] = forward_op_.inputs[i]->DataType();
}
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_.outputs.push_back(retvals[i]);
@ -212,25 +206,11 @@ Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
// Consider getting rid of this and making the behavior between number types
// and string consistent.
forward_op_.attrs.BuildNodeDef();
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t));
}
tape_->RecordOperation(
parent_op_->Name(), tape_tensors, input_ids, input_dtypes,
[this]() -> BackwardFunction* {
std::unique_ptr<BackwardFunction> backward_fn;
Status s = registry_.Lookup(forward_op_, &backward_fn);
if (!s.ok()) {
return nullptr;
}
return backward_fn.release();
},
[](BackwardFunction* ptr) {
if (ptr) {
delete ptr;
}
});
// TODO(b/170307493): Populate skip_input_indices here.
std::unique_ptr<GradientFunction> backward_fn;
TF_RETURN_IF_ERROR(registry_.Lookup(forward_op_, &backward_fn));
tape_->RecordOperation(forward_op_.inputs, forward_op_.outputs,
backward_fn.release(), parent_op_->Name());
return Status::OK();
}

View File

@ -1327,10 +1327,10 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
}
tensorflow::Status CallBackwardFunction(
PyBackwardFunction* backward_function,
const string& op_type, PyBackwardFunction* backward_function,
const std::vector<tensorflow::int64>& unneeded_gradients,
tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
std::vector<PyObject*>* result) const final {
absl::Span<PyObject*> result) const final {
PyObject* grads = PyTuple_New(output_gradients.size());
for (int i = 0; i < output_gradients.size(); ++i) {
if (output_gradients[i] == nullptr) {
@ -1346,7 +1346,6 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
if (py_result == nullptr) {
return tensorflow::errors::Internal("gradient function threw exceptions");
}
result->clear();
PyObject* seq =
PySequence_Fast(py_result, "expected a sequence of gradients");
if (seq == nullptr) {
@ -1354,16 +1353,21 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
"gradient function did not return a list");
}
int len = PySequence_Fast_GET_SIZE(seq);
if (len != result.size()) {
return tensorflow::errors::Internal(
"Recorded operation '", op_type,
"' returned too few gradients. Expected ", result.size(),
" but received ", len);
}
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
VLOG(1) << "Gradient length is " << len;
result->reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = seq_array[i];
if (item == Py_None) {
result->push_back(nullptr);
result[i] = nullptr;
} else {
Py_INCREF(item);
result->push_back(item);
result[i] = item;
}
}
Py_DECREF(seq);
@ -2774,10 +2778,10 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
Py_INCREF(tensor);
}
}
std::vector<PyObject*> result;
std::vector<PyObject*> result(sources_vec.size());
status->status = tape_obj->tape->ComputeGradient(
*py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
outgrad_vec, &result);
outgrad_vec, absl::MakeSpan(result));
if (!status->status.ok()) {
if (PyErr_Occurred()) {
// Do not propagate the erroneous status as that would swallow the

View File

@ -47,35 +47,19 @@ Status RegisterGradients(GradientRegistry* registry) {
PYBIND11_MODULE(_tape, m) {
py::class_<Tape>(m, "Tape")
.def(py::init([](bool persistent) { return new Tape(persistent); }))
.def("Watch",
[](Tape* self, AbstractTensorHandle* t) { self->Watch(ToId(t)); })
.def("Watch", [](Tape* self, AbstractTensorHandle* t) { self->Watch(t); })
.def("ComputeGradient",
[](Tape* self, TapeVSpace* vspace,
[](Tape* self, AbstractContext* ctx,
std::vector<AbstractTensorHandle*> target_tensors,
std::vector<AbstractTensorHandle*> source_tensors,
std::vector<AbstractTensorHandle*> output_gradients) {
std::vector<int64> target_tensor_ids;
std::vector<int64> source_tensor_ids;
target_tensor_ids.reserve(target_tensors.size());
source_tensor_ids.reserve(source_tensors.size());
for (auto t : target_tensors) {
target_tensor_ids.emplace_back(ToId(t));
}
for (auto t : source_tensors) {
source_tensor_ids.emplace_back(ToId(t));
}
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> results;
Status s = self->ComputeGradient(
*vspace, target_tensor_ids, source_tensor_ids,
source_tensors_that_are_targets, output_gradients, &results,
/*build_default_zeros_grads=*/false);
std::vector<AbstractTensorHandle*> results(source_tensors.size());
Status s = self->ComputeGradient(ctx, target_tensors,
source_tensors, output_gradients,
absl::MakeSpan(results));
MaybeRaiseRegisteredFromStatus(s);
return results;
});
py::class_<TapeVSpace>(m, "TapeVSpace")
.def(py::init([](AbstractContext* ctx) { return new TapeVSpace(ctx); }));
py::class_<GradientRegistry>(m, "GradientRegistry").def(py::init([]() {
auto registry = new GradientRegistry();
MaybeRaiseRegisteredFromStatus(RegisterGradients(registry));

View File

@ -40,10 +40,9 @@ class GradientTape(object):
# TODO(srbs): Add support for unconnected_gradients.
def gradient(self, targets, sources, output_gradients=None):
ctx = context_stack.get_default()
vspace = _tape.TapeVSpace(ctx)
flat_targets = nest.flatten(targets)
flat_sources = nest.flatten(sources)
out_grads = self._c_tape.ComputeGradient(vspace, flat_targets, flat_sources,
out_grads = self._c_tape.ComputeGradient(ctx, flat_targets, flat_sources,
output_gradients or [])
return nest.pack_sequence_as(sources, out_grads)