Simplify C++ tape APIs to match the proposal in https://github.com/tensorflow/community/pull/335
- The main change is to get rid of int64 tensor ids from Tape and directly use AbstractTensorHandles. - Get rid of tensorflow::gradients::Context and directly use AbstractContext*. - Get rid of DefaultGradientFunction and BackwardFunction(which was a wrapper for a GradientFunction and DefaultGradientFunction). We had introduced DefaultGradientFunction in order to support existing python gradient functions which expect all necessary incoming grads to be non-None. This is only relevant for ops with more than one output, which are few. We could handle those by creating a wrapper GradientFunction that builds the zeros if needed. Getting rid of DefaultGradientFunction greatly simplifies the API. - Introduce ForwardOperation::skip_input_indices. This will be filled up in a follow-up change. There is a bug tracking this. - Introduce helpers for implementing behavior of tf.no_gradient and tf.stop_gradient, i.e. RegisterNotDifferentiable and NotDifferentiableGradientFunction. - One slight behavior change: Currently, when an op does not have a GradientFunction registered we silently record a nullptr GradientFunction on the tape. This sometimes leads to uninformative error messages. Now we loudly raise an error when GradientRegistry::Lookup fails in TapeContext::Execute. So any op executing under a TapeContext must have a registered GradientFunction. Non-differentiable ops need to be explicitly registered using RegisterNotDifferentiable e.g. CheckNumerics in gradients_test.cc c/eager/tape.h: I changed the signatures of gradient functions to use `absl::Span<Gradient*>` instead of `vector<Gradient*>*` for the result grads. This makes it consistent with the new Tape API and generally makes things cleaner. PiperOrigin-RevId: 345534016 Change-Id: Ie1bf5dff88f87390e6b470acc379d3852ce68b5c
This commit is contained in:
parent
9193f62b81
commit
4e25cac495
@ -217,12 +217,12 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -278,6 +278,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:array_grad",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:not_differentiable",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
@ -454,6 +455,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
@ -468,7 +470,6 @@ tf_cuda_cc_test(
|
||||
args = ["--heap_check=local"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"nomac",
|
||||
"no_cuda_asan", # b/173825513
|
||||
],
|
||||
deps = [
|
||||
@ -493,6 +494,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -20,11 +20,19 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(b/172558015): Using the pointer address as the identifier for the tensor
|
||||
// may lead to collisions. Introduce another way to get a unique id for this
|
||||
// tensor.
|
||||
int64 ToId(const AbstractTensorHandle* t) {
|
||||
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
|
||||
}
|
||||
|
||||
Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
|
||||
AbstractTensorHandle** result) {
|
||||
AbstractOperationPtr op(ctx->CreateOperation());
|
||||
@ -43,85 +51,28 @@ Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class IncomingGradientsImpl : public IncomingGradients {
|
||||
public:
|
||||
explicit IncomingGradientsImpl(
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs, Context* ctx,
|
||||
DefaultGradientFunction* default_gradients)
|
||||
: grad_inputs_(grad_inputs),
|
||||
ctx_(ctx),
|
||||
default_gradients_(default_gradients) {}
|
||||
AbstractTensorHandle* operator[](int i) const override {
|
||||
return default_gradients_->get(ctx_, grad_inputs_, i);
|
||||
}
|
||||
size_t size() const override { return grad_inputs_.size(); }
|
||||
|
||||
private:
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs_;
|
||||
Context* ctx_;
|
||||
DefaultGradientFunction* default_gradients_;
|
||||
};
|
||||
|
||||
AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op)
|
||||
: outputs_(op.outputs) {
|
||||
for (auto output : outputs_) {
|
||||
output->Ref();
|
||||
}
|
||||
}
|
||||
AbstractTensorHandle* AllZerosDefaultGradients::get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
|
||||
if (grad_inputs[i]) {
|
||||
return grad_inputs[i];
|
||||
}
|
||||
if (cached_default_grads_[i]) {
|
||||
return cached_default_grads_[i].get();
|
||||
}
|
||||
AbstractTensorHandle* result = nullptr;
|
||||
Status s = ZerosLike(ctx->ctx, outputs_[i], &result);
|
||||
if (!s.ok()) {
|
||||
if (result) {
|
||||
result->Unref();
|
||||
}
|
||||
VLOG(1) << "Failed to create ZerosLike for index " << i;
|
||||
return nullptr;
|
||||
}
|
||||
cached_default_grads_[i].reset(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
PassThroughDefaultGradients::PassThroughDefaultGradients(
|
||||
const ForwardOperation& op) {}
|
||||
AbstractTensorHandle* PassThroughDefaultGradients::get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs, int i) {
|
||||
return grad_inputs[i];
|
||||
}
|
||||
|
||||
Status GradientRegistry::Register(
|
||||
const string& op_name, BackwardFunctionFactory backward_function_factory) {
|
||||
const string& op_name, GradientFunctionFactory gradient_function_factory) {
|
||||
auto iter = registry_.find(op_name);
|
||||
if (iter != registry_.end()) {
|
||||
const string error_msg = "Gradient already exists for op: " + op_name + ".";
|
||||
return errors::AlreadyExists(error_msg);
|
||||
}
|
||||
registry_.insert({op_name, backward_function_factory});
|
||||
registry_.insert({op_name, gradient_function_factory});
|
||||
return Status::OK();
|
||||
}
|
||||
Status GradientRegistry::Lookup(
|
||||
const ForwardOperation& op,
|
||||
std::unique_ptr<BackwardFunction>* backward_function) const {
|
||||
std::unique_ptr<GradientFunction>* gradient_function) const {
|
||||
auto iter = registry_.find(op.op_name);
|
||||
if (iter == registry_.end()) {
|
||||
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
|
||||
return errors::NotFound(error_msg);
|
||||
}
|
||||
backward_function->reset(iter->second(op));
|
||||
gradient_function->reset(iter->second(op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 ToId(AbstractTensorHandle* t) {
|
||||
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
|
||||
}
|
||||
|
||||
TapeTensor::TapeTensor(AbstractTensorHandle* handle) : handle_(handle) {
|
||||
handle_->Ref();
|
||||
}
|
||||
@ -140,6 +91,47 @@ AbstractTensorHandle* TapeTensor::GetHandle() const { return handle_; }
|
||||
|
||||
AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; }
|
||||
|
||||
class TapeVSpace
|
||||
: public eager::VSpace<AbstractTensorHandle, GradientFunction, TapeTensor> {
|
||||
public:
|
||||
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
~TapeVSpace() override {}
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 NumElements(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Consumes references to the tensors in the gradient_tensors list and returns
|
||||
// a tensor with the result.
|
||||
AbstractTensorHandle* AggregateGradients(
|
||||
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
// op_type is the op's name provided in RecordOperation.
|
||||
Status CallBackwardFunction(
|
||||
const string& op_type, GradientFunction* gradient_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
absl::Span<AbstractTensorHandle*> result) const override;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
Status BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const override;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TensorId(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
|
||||
|
||||
void MarkAsResult(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
void DeleteGradient(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
private:
|
||||
// The context where the aggregation op `Add` is to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
|
||||
// TODO(srbs): It seems like this is used only for performance optimization
|
||||
@ -178,17 +170,20 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients(
|
||||
}
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
// op_type is the op's name provided in RecordOperation.
|
||||
Status TapeVSpace::CallBackwardFunction(
|
||||
BackwardFunction* backward_function,
|
||||
const string& op_type, GradientFunction* gradient_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const {
|
||||
if (backward_function == nullptr) return Status::OK();
|
||||
Context ctx = {ctx_};
|
||||
IncomingGradientsImpl incoming_gradients(
|
||||
output_gradients, &ctx, backward_function->GetDefaultGradientFunction());
|
||||
return backward_function->GetGradientFunction()->Compute(
|
||||
&ctx, incoming_gradients, result);
|
||||
absl::Span<AbstractTensorHandle*> result) const {
|
||||
if (gradient_function == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Provided null gradient_function for '", op_type, "'.\n",
|
||||
"If the intent is to treat this op as non-differentiable consider "
|
||||
"using RegisterNotDifferentiable or "
|
||||
"NotDifferentiableGradientFunction.");
|
||||
}
|
||||
return gradient_function->Compute(ctx_, output_gradients, result);
|
||||
}
|
||||
|
||||
Status TapeVSpace::BuildOnesLike(const TapeTensor& t,
|
||||
@ -224,6 +219,81 @@ void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
|
||||
gradient->Unref();
|
||||
}
|
||||
|
||||
void Tape::Watch(const AbstractTensorHandle* t) {
|
||||
GradientTape::Watch(ToId(t));
|
||||
}
|
||||
void Tape::RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle* const> outputs,
|
||||
GradientFunction* gradient_function,
|
||||
const string& op_name) {
|
||||
std::vector<int64> input_ids(inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(inputs.size());
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
input_ids[i] = ToId(inputs[i]);
|
||||
input_dtypes[i] = inputs[i]->DataType();
|
||||
}
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : outputs) {
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
GradientTape::RecordOperation(
|
||||
op_name, tape_tensors, input_ids, input_dtypes,
|
||||
[gradient_function]() -> GradientFunction* { return gradient_function; },
|
||||
[](GradientFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
}
|
||||
bool Tape::ShouldRecord(
|
||||
absl::Span<const AbstractTensorHandle* const> tensors) const {
|
||||
std::vector<int64> tensor_ids(tensors.size());
|
||||
std::vector<tensorflow::DataType> tensor_dtypes(tensors.size());
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
tensor_ids[i] = ToId(tensors[i]);
|
||||
tensor_dtypes[i] = tensors[i]->DataType();
|
||||
}
|
||||
return GradientTape::ShouldRecord(tensor_ids, tensor_dtypes);
|
||||
}
|
||||
void Tape::DeleteTrace(const AbstractTensorHandle* t) {
|
||||
GradientTape::DeleteTrace(ToId(t));
|
||||
}
|
||||
|
||||
std::vector<int64> MakeTensorIDList(
|
||||
absl::Span<AbstractTensorHandle* const> tensors) {
|
||||
std::vector<int64> ids(tensors.size());
|
||||
for (int i = 0; i < tensors.size(); i++) {
|
||||
ids[i] = ToId(tensors[i]);
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
Status Tape::ComputeGradient(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
|
||||
absl::Span<AbstractTensorHandle* const> sources,
|
||||
absl::Span<AbstractTensorHandle* const> output_gradients,
|
||||
absl::Span<AbstractTensorHandle*> result) {
|
||||
TapeVSpace vspace(ctx);
|
||||
std::vector<int64> target_tensor_ids = MakeTensorIDList(targets);
|
||||
std::vector<int64> source_tensor_ids = MakeTensorIDList(sources);
|
||||
tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(
|
||||
source_tensor_ids.begin(), source_tensor_ids.end());
|
||||
std::unordered_map<int64, TapeTensor> sources_that_are_targets;
|
||||
for (int i = 0; i < target_tensor_ids.size(); ++i) {
|
||||
int64 target_id = target_tensor_ids[i];
|
||||
if (sources_set.find(target_id) != sources_set.end()) {
|
||||
auto tensor = targets[i];
|
||||
sources_that_are_targets.insert(
|
||||
std::make_pair(target_id, TapeTensor(tensor)));
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(GradientTape::ComputeGradient(
|
||||
vspace, target_tensor_ids, source_tensor_ids, sources_that_are_targets,
|
||||
output_gradients, result, /*build_default_zeros_grads*/ false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper functions which delegate to `AbstractOperation`, update
|
||||
// the state of the ForwardOperation and call the tape as appropriate.
|
||||
// These APIs are mainly to facilitate testing and are subject to change.
|
||||
@ -398,12 +468,6 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
ForwardOperation* forward_op_, Tape* tape,
|
||||
const GradientRegistry& registry) {
|
||||
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
|
||||
std::vector<int64> input_ids(forward_op_->inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(forward_op_->inputs.size());
|
||||
for (int i = 0; i < forward_op_->inputs.size(); i++) {
|
||||
input_ids[i] = ToId(forward_op_->inputs[i]);
|
||||
input_dtypes[i] = forward_op_->inputs[i]->DataType();
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; i++) {
|
||||
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
|
||||
forward_op_->outputs.push_back(retvals[i]);
|
||||
@ -413,25 +477,10 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
|
||||
// Consider getting rid of this and making the behavior between number types
|
||||
// and string consistent.
|
||||
forward_op_->attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
tape->RecordOperation(
|
||||
op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[registry, forward_op_]() -> BackwardFunction* {
|
||||
std::unique_ptr<BackwardFunction> backward_fn;
|
||||
Status s = registry.Lookup(*forward_op_, &backward_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return backward_fn.release();
|
||||
},
|
||||
[](BackwardFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
std::unique_ptr<GradientFunction> gradient_fn;
|
||||
TF_RETURN_IF_ERROR(registry.Lookup(*forward_op_, &gradient_fn));
|
||||
tape->RecordOperation(forward_op_->inputs, retvals, gradient_fn.release(),
|
||||
op_->Name());
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace internal
|
||||
|
@ -33,10 +33,11 @@ namespace gradients {
|
||||
// public:
|
||||
// Status Compute(Context* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// grad_outputs->resize(2);
|
||||
// (*grad_outputs)[0] = grad_inputs[0];
|
||||
// (*grad_outputs)[1] = grad_inputs[0];
|
||||
// absl::Span<AbstractTensorHandle*> grad_outputs) override {
|
||||
// grad_outputs[0] = grad_inputs[0];
|
||||
// grad_outputs[1] = grad_inputs[0];
|
||||
// grad_outputs[0]->Ref();
|
||||
// grad_outputs[1]->Ref();
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~AddGradientFunction() override {}
|
||||
@ -51,123 +52,41 @@ namespace gradients {
|
||||
// Status RegisterGradients(GradientRegistry* registry) {
|
||||
// return registry->Register("Add", AddRegisterer);
|
||||
// }
|
||||
struct Context {
|
||||
public:
|
||||
AbstractContext* ctx;
|
||||
};
|
||||
|
||||
class IncomingGradients {
|
||||
public:
|
||||
virtual AbstractTensorHandle* operator[](int i) const = 0;
|
||||
virtual size_t size() const = 0;
|
||||
virtual ~IncomingGradients() {}
|
||||
};
|
||||
|
||||
class GradientFunction {
|
||||
public:
|
||||
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
|
||||
// `grad_inputs`.
|
||||
virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
|
||||
virtual Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) = 0;
|
||||
virtual ~GradientFunction() {}
|
||||
};
|
||||
|
||||
// Metadata from the forward operation that is made available to the
|
||||
// gradient registerer to instantiate a BackwardFunction.
|
||||
// gradient registerer to instantiate a GradientFunction.
|
||||
struct ForwardOperation {
|
||||
public:
|
||||
string op_name;
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
std::vector<AbstractTensorHandle*> outputs;
|
||||
std::vector<int64> skip_input_indices;
|
||||
AttrBuilder attrs;
|
||||
};
|
||||
|
||||
// Interface for building default zeros gradients for op outputs which are
|
||||
// missing incoming gradients. Custom implementations of this can be used to
|
||||
// control which of the forward op's output tensors/their metadata needs to
|
||||
// be kept around in memory to build the default zeros grad.
|
||||
//
|
||||
// Some common helper implementations are provided below.
|
||||
class DefaultGradientFunction {
|
||||
public:
|
||||
virtual AbstractTensorHandle* get(
|
||||
Context* ctx, absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) = 0;
|
||||
virtual ~DefaultGradientFunction() {}
|
||||
};
|
||||
using GradientFunctionFactory =
|
||||
std::function<GradientFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Returns zeros for any `nullptr` in `grad_inputs`.
|
||||
//
|
||||
// This may require keeping track of all of forward op's output
|
||||
// tensors and hence may incur a higher memory footprint. Use sparingly.
|
||||
//
|
||||
// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor
|
||||
// handle.
|
||||
//
|
||||
// The destructor of this class `Unref`'s any cached tensor handles so users of
|
||||
// those tensor handles should `Ref` them in order to keep them alive if needed.
|
||||
class AllZerosDefaultGradients : public DefaultGradientFunction {
|
||||
public:
|
||||
explicit AllZerosDefaultGradients(const ForwardOperation& op);
|
||||
AbstractTensorHandle* get(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) override;
|
||||
|
||||
private:
|
||||
// TODO(srbs): We do not always need to keep the tensors around. In immediate
|
||||
// execution mode we just need to store the shape and dtype. During tracing
|
||||
// we may need to keep the tensor around if the shape is not full defined.
|
||||
std::vector<AbstractTensorHandle*> outputs_;
|
||||
std::vector<AbstractTensorHandlePtr> cached_default_grads_;
|
||||
};
|
||||
|
||||
// Passes through `grad_inputs` as-is. The `GradientFunction`
|
||||
// will be expected to deal with nullptr in `grad_inputs` if any.
|
||||
class PassThroughDefaultGradients : public DefaultGradientFunction {
|
||||
public:
|
||||
explicit PassThroughDefaultGradients(const ForwardOperation& op);
|
||||
AbstractTensorHandle* get(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
int i) override;
|
||||
};
|
||||
|
||||
// A `BackwardFunction` wraps a `GradientFunction` and a
|
||||
// `DefaultGradientFunction`. Both are owned by this class' instance.
|
||||
class BackwardFunction {
|
||||
public:
|
||||
BackwardFunction(GradientFunction* gradient_function,
|
||||
DefaultGradientFunction* default_gradients)
|
||||
: gradient_function_(gradient_function),
|
||||
default_gradients_(default_gradients) {}
|
||||
GradientFunction* GetGradientFunction() { return gradient_function_.get(); }
|
||||
DefaultGradientFunction* GetDefaultGradientFunction() {
|
||||
return default_gradients_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<GradientFunction> gradient_function_;
|
||||
std::unique_ptr<DefaultGradientFunction> default_gradients_;
|
||||
};
|
||||
|
||||
using BackwardFunctionFactory =
|
||||
std::function<BackwardFunction*(const ForwardOperation& op)>;
|
||||
|
||||
// Map from op name to a `BackwardFunctionFactory`.
|
||||
// Map from op name to a `GradientFunctionFactory`.
|
||||
class GradientRegistry {
|
||||
public:
|
||||
Status Register(const string& op,
|
||||
BackwardFunctionFactory backward_function_factory);
|
||||
GradientFunctionFactory gradient_function_factory);
|
||||
Status Lookup(const ForwardOperation& op,
|
||||
std::unique_ptr<BackwardFunction>* backward_function) const;
|
||||
std::unique_ptr<GradientFunction>* gradient_function) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, BackwardFunctionFactory> registry_;
|
||||
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
|
||||
};
|
||||
|
||||
// Returns a unique id for the tensor which is used by the tape to build
|
||||
// the gradient graph. See documentation of `TapeTensor` for more details.
|
||||
int64 ToId(AbstractTensorHandle* t);
|
||||
|
||||
// TODO(srbs): Figure out if we can avoid declaring this in the public header.
|
||||
// Wrapper for a tensor output of an operation executing under a tape.
|
||||
//
|
||||
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
|
||||
@ -203,59 +122,53 @@ class TapeTensor {
|
||||
AbstractTensorHandle* handle_;
|
||||
};
|
||||
|
||||
// Vector space for actually computing gradients. Implements methods for calling
|
||||
// the backward function with incoming gradients and returning the outgoing
|
||||
// gradient and for performing gradient aggregation.
|
||||
// See `tensorflow::eager::VSpace` for more details.
|
||||
class TapeVSpace
|
||||
: public eager::VSpace<AbstractTensorHandle, BackwardFunction, TapeTensor> {
|
||||
public:
|
||||
explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
~TapeVSpace() override {}
|
||||
|
||||
// Returns the number of elements in the gradient tensor.
|
||||
int64 NumElements(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Consumes references to the tensors in the gradient_tensors list and returns
|
||||
// a tensor with the result.
|
||||
AbstractTensorHandle* AggregateGradients(
|
||||
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const override;
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
Status CallBackwardFunction(
|
||||
BackwardFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
|
||||
std::vector<AbstractTensorHandle*>* result) const override;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
Status BuildOnesLike(const TapeTensor& t,
|
||||
AbstractTensorHandle** result) const override;
|
||||
|
||||
// Looks up the ID of a Gradient.
|
||||
int64 TensorId(AbstractTensorHandle* tensor) const override;
|
||||
|
||||
// Converts a Gradient to a TapeTensor.
|
||||
TapeTensor TapeTensorFromGradient(AbstractTensorHandle* g) const override;
|
||||
|
||||
void MarkAsResult(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
void DeleteGradient(AbstractTensorHandle* gradient) const override;
|
||||
|
||||
private:
|
||||
// The context where the aggregation op `Add` is to be created.
|
||||
AbstractContext* ctx_;
|
||||
};
|
||||
|
||||
// A tracing/immediate-execution agnostic tape.
|
||||
//
|
||||
// Gradient functions defined for this library support handling null incoming
|
||||
// gradients. `Tape::ComputeGradient` should be called with
|
||||
// `build_default_zeros_grads=false`. Calling with
|
||||
// `build_default_zeros_grads=true` (the default) is equivalent but just results
|
||||
// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway.
|
||||
using Tape = tensorflow::eager::GradientTape<AbstractTensorHandle,
|
||||
BackwardFunction, TapeTensor>;
|
||||
// Gradient functions defined for this tape must support handling null incoming
|
||||
// gradients.
|
||||
class Tape : protected eager::GradientTape<AbstractTensorHandle,
|
||||
GradientFunction, TapeTensor> {
|
||||
public:
|
||||
using GradientTape<AbstractTensorHandle, GradientFunction,
|
||||
TapeTensor>::GradientTape;
|
||||
// Returns whether the tape is persistent, i.e., whether the tape will hold
|
||||
// onto its internal state after a call to `ComputeGradient`.
|
||||
using GradientTape<AbstractTensorHandle, GradientFunction,
|
||||
TapeTensor>::IsPersistent;
|
||||
|
||||
// Adds this tensor to the list of watched tensors.
|
||||
//
|
||||
// This is a no-op if the tensor is already being watched either from an
|
||||
// earlier call to `GradientTape::Watch` or being an output of an op with
|
||||
// watched inputs.
|
||||
void Watch(const AbstractTensorHandle*);
|
||||
// Records an operation with given inputs and outputs
|
||||
// on the tape and marks all its outputs as watched if at
|
||||
// least one input of the op is watched and has a trainable dtype.
|
||||
// op_name is optional and is used for debugging only.
|
||||
void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle* const> outputs,
|
||||
GradientFunction* gradient_function,
|
||||
const string& op_name = "");
|
||||
// Returns whether any tensor in a list of tensors is being watched and has
|
||||
// a trainable dtype.
|
||||
bool ShouldRecord(
|
||||
absl::Span<const AbstractTensorHandle* const> tensors) const;
|
||||
// Unwatches this tensor on the tape. Mainly used for cleanup when deleting
|
||||
// eager tensors.
|
||||
void DeleteTrace(const AbstractTensorHandle*);
|
||||
|
||||
// Consumes the internal state of the tape (so cannot be called more than
|
||||
// once unless the tape is persistent) and produces the gradient of the target
|
||||
// tensors with respect to the source tensors. The output gradients are used
|
||||
// if not empty and not null. The result is populated with one tensor per
|
||||
// target element.
|
||||
Status ComputeGradient(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
|
||||
absl::Span<AbstractTensorHandle* const> sources,
|
||||
absl::Span<AbstractTensorHandle* const> output_gradients,
|
||||
absl::Span<AbstractTensorHandle*> result);
|
||||
};
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/not_differentiable.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
@ -68,6 +69,7 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Mul", MulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Log1p", Log1pRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("DivNoNan", DivNoNanRegisterer));
|
||||
TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics"));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -80,30 +82,20 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(add_outputs),
|
||||
"Add")); // Compute x+y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -116,26 +108,18 @@ Status ExpGradModel(AbstractContext* ctx,
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> exp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Exp(tape_ctx.get(), inputs, absl::MakeSpan(exp_outputs), "Exp"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/exp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto exp_output : exp_outputs) {
|
||||
exp_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -148,26 +132,18 @@ Status SqrtGradModel(AbstractContext* ctx,
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> sqrt_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sqrt(tape_ctx.get(), inputs, absl::MakeSpan(sqrt_outputs), "Sqrt"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sqrt_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/sqrt_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto sqrt_output : sqrt_outputs) {
|
||||
sqrt_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -181,30 +157,21 @@ Status IdentityNGradModel(AbstractContext* ctx,
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
tape->Watch(ToId(inputs[1]));
|
||||
tape->Watch(inputs[0]);
|
||||
tape->Watch(inputs[1]);
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::IdentityN(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
|
||||
/*targets=*/{identity_n_outputs[1]},
|
||||
/*sources=*/{inputs[0], inputs[1]},
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto identity_n_output : identity_n_outputs) {
|
||||
identity_n_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -217,27 +184,19 @@ Status NegGradModel(AbstractContext* ctx,
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0]));
|
||||
tape->Watch(inputs[0]);
|
||||
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Neg(tape_ctx.get(), inputs, absl::MakeSpan(neg_outputs), "Neg"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(neg_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/neg_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto neg_output : neg_outputs) {
|
||||
neg_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -250,30 +209,20 @@ Status SubGradModel(AbstractContext* ctx,
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> sub_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(sub_outputs),
|
||||
"Sub")); // Compute x-y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sub_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/sub_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto sub_output : sub_outputs) {
|
||||
sub_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -286,30 +235,20 @@ Status MulGradModel(AbstractContext* ctx,
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(mul_outputs),
|
||||
"Mul")); // Compute x*y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(mul_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mul_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto mul_output : mul_outputs) {
|
||||
mul_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -322,27 +261,19 @@ Status Log1pGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
std::vector<AbstractTensorHandle*> log1p_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(log1p_outputs),
|
||||
"Log1p")); // Compute log(1 + x).
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(log1p_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/log1p_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto log1p_output : log1p_outputs) {
|
||||
log1p_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -355,30 +286,20 @@ Status DivNoNanGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> div_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::DivNoNan(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(div_outputs),
|
||||
"DivNoNan")); // Compute x / y.
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(div_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/div_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto div_output : div_outputs) {
|
||||
div_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -890,6 +811,8 @@ TEST_P(CppGradients, TestSetAttrString) {
|
||||
int num_retvals = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
GradientRegistry registry;
|
||||
s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
s = Execute(check_numerics_op.get(), ctx.get(), absl::MakeSpan(outputs),
|
||||
&num_retvals, &forward_op, tape.get(), registry);
|
||||
@ -901,6 +824,52 @@ TEST_P(CppGradients, TestSetAttrString) {
|
||||
ASSERT_EQ(read_message, message);
|
||||
}
|
||||
|
||||
Status RecordOperationWithNullGradientFunctionModel(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, inputs, absl::MakeSpan(neg_outputs), "Neg"));
|
||||
tape.RecordOperation(inputs, neg_outputs, nullptr, "Neg");
|
||||
return tape.ComputeGradient(ctx, /*targets=*/neg_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
Status s = RunModel(RecordOperationWithNullGradientFunctionModel, ctx.get(),
|
||||
{x.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
ASSERT_EQ(
|
||||
"Provided null gradient_function for 'Neg'.\nIf the intent is to treat "
|
||||
"this op as non-differentiable consider using RegisterNotDifferentiable "
|
||||
"or NotDifferentiableGradientFunction.",
|
||||
s.error_message());
|
||||
ASSERT_EQ(nullptr, outputs[0]);
|
||||
}
|
||||
|
||||
// TODO(b/164171226): Enable this test with tfrt after AddInputList is
|
||||
// supported. It is needed for IdentityN.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
@ -47,29 +47,19 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -81,10 +71,9 @@ Status MatMulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||
@ -92,21 +81,12 @@ Status MatMulGradModel(AbstractContext* ctx,
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute x*y.
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mm_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto mm_output : mm_outputs) {
|
||||
mm_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -138,39 +118,32 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
AbstractTensorHandle* W2 = inputs[2];
|
||||
AbstractTensorHandle* y_labels = inputs[3];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
tape->Watch(ToId(W2)); // Watch W2.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
|
||||
TF_RETURN_IF_ERROR(ops::Relu(ctx, {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"relu")); // Compute Relu(X*W1)
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
|
||||
"matmul1",
|
||||
ctx, {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
ctx, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||
|
||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||
|
||||
outputs[0] = scores;
|
||||
outputs[1] = loss_vals;
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -181,21 +154,9 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(X));
|
||||
tape->Watch(ToId(W1));
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, outputs, "matmul0",
|
||||
/*transpose_a=*/true,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -203,30 +164,22 @@ Status ReluGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch X
|
||||
tape->Watch(inputs[0]); // Watch X
|
||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(relu_outputs),
|
||||
"relu0")); // Relu(X)
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/relu_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
|
||||
for (auto relu_output : relu_outputs) {
|
||||
relu_output->Unref();
|
||||
}
|
||||
|
||||
outputs[0] = out_grads[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -235,28 +188,19 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||
tape->Watch(inputs[0]); // Watch scores.
|
||||
tape->Watch(inputs[1]); // Watch labels.
|
||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
|
||||
/*targets=*/sm_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -270,11 +214,10 @@ Status MNISTGradModel(AbstractContext* ctx,
|
||||
AbstractTensorHandle* W2 = inputs[2];
|
||||
AbstractTensorHandle* y_labels = inputs[3];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/true);
|
||||
tape->Watch(ToId(X)); // Watch X.
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
tape->Watch(ToId(W2)); // Watch W1.
|
||||
tape->Watch(X); // Watch X.
|
||||
tape->Watch(W1); // Watch W1.
|
||||
tape->Watch(W2); // Watch W1.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
@ -303,24 +246,14 @@ Status MNISTGradModel(AbstractContext* ctx,
|
||||
|
||||
AbstractTensorHandle* loss = temp_outputs[0];
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)},
|
||||
/*source_tensor_ids=*/{ToId(W1), ToId(W2)},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/{loss},
|
||||
/*sources=*/{W1, W2},
|
||||
/*output_gradients=*/{},
|
||||
outputs.subspan(0, 2)));
|
||||
|
||||
// Only release 2nd temp output as first holds loss values.
|
||||
temp_outputs[1]->Unref();
|
||||
|
||||
outputs[0] = out_grads[0]; // dW1
|
||||
outputs[1] = out_grads[1]; // dW2
|
||||
outputs[2] = loss;
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -329,83 +262,33 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* eta = inputs[0];
|
||||
AbstractTensorHandle* A = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0")); // Compute eta*A
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
return ops::Mul(ctx, inputs, outputs,
|
||||
"scalarMul0"); // Compute eta*A
|
||||
}
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
return ops::MatMul(ctx, inputs, outputs, "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false); // Compute X*W1
|
||||
}
|
||||
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* x = inputs[0];
|
||||
AbstractTensorHandle* y = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"mul0")); // Compute x*y
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
return ops::Mul(ctx, inputs, outputs,
|
||||
"mul0"); // Compute x*y
|
||||
}
|
||||
|
||||
Status SoftmaxModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* x = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
|
||||
|
||||
outputs[0] = temp_outputs[0]; // loss values
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
return ops::SparseSoftmaxCrossEntropyWithLogits(ctx, inputs, outputs,
|
||||
"sm_loss");
|
||||
}
|
||||
|
||||
// ============================= End Models ================================
|
||||
|
@ -93,11 +93,14 @@ class VSpace {
|
||||
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
|
||||
|
||||
// Calls the passed-in backward function.
|
||||
//
|
||||
// `unneeded_gradients` contains sorted list of input indices for which a
|
||||
// gradient is not required.
|
||||
virtual Status CallBackwardFunction(
|
||||
BackwardFunction* backward_function,
|
||||
const string& op_type, BackwardFunction* backward_function,
|
||||
const std::vector<int64>& unneeded_gradients,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) const = 0;
|
||||
absl::Span<Gradient*> result) const = 0;
|
||||
|
||||
// Builds a tensor filled with ones with the same shape and dtype as `t`.
|
||||
virtual Status BuildOnesLike(const TapeTensor& t,
|
||||
@ -133,11 +136,24 @@ class GradientTape {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns whether any tensor in a list of tensors is being watched and has
|
||||
// a trainable dtype.
|
||||
bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes);
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) const;
|
||||
|
||||
// Adds this tensor to the list of watched tensors.
|
||||
//
|
||||
// This is a no-op if the tensor is already being watched either from an
|
||||
// earlier call to `GradientTape::Watch` or being an output of an op with
|
||||
// watched inputs.
|
||||
void Watch(int64 tensor_id);
|
||||
|
||||
// Records an operation with inputs `input_tensor_id` and outputs
|
||||
// `output_tensors` on the tape and marks all its outputs as watched if at
|
||||
// least one input of the op is watched and has trainable dtype.
|
||||
//
|
||||
// op_type is used to decide which of the incoming gradients can be left as
|
||||
// nullptr instead of building zeros when build_default_zeros_grads == true.
|
||||
void RecordOperation(
|
||||
const string& op_type, const std::vector<TapeTensor>& output_tensors,
|
||||
gtl::ArraySlice<int64> input_tensor_id,
|
||||
@ -159,9 +175,10 @@ class GradientTape {
|
||||
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result, bool build_default_zeros_grads = true);
|
||||
gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
|
||||
bool build_default_zeros_grads = true);
|
||||
|
||||
// Whether the tape is persistent. See ctor for detailed description.
|
||||
bool IsPersistent() const { return persistent_; }
|
||||
|
||||
private:
|
||||
@ -311,11 +328,10 @@ class ForwardAccumulator {
|
||||
// function is running; this effectively adds the backward tape to the active
|
||||
// set (but does not require complicated callbacks to the language bindings).
|
||||
Status ForwardpropFromTape(
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
const string& op_type, const std::vector<TapeTensor>& output_tensors,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter,
|
||||
const std::vector<Gradient*>& in_grads,
|
||||
std::vector<Gradient*>* out_grads);
|
||||
const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads);
|
||||
|
||||
// Maps from tensor IDs to corresponding JVPs.
|
||||
std::unordered_map<int64, Gradient*> accumulated_gradients_;
|
||||
@ -368,7 +384,7 @@ inline bool IsDtypeTrainable(DataType dtype) {
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
||||
gtl::ArraySlice<int64> tensor_ids,
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) {
|
||||
gtl::ArraySlice<tensorflow::DataType> dtypes) const {
|
||||
CHECK_EQ(tensor_ids.size(), dtypes.size());
|
||||
for (int i = 0; i < tensor_ids.size(); ++i) {
|
||||
if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) {
|
||||
@ -668,7 +684,7 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
const gtl::ArraySlice<int64> target_tensor_ids,
|
||||
const gtl::ArraySlice<int64> source_tensor_ids,
|
||||
const std::unordered_map<int64, TapeTensor>& sources_that_are_targets,
|
||||
gtl::ArraySlice<Gradient*> output_gradients, std::vector<Gradient*>* result,
|
||||
gtl::ArraySlice<Gradient*> output_gradients, absl::Span<Gradient*> result,
|
||||
bool build_default_zeros_grads) {
|
||||
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
|
||||
source_tensor_ids.end());
|
||||
@ -757,23 +773,17 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
out_gradients.push_back(new_gradients);
|
||||
}
|
||||
}
|
||||
std::vector<Gradient*> in_gradients;
|
||||
VLOG(1) << "Calling gradient function for '" << trace.op_type << "'";
|
||||
std::vector<Gradient*> in_gradients(trace.input_tensor_id.size());
|
||||
DCHECK(build_default_zeros_grads || zero_indices.empty());
|
||||
if (any_gradient_nonzero) {
|
||||
for (const auto i : zero_indices) {
|
||||
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
|
||||
}
|
||||
Status s;
|
||||
s = vspace.CallBackwardFunction(trace.backward_function,
|
||||
s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function,
|
||||
unneeded_gradients, out_gradients,
|
||||
&in_gradients);
|
||||
if (in_gradients.size() != trace.input_tensor_id.size()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Recorded operation '", trace.op_type,
|
||||
"' returned too few gradients. Expected ",
|
||||
trace.input_tensor_id.size(), " but received ",
|
||||
in_gradients.size());
|
||||
}
|
||||
absl::MakeSpan(in_gradients));
|
||||
if (!persistent_) {
|
||||
trace.backward_function_deleter(trace.backward_function);
|
||||
}
|
||||
@ -781,7 +791,6 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
return s;
|
||||
}
|
||||
} else {
|
||||
in_gradients.resize(trace.input_tensor_id.size());
|
||||
if (!persistent_) {
|
||||
trace.backward_function_deleter(trace.backward_function);
|
||||
}
|
||||
@ -791,8 +800,6 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
|
||||
<< trace.input_tensor_id.size() << " sources";
|
||||
for (int i = 0, end = in_gradients.size(); i < end; ++i) {
|
||||
const int64 id = trace.input_tensor_id[i];
|
||||
if (in_gradients[i] != nullptr) {
|
||||
@ -856,20 +863,25 @@ Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
|
||||
if (!state.op_tape.empty()) {
|
||||
return tensorflow::errors::Internal("Invalid tape state.");
|
||||
}
|
||||
result->reserve(source_tensor_ids.size());
|
||||
if (result.size() != source_tensor_ids.size()) {
|
||||
return errors::Internal("Expected result Span to be of size ",
|
||||
source_tensor_ids.size(), " found ", result.size(),
|
||||
" in call to Tape::ComputeGradient.");
|
||||
}
|
||||
std::unordered_set<int64> used_gradient_ids(source_tensor_ids.size());
|
||||
for (auto is : source_tensor_ids) {
|
||||
auto grad_it = gradients.find(is);
|
||||
for (int i = 0; i < source_tensor_ids.size(); i++) {
|
||||
int64 tensor_id = source_tensor_ids[i];
|
||||
auto grad_it = gradients.find(tensor_id);
|
||||
if (grad_it == gradients.end()) {
|
||||
result->push_back(nullptr);
|
||||
result[i] = nullptr;
|
||||
} else {
|
||||
if (grad_it->second.size() > 1) {
|
||||
Gradient* grad = vspace.AggregateGradients(grad_it->second);
|
||||
grad_it->second.clear();
|
||||
grad_it->second.push_back(grad);
|
||||
}
|
||||
result->push_back(grad_it->second[0]);
|
||||
used_gradient_ids.insert(is);
|
||||
result[i] = grad_it->second[0];
|
||||
used_gradient_ids.insert(tensor_id);
|
||||
}
|
||||
}
|
||||
VLOG(1) << "Final gradients size: "
|
||||
@ -910,10 +922,10 @@ bool ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
|
||||
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
|
||||
Status
|
||||
ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
const std::vector<TapeTensor>& output_tensors,
|
||||
const string& op_type, const std::vector<TapeTensor>& output_tensors,
|
||||
const std::function<BackwardFunction*()>& backward_function_getter,
|
||||
const std::function<void(BackwardFunction*)>& backward_function_deleter,
|
||||
const std::vector<Gradient*>& in_grads, std::vector<Gradient*>* out_grads) {
|
||||
const std::vector<Gradient*>& in_grads, absl::Span<Gradient*> out_grads) {
|
||||
/* This function is approximately equivalent to this Python code:
|
||||
|
||||
forwardprop_aids = tf.ones_like(output_tensors)
|
||||
@ -957,7 +969,7 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
sources_set.insert(aid_id);
|
||||
tape->Watch(aid_id);
|
||||
}
|
||||
std::vector<Gradient*> grad;
|
||||
std::vector<Gradient*> grad(in_grads.size());
|
||||
auto delete_grad = gtl::MakeCleanup([&grad, this] {
|
||||
for (Gradient* tensor : grad) {
|
||||
this->vspace_.DeleteGradient(tensor);
|
||||
@ -969,16 +981,13 @@ ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::ForwardpropFromTape(
|
||||
backward_function(backward_function_getter(),
|
||||
backward_function_deleter);
|
||||
TF_RETURN_IF_ERROR(vspace_.CallBackwardFunction(
|
||||
backward_function.get(), unneeded_gradients, forwardprop_aids, &grad));
|
||||
op_type, backward_function.get(), unneeded_gradients, forwardprop_aids,
|
||||
absl::MakeSpan(grad)));
|
||||
}
|
||||
|
||||
// Stop the tape from recording
|
||||
pop_backward_tape.release()();
|
||||
|
||||
if (grad.size() != in_grads.size()) {
|
||||
return tensorflow::errors::Internal("Wrong number of gradients returned.");
|
||||
}
|
||||
|
||||
std::vector<int64> targets;
|
||||
std::vector<Gradient*> used_in_grads;
|
||||
// We may end up with slightly fewer elements than we reserve, but grad.size()
|
||||
@ -1076,9 +1085,10 @@ Status ForwardAccumulator<Gradient, BackwardFunction, TapeTensor>::Accumulate(
|
||||
if (forward_function == nullptr) {
|
||||
// We have no special-cased forward gradient. Fall back to running the
|
||||
// backward function under a gradient tape.
|
||||
forward_grads.resize(output_tensors.size());
|
||||
TF_RETURN_IF_ERROR(ForwardpropFromTape(
|
||||
output_tensors, backward_function_getter, backward_function_deleter,
|
||||
in_grads, &forward_grads));
|
||||
op_type, output_tensors, backward_function_getter,
|
||||
backward_function_deleter, in_grads, absl::MakeSpan(forward_grads)));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
(*forward_function)(in_grads, &forward_grads, use_batch_));
|
||||
|
@ -30,6 +30,7 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
@ -79,12 +80,29 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "not_differentiable",
|
||||
srcs = ["not_differentiable.cc"],
|
||||
hdrs = [
|
||||
"not_differentiable.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients",
|
||||
hdrs = [
|
||||
"array_grad.h",
|
||||
"math_grad.h",
|
||||
"nn_grad.h",
|
||||
"not_differentiable.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
@ -93,6 +111,7 @@ cc_library(
|
||||
":array_grad",
|
||||
":math_grad",
|
||||
":nn_grad",
|
||||
":not_differentiable",
|
||||
"//tensorflow/c/eager:gradients_internal",
|
||||
],
|
||||
)
|
||||
@ -103,6 +122,7 @@ filegroup(
|
||||
"array_grad.h",
|
||||
"math_grad.h",
|
||||
"nn_grad.h",
|
||||
"not_differentiable.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
|
@ -14,23 +14,24 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace {
|
||||
using std::vector;
|
||||
class IdentityNGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(grad_inputs.size(), nullptr);
|
||||
for (int i = 0; i < grad_inputs.size(); i++) {
|
||||
auto grad_input = grad_inputs[i];
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
for (int i = 0; i < grad_outputs.size(); i++) {
|
||||
auto grad_input = grad_outputs[i];
|
||||
// TODO(srbs): Should we add a copy contructor to AbstractTensorHandle
|
||||
// that takes care of this similar to `Tensor`?
|
||||
if (grad_input) {
|
||||
grad_input->Ref();
|
||||
}
|
||||
(*grad_outputs)[i] = grad_input;
|
||||
grad_inputs[i] = grad_input;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -38,10 +39,8 @@ class IdentityNGradientFunction : public GradientFunction {
|
||||
};
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new IdentityNGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* IdentityNRegisterer(const ForwardOperation& op) {
|
||||
return new IdentityNGradientFunction;
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
BackwardFunction* IdentityNRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* IdentityNRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -37,17 +37,17 @@ namespace {
|
||||
|
||||
class AddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
// TODO(b/161805092): Support broadcasting.
|
||||
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
(*grad_outputs)[1] = grad_inputs[0];
|
||||
DCHECK(grad_outputs[0]);
|
||||
grad_inputs[0] = grad_outputs[0];
|
||||
grad_inputs[1] = grad_outputs[0];
|
||||
|
||||
(*grad_outputs)[0]->Ref();
|
||||
(*grad_outputs)[1]->Ref();
|
||||
grad_inputs[0]->Ref();
|
||||
grad_inputs[1]->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
~AddGradientFunction() override {}
|
||||
@ -58,18 +58,18 @@ class ExpGradientFunction : public GradientFunction {
|
||||
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
|
||||
exp->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
vector<AbstractTensorHandle*> conj_outputs(1);
|
||||
std::string name = "Conj_Exp_Grad";
|
||||
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()},
|
||||
absl::MakeSpan(conj_outputs), name.c_str()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
Conj(ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), name.c_str()));
|
||||
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
|
||||
grad_outputs->resize(1);
|
||||
|
||||
name = "Mul_Exp_Grad";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, {conj_outputs[0], grad_outputs[0]}, grad_inputs,
|
||||
name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~ExpGradientFunction() override {}
|
||||
@ -83,12 +83,12 @@ class SqrtGradientFunction : public GradientFunction {
|
||||
explicit SqrtGradientFunction(AbstractTensorHandle* sqrt) : sqrt_(sqrt) {
|
||||
sqrt->Ref();
|
||||
}
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
std::string name = "Sqrt_Grad";
|
||||
grad_outputs->resize(1);
|
||||
TF_RETURN_IF_ERROR(SqrtGrad(ctx->ctx, {sqrt_.get(), grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
TF_RETURN_IF_ERROR(SqrtGrad(ctx, {sqrt_.get(), grad_outputs[0]},
|
||||
absl::MakeSpan(grad_inputs), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~SqrtGradientFunction() override {}
|
||||
@ -101,10 +101,17 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MatMulGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
AttrBuilder f_attrs)
|
||||
: forward_inputs(f_inputs), forward_attrs(f_attrs) {}
|
||||
: forward_inputs_(f_inputs), forward_attrs_(f_attrs) {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Ref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
/* Given upstream grad U and a matmul op A*B, the gradients are:
|
||||
*
|
||||
* dA = U * B.T
|
||||
@ -112,29 +119,28 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
*
|
||||
* where A.T means `transpose(A)`
|
||||
*/
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
grad_outputs->resize(2);
|
||||
AbstractTensorHandle* upstream_grad = grad_outputs[0];
|
||||
|
||||
// Get transpose attrs
|
||||
bool t_a;
|
||||
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_a", &t_a));
|
||||
TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_a", &t_a));
|
||||
|
||||
bool t_b;
|
||||
TF_RETURN_IF_ERROR(forward_attrs.Get("transpose_b", &t_b));
|
||||
TF_RETURN_IF_ERROR(forward_attrs_.Get("transpose_b", &t_b));
|
||||
|
||||
// Conj each input
|
||||
vector<AbstractTensorHandle*> conj_outputs(1);
|
||||
std::string name = "Conj_A_MatMul_Grad";
|
||||
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]},
|
||||
TF_RETURN_IF_ERROR(Conj(ctx, {forward_inputs_[0]},
|
||||
absl::MakeSpan(conj_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* A = conj_outputs[0];
|
||||
AbstractTensorHandlePtr A(conj_outputs[0]);
|
||||
|
||||
name = "Conj_B_MatMul_Grad";
|
||||
TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]},
|
||||
TF_RETURN_IF_ERROR(Conj(ctx, {forward_inputs_[1]},
|
||||
absl::MakeSpan(conj_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* B = conj_outputs[0];
|
||||
AbstractTensorHandlePtr B(conj_outputs[0]);
|
||||
|
||||
// Calc Grad
|
||||
vector<AbstractTensorHandle*> matmul_A_outputs(1);
|
||||
@ -142,50 +148,50 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
std::string name_grad_A = "MatMul_Grad_A";
|
||||
std::string name_grad_B = "MatMul_Grad_B";
|
||||
if (!t_a && !t_b) {
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, B.get()},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ true));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {A.get(), upstream_grad},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ false));
|
||||
} else if (!t_a && t_b) {
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, B.get()},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ false));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, A.get()},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ false));
|
||||
|
||||
} else if (t_a && !t_b) {
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {B.get(), upstream_grad},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ true));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {A.get(), upstream_grad},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ false));
|
||||
} else { // t_a && t_b
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {B.get(), upstream_grad},
|
||||
absl::MakeSpan(matmul_A_outputs),
|
||||
name_grad_A.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ true));
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A},
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, {upstream_grad, A.get()},
|
||||
absl::MakeSpan(matmul_B_outputs),
|
||||
name_grad_B.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
@ -193,33 +199,40 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
}
|
||||
|
||||
// Gradient for A
|
||||
(*grad_outputs)[0] = matmul_A_outputs[0];
|
||||
grad_inputs[0] = matmul_A_outputs[0];
|
||||
|
||||
// Gradient for B
|
||||
(*grad_outputs)[1] = matmul_B_outputs[0];
|
||||
grad_inputs[1] = matmul_B_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~MatMulGradientFunction() override {}
|
||||
~MatMulGradientFunction() override {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Unref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
AttrBuilder forward_attrs;
|
||||
// TODO(b/174778737): Only hold needed inputs.
|
||||
vector<AbstractTensorHandle*> forward_inputs_;
|
||||
AttrBuilder forward_attrs_;
|
||||
};
|
||||
|
||||
class NegGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
/* Given upstream grad U and a Neg op Y = -X, the gradients are:
|
||||
*
|
||||
* dX = -U
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(1);
|
||||
std::string name = "Neg_Grad";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(*grad_outputs), name.c_str()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Neg(ctx, {grad_outputs[0]}, grad_inputs, name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~NegGradientFunction() override {}
|
||||
@ -227,8 +240,9 @@ class NegGradientFunction : public GradientFunction {
|
||||
|
||||
class SubGradientFunction : public GradientFunction {
|
||||
public:
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
/* Given upstream grad U and a Sub op A-B, the gradients are:
|
||||
*
|
||||
* dA = U
|
||||
@ -236,20 +250,16 @@ class SubGradientFunction : public GradientFunction {
|
||||
*
|
||||
*/
|
||||
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Grad for A
|
||||
DCHECK(grad_inputs[0]);
|
||||
(*grad_outputs)[0] = grad_inputs[0];
|
||||
(*grad_outputs)[0]->Ref();
|
||||
DCHECK(grad_outputs[0]);
|
||||
grad_inputs[0] = grad_outputs[0];
|
||||
grad_inputs[0]->Ref();
|
||||
|
||||
// Grad for B
|
||||
// negate the upstream grad
|
||||
std::vector<AbstractTensorHandle*> neg_outputs(1);
|
||||
std::string name = "Neg_Sub_Grad_B";
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx->ctx, {grad_inputs[0]},
|
||||
absl::MakeSpan(neg_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = neg_outputs[0];
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, {grad_outputs[0]},
|
||||
grad_inputs.subspan(1, 1), name.c_str()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -259,10 +269,17 @@ class SubGradientFunction : public GradientFunction {
|
||||
class MulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MulGradientFunction(vector<AbstractTensorHandle*> f_inputs)
|
||||
: forward_inputs(f_inputs) {}
|
||||
: forward_inputs_(f_inputs) {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Ref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
/* Given upstream grad U and a mul op A*B, the gradients are:
|
||||
*
|
||||
* dA = U * B
|
||||
@ -270,36 +287,46 @@ class MulGradientFunction : public GradientFunction {
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
AbstractTensorHandle* upstream_grad = grad_outputs[0];
|
||||
|
||||
// Gradient for A
|
||||
std::string name = "Mul_Grad_A";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {upstream_grad, forward_inputs[1]},
|
||||
absl::MakeSpan(mul_outputs), name.c_str()));
|
||||
(*grad_outputs)[0] = mul_outputs[0];
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, {upstream_grad, forward_inputs_[1]},
|
||||
grad_inputs.subspan(0, 1), name.c_str()));
|
||||
|
||||
// Gradient for B
|
||||
name = "Mul_Grad_B";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {forward_inputs[0], upstream_grad},
|
||||
absl::MakeSpan(mul_outputs), name.c_str()));
|
||||
(*grad_outputs)[1] = mul_outputs[0];
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, {forward_inputs_[0], upstream_grad},
|
||||
grad_inputs.subspan(1, 1), name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~MulGradientFunction() override {}
|
||||
~MulGradientFunction() override {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Unref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
// TODO(b/174778737): Only hold needed inputs.
|
||||
vector<AbstractTensorHandle*> forward_inputs_;
|
||||
};
|
||||
|
||||
class Log1pGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit Log1pGradientFunction(vector<AbstractTensorHandle*> f_inputs)
|
||||
: forward_inputs(f_inputs) {}
|
||||
: forward_inputs_(f_inputs) {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Ref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
// TODO(vnvo2409): Add control dependency
|
||||
/* Given upstream grad U and a Log1p op: Y = log(1 + X), the gradients are:
|
||||
*
|
||||
@ -307,56 +334,72 @@ class Log1pGradientFunction : public GradientFunction {
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* X = forward_inputs[0];
|
||||
AbstractTensorHandle* upstream_grad = grad_outputs[0];
|
||||
AbstractTensorHandle* X = forward_inputs_[0];
|
||||
|
||||
grad_outputs->resize(1);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
// Calculate conjugate of X
|
||||
std::string name = "Conj_Log1p_Grad_X";
|
||||
TF_RETURN_IF_ERROR(
|
||||
Conj(ctx->ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
Conj(ctx, {X}, absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Conj_X = temp_outputs[0];
|
||||
AbstractTensorHandlePtr Conj_X(temp_outputs[0]);
|
||||
|
||||
// Creates Ones
|
||||
name = "OnesLike_Log1p_Grad_X";
|
||||
TF_RETURN_IF_ERROR(OnesLike(ctx->ctx, {Conj_X},
|
||||
TF_RETURN_IF_ERROR(OnesLike(ctx, {Conj_X.get()},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Ones_X = temp_outputs[0];
|
||||
AbstractTensorHandlePtr Ones_X(temp_outputs[0]);
|
||||
|
||||
name = "Add_Log1p_Grad_X";
|
||||
// Calculate 1 + Conj(X)
|
||||
TF_RETURN_IF_ERROR(Add(ctx->ctx, {Ones_X, Conj_X},
|
||||
TF_RETURN_IF_ERROR(Add(ctx, {Ones_X.get(), Conj_X.get()},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
AbstractTensorHandle* Conj_XP1 = temp_outputs[0];
|
||||
AbstractTensorHandlePtr Conj_XP1(temp_outputs[0]);
|
||||
|
||||
name = "Div_Log1p_Grad_X";
|
||||
// Calculate U / (1 + Conj(X))
|
||||
TF_RETURN_IF_ERROR(Div(ctx->ctx, {upstream_grad, Conj_XP1},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = temp_outputs[0];
|
||||
TF_RETURN_IF_ERROR(
|
||||
Div(ctx, {upstream_grad, Conj_XP1.get()}, grad_inputs, name.c_str()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
~Log1pGradientFunction() override {}
|
||||
~Log1pGradientFunction() override {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Unref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
// TODO(b/174778737): Only hold needed inputs.
|
||||
vector<AbstractTensorHandle*> forward_inputs_;
|
||||
};
|
||||
|
||||
class DivNoNanGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit DivNoNanGradientFunction(vector<AbstractTensorHandle*> f_inputs,
|
||||
vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_inputs(f_inputs), forward_outputs(f_outputs) {}
|
||||
: forward_inputs_(f_inputs), forward_outputs_(f_outputs) {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Ref();
|
||||
}
|
||||
}
|
||||
for (auto output : forward_outputs_) {
|
||||
if (output) {
|
||||
output->Ref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
// TODO(vnvo2409): Add shape broadcasting
|
||||
/* Given upstream grad U and a Div op: Z = X/Y, the gradients are:
|
||||
*
|
||||
@ -365,126 +408,88 @@ class DivNoNanGradientFunction : public GradientFunction {
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* Y = forward_inputs[1];
|
||||
AbstractTensorHandle* Z = forward_outputs[0];
|
||||
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractTensorHandle* upstream_grad = grad_outputs[0];
|
||||
AbstractTensorHandle* Y = forward_inputs_[1];
|
||||
AbstractTensorHandle* Z = forward_outputs_[0];
|
||||
|
||||
// Calculate dX = U / Y
|
||||
std::string name = "Div_Grad_X";
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {upstream_grad, Y},
|
||||
absl::MakeSpan(temp_outputs), name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = temp_outputs[0];
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx, {upstream_grad, Y},
|
||||
grad_inputs.subspan(0, 1), name.c_str()));
|
||||
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
// Calculate dY = -U*Z / Y
|
||||
name = "Neg_Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(Neg(ctx->ctx, {upstream_grad},
|
||||
absl::MakeSpan(temp_outputs), name.c_str())); // -U
|
||||
AbstractTensorHandle* MinusU = temp_outputs[0];
|
||||
TF_RETURN_IF_ERROR(Neg(ctx, {upstream_grad}, absl::MakeSpan(temp_outputs),
|
||||
name.c_str())); // -U
|
||||
AbstractTensorHandlePtr MinusU(temp_outputs[0]);
|
||||
|
||||
name = "Mul_Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {MinusU, Z}, absl::MakeSpan(temp_outputs),
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, {MinusU.get(), Z}, absl::MakeSpan(temp_outputs),
|
||||
name.c_str())); // -U*Z
|
||||
AbstractTensorHandle* UZ = temp_outputs[0];
|
||||
AbstractTensorHandlePtr UZ(temp_outputs[0]);
|
||||
|
||||
name = "Div_Grad_Y";
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx->ctx, {UZ, Y}, absl::MakeSpan(temp_outputs),
|
||||
TF_RETURN_IF_ERROR(DivNoNan(ctx, {UZ.get(), Y}, grad_inputs.subspan(1, 1),
|
||||
name.c_str())); // -U*Z / Y
|
||||
|
||||
(*grad_outputs)[1] = temp_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~DivNoNanGradientFunction() override {}
|
||||
~DivNoNanGradientFunction() override {
|
||||
for (auto input : forward_inputs_) {
|
||||
if (input) {
|
||||
input->Unref();
|
||||
}
|
||||
}
|
||||
for (auto output : forward_outputs_) {
|
||||
if (output) {
|
||||
output->Unref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_inputs;
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
// TODO(b/174778737): Only hold needed inputs and outputs.
|
||||
vector<AbstractTensorHandle*> forward_inputs_;
|
||||
vector<AbstractTensorHandle*> forward_outputs_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new AddGradientFunction;
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
return new AddGradientFunction;
|
||||
}
|
||||
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new ExpGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
|
||||
return new ExpGradientFunction(op.outputs[0]);
|
||||
}
|
||||
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
return new MatMulGradientFunction(op.inputs, op.attrs);
|
||||
}
|
||||
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new SqrtGradientFunction(op.outputs[0]);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* SqrtRegisterer(const ForwardOperation& op) {
|
||||
return new SqrtGradientFunction(op.outputs[0]);
|
||||
}
|
||||
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new NegGradientFunction;
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* NegRegisterer(const ForwardOperation& op) {
|
||||
return new NegGradientFunction;
|
||||
}
|
||||
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new SubGradientFunction;
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* SubRegisterer(const ForwardOperation& op) {
|
||||
return new SubGradientFunction;
|
||||
}
|
||||
|
||||
BackwardFunction* MulRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new MulGradientFunction(op.inputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* MulRegisterer(const ForwardOperation& op) {
|
||||
return new MulGradientFunction(op.inputs);
|
||||
}
|
||||
|
||||
BackwardFunction* Log1pRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new Log1pGradientFunction(op.inputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* Log1pRegisterer(const ForwardOperation& op) {
|
||||
return new Log1pGradientFunction(op.inputs);
|
||||
}
|
||||
|
||||
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new DivNoNanGradientFunction(op.inputs, op.outputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* DivNoNanRegisterer(const ForwardOperation& op) {
|
||||
return new DivNoNanGradientFunction(op.inputs, op.outputs);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
|
@ -20,15 +20,15 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
BackwardFunction* AddRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* NegRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SubRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* MulRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* Log1pRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* DivNoNanRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* AddRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* ExpRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* MatMulRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* SqrtRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* NegRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* SubRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* MulRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* Log1pRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* DivNoNanRegisterer(const ForwardOperation& op);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -36,29 +36,37 @@ namespace {
|
||||
class ReluGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit ReluGradientFunction(vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_outputs(f_outputs) {}
|
||||
: forward_outputs_(f_outputs) {
|
||||
for (auto output : forward_outputs_) {
|
||||
if (output) {
|
||||
output->Ref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* activations = forward_outputs[0];
|
||||
grad_outputs->resize(1);
|
||||
vector<AbstractTensorHandle*> relugrad_outputs(1);
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
AbstractTensorHandle* upstream_grad = grad_outputs[0];
|
||||
AbstractTensorHandle* activations = forward_outputs_[0];
|
||||
|
||||
// Calculate Grad
|
||||
std::string name = "relu_grad";
|
||||
|
||||
TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations},
|
||||
absl::MakeSpan(relugrad_outputs),
|
||||
name.c_str()));
|
||||
(*grad_outputs)[0] = relugrad_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReluGrad(ctx, {upstream_grad, activations}, grad_inputs, name.c_str()));
|
||||
return Status::OK();
|
||||
}
|
||||
~ReluGradientFunction() override {}
|
||||
~ReluGradientFunction() override {
|
||||
for (auto output : forward_outputs_) {
|
||||
if (output) {
|
||||
output->Unref();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
// TODO(b/174778737): Only hold needed outputs.
|
||||
vector<AbstractTensorHandle*> forward_outputs_;
|
||||
};
|
||||
|
||||
Status BroadcastMul(AbstractContext* ctx, AbstractTensorHandle* vec,
|
||||
@ -87,98 +95,79 @@ class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
|
||||
public:
|
||||
explicit SparseSoftmaxCrossEntropyWithLogitsGradientFunction(
|
||||
vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
: forward_outputs_(f_outputs) {}
|
||||
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
// Grad for Softmax Input
|
||||
vector<AbstractTensorHandle*> mul_outputs(1);
|
||||
TF_RETURN_IF_ERROR(BroadcastMul(
|
||||
ctx->ctx, grad_inputs[0], forward_outputs[1],
|
||||
absl::MakeSpan(mul_outputs))); // upstream_grad * local softmax grad
|
||||
(*grad_outputs)[0] = mul_outputs[0];
|
||||
ctx, grad_outputs[0], forward_outputs_[1],
|
||||
grad_inputs.subspan(0, 1))); // upstream_grad * local softmax grad
|
||||
|
||||
// Grad for labels is null
|
||||
(*grad_outputs)[1] = nullptr;
|
||||
|
||||
grad_inputs[1] = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
~SparseSoftmaxCrossEntropyWithLogitsGradientFunction() override {}
|
||||
|
||||
private:
|
||||
vector<AbstractTensorHandle*> forward_outputs;
|
||||
vector<AbstractTensorHandle*> forward_outputs_;
|
||||
};
|
||||
|
||||
// TODO(vnvo2409): Add python test
|
||||
class BiasAddGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit BiasAddGradientFunction(AttrBuilder f_attrs)
|
||||
: forward_attrs(f_attrs) {}
|
||||
: forward_attrs_(f_attrs) {}
|
||||
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override {
|
||||
/* Given upstream grad U and a BiasAdd: A + bias, the gradients are:
|
||||
*
|
||||
* dA = U
|
||||
* dbias = reduceSum(U, dims = channel_dim)
|
||||
*/
|
||||
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* upstream_grad = grad_outputs[0];
|
||||
DCHECK(upstream_grad);
|
||||
grad_outputs->resize(2);
|
||||
|
||||
// Recover data format from forward pass for gradient.
|
||||
std::string data_format;
|
||||
TF_RETURN_IF_ERROR(forward_attrs.Get("data_format", &data_format));
|
||||
TF_RETURN_IF_ERROR(forward_attrs_.Get("data_format", &data_format));
|
||||
|
||||
// Grad for A
|
||||
(*grad_outputs)[0] = upstream_grad;
|
||||
(*grad_outputs)[0]->Ref();
|
||||
grad_inputs[0] = upstream_grad;
|
||||
grad_inputs[0]->Ref();
|
||||
|
||||
// Grad for bias
|
||||
vector<AbstractTensorHandle*> bias_add_grad_outputs(1);
|
||||
std::string name = "bias_add_grad";
|
||||
TF_RETURN_IF_ERROR(BiasAddGrad(ctx->ctx, {upstream_grad},
|
||||
absl::MakeSpan(bias_add_grad_outputs),
|
||||
TF_RETURN_IF_ERROR(BiasAddGrad(ctx, {upstream_grad},
|
||||
grad_inputs.subspan(1, 1),
|
||||
data_format.c_str(), name.c_str()));
|
||||
|
||||
(*grad_outputs)[1] = bias_add_grad_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
~BiasAddGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AttrBuilder forward_attrs;
|
||||
AttrBuilder forward_attrs_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
|
||||
auto gradient_function = new ReluGradientFunction(op.outputs);
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* ReluRegisterer(const ForwardOperation& op) {
|
||||
return new ReluGradientFunction(op.outputs);
|
||||
}
|
||||
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
|
||||
GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
|
||||
const ForwardOperation& op) {
|
||||
auto gradient_function =
|
||||
new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
return new SparseSoftmaxCrossEntropyWithLogitsGradientFunction(op.outputs);
|
||||
}
|
||||
|
||||
BackwardFunction* BiasAddRegisterer(const ForwardOperation& op) {
|
||||
// For ops with a single output, the gradient function is not called if there
|
||||
// is no incoming gradient. So we do not need to worry about creating zeros
|
||||
// grads in this case.
|
||||
auto gradient_function = new BiasAddGradientFunction(op.attrs);
|
||||
auto default_gradients = new PassThroughDefaultGradients(op);
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
GradientFunction* BiasAddRegisterer(const ForwardOperation& op) {
|
||||
return new BiasAddGradientFunction(op.attrs);
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
|
@ -19,10 +19,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
|
||||
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
|
||||
GradientFunction* ReluRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
|
||||
const ForwardOperation& op);
|
||||
BackwardFunction* BiasAddRegisterer(const ForwardOperation& op);
|
||||
GradientFunction* BiasAddRegisterer(const ForwardOperation& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -25,7 +25,6 @@ namespace internal {
|
||||
namespace {
|
||||
|
||||
using tensorflow::TF_StatusPtr;
|
||||
using tracing::TracingOperation;
|
||||
|
||||
Status BiasAddModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
@ -38,30 +37,20 @@ Status BiasAddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch A.
|
||||
tape->Watch(ToId(inputs[1])); // Watch Bias.
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]); // Watch A.
|
||||
tape.Watch(inputs[1]); // Watch Bias.
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "BiasAddGrad"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(temp_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads,
|
||||
/*build_default_zeros_grads=*/false));
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
34
tensorflow/c/experimental/gradients/not_differentiable.cc
Normal file
34
tensorflow/c/experimental/gradients/not_differentiable.cc
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/not_differentiable.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
Status NotDifferentiableGradientFunction::Compute(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) {
|
||||
for (int i = 0; i < grad_inputs.size(); i++) {
|
||||
grad_inputs[i] = nullptr;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RegisterNotDifferentiable(GradientRegistry* registry, const string& op) {
|
||||
return registry->Register(op, [](const ForwardOperation& op) {
|
||||
return new NotDifferentiableGradientFunction;
|
||||
});
|
||||
}
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
34
tensorflow/c/experimental/gradients/not_differentiable.h
Normal file
34
tensorflow/c/experimental/gradients/not_differentiable.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
// Ignores `grad_outputs` and sets all entries in grad_inputs to nullptr.
|
||||
class NotDifferentiableGradientFunction : public GradientFunction {
|
||||
Status Compute(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
||||
absl::Span<AbstractTensorHandle*> grad_inputs) override;
|
||||
};
|
||||
// Shorthand for registry->Register(op, new NotDifferentiableGradientFunction)
|
||||
Status RegisterNotDifferentiable(GradientRegistry* registry, const string& op);
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NOT_DIFFERENTIABLE_H_
|
@ -197,12 +197,6 @@ AbstractOperation* TapeOperation::GetBackingOperation() { return parent_op_; }
|
||||
Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
TF_RETURN_IF_ERROR(parent_op_->Execute(retvals, num_retvals));
|
||||
std::vector<int64> input_ids(forward_op_.inputs.size());
|
||||
std::vector<tensorflow::DataType> input_dtypes(forward_op_.inputs.size());
|
||||
for (int i = 0; i < forward_op_.inputs.size(); i++) {
|
||||
input_ids[i] = ToId(forward_op_.inputs[i]);
|
||||
input_dtypes[i] = forward_op_.inputs[i]->DataType();
|
||||
}
|
||||
for (int i = 0; i < *num_retvals; i++) {
|
||||
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
|
||||
forward_op_.outputs.push_back(retvals[i]);
|
||||
@ -212,25 +206,11 @@ Status TapeOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
// Consider getting rid of this and making the behavior between number types
|
||||
// and string consistent.
|
||||
forward_op_.attrs.BuildNodeDef();
|
||||
std::vector<TapeTensor> tape_tensors;
|
||||
for (auto t : retvals) {
|
||||
tape_tensors.push_back(TapeTensor(t));
|
||||
}
|
||||
tape_->RecordOperation(
|
||||
parent_op_->Name(), tape_tensors, input_ids, input_dtypes,
|
||||
[this]() -> BackwardFunction* {
|
||||
std::unique_ptr<BackwardFunction> backward_fn;
|
||||
Status s = registry_.Lookup(forward_op_, &backward_fn);
|
||||
if (!s.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return backward_fn.release();
|
||||
},
|
||||
[](BackwardFunction* ptr) {
|
||||
if (ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
});
|
||||
// TODO(b/170307493): Populate skip_input_indices here.
|
||||
std::unique_ptr<GradientFunction> backward_fn;
|
||||
TF_RETURN_IF_ERROR(registry_.Lookup(forward_op_, &backward_fn));
|
||||
tape_->RecordOperation(forward_op_.inputs, forward_op_.outputs,
|
||||
backward_fn.release(), parent_op_->Name());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1327,10 +1327,10 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
|
||||
}
|
||||
|
||||
tensorflow::Status CallBackwardFunction(
|
||||
PyBackwardFunction* backward_function,
|
||||
const string& op_type, PyBackwardFunction* backward_function,
|
||||
const std::vector<tensorflow::int64>& unneeded_gradients,
|
||||
tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
|
||||
std::vector<PyObject*>* result) const final {
|
||||
absl::Span<PyObject*> result) const final {
|
||||
PyObject* grads = PyTuple_New(output_gradients.size());
|
||||
for (int i = 0; i < output_gradients.size(); ++i) {
|
||||
if (output_gradients[i] == nullptr) {
|
||||
@ -1346,7 +1346,6 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
|
||||
if (py_result == nullptr) {
|
||||
return tensorflow::errors::Internal("gradient function threw exceptions");
|
||||
}
|
||||
result->clear();
|
||||
PyObject* seq =
|
||||
PySequence_Fast(py_result, "expected a sequence of gradients");
|
||||
if (seq == nullptr) {
|
||||
@ -1354,16 +1353,21 @@ class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
|
||||
"gradient function did not return a list");
|
||||
}
|
||||
int len = PySequence_Fast_GET_SIZE(seq);
|
||||
if (len != result.size()) {
|
||||
return tensorflow::errors::Internal(
|
||||
"Recorded operation '", op_type,
|
||||
"' returned too few gradients. Expected ", result.size(),
|
||||
" but received ", len);
|
||||
}
|
||||
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
|
||||
VLOG(1) << "Gradient length is " << len;
|
||||
result->reserve(len);
|
||||
for (int i = 0; i < len; ++i) {
|
||||
PyObject* item = seq_array[i];
|
||||
if (item == Py_None) {
|
||||
result->push_back(nullptr);
|
||||
result[i] = nullptr;
|
||||
} else {
|
||||
Py_INCREF(item);
|
||||
result->push_back(item);
|
||||
result[i] = item;
|
||||
}
|
||||
}
|
||||
Py_DECREF(seq);
|
||||
@ -2774,10 +2778,10 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
|
||||
Py_INCREF(tensor);
|
||||
}
|
||||
}
|
||||
std::vector<PyObject*> result;
|
||||
std::vector<PyObject*> result(sources_vec.size());
|
||||
status->status = tape_obj->tape->ComputeGradient(
|
||||
*py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
|
||||
outgrad_vec, &result);
|
||||
outgrad_vec, absl::MakeSpan(result));
|
||||
if (!status->status.ok()) {
|
||||
if (PyErr_Occurred()) {
|
||||
// Do not propagate the erroneous status as that would swallow the
|
||||
|
@ -47,35 +47,19 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
PYBIND11_MODULE(_tape, m) {
|
||||
py::class_<Tape>(m, "Tape")
|
||||
.def(py::init([](bool persistent) { return new Tape(persistent); }))
|
||||
.def("Watch",
|
||||
[](Tape* self, AbstractTensorHandle* t) { self->Watch(ToId(t)); })
|
||||
.def("Watch", [](Tape* self, AbstractTensorHandle* t) { self->Watch(t); })
|
||||
.def("ComputeGradient",
|
||||
[](Tape* self, TapeVSpace* vspace,
|
||||
[](Tape* self, AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*> target_tensors,
|
||||
std::vector<AbstractTensorHandle*> source_tensors,
|
||||
std::vector<AbstractTensorHandle*> output_gradients) {
|
||||
std::vector<int64> target_tensor_ids;
|
||||
std::vector<int64> source_tensor_ids;
|
||||
target_tensor_ids.reserve(target_tensors.size());
|
||||
source_tensor_ids.reserve(source_tensors.size());
|
||||
for (auto t : target_tensors) {
|
||||
target_tensor_ids.emplace_back(ToId(t));
|
||||
}
|
||||
for (auto t : source_tensors) {
|
||||
source_tensor_ids.emplace_back(ToId(t));
|
||||
}
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
std::vector<AbstractTensorHandle*> results;
|
||||
Status s = self->ComputeGradient(
|
||||
*vspace, target_tensor_ids, source_tensor_ids,
|
||||
source_tensors_that_are_targets, output_gradients, &results,
|
||||
/*build_default_zeros_grads=*/false);
|
||||
std::vector<AbstractTensorHandle*> results(source_tensors.size());
|
||||
Status s = self->ComputeGradient(ctx, target_tensors,
|
||||
source_tensors, output_gradients,
|
||||
absl::MakeSpan(results));
|
||||
MaybeRaiseRegisteredFromStatus(s);
|
||||
return results;
|
||||
});
|
||||
py::class_<TapeVSpace>(m, "TapeVSpace")
|
||||
.def(py::init([](AbstractContext* ctx) { return new TapeVSpace(ctx); }));
|
||||
py::class_<GradientRegistry>(m, "GradientRegistry").def(py::init([]() {
|
||||
auto registry = new GradientRegistry();
|
||||
MaybeRaiseRegisteredFromStatus(RegisterGradients(registry));
|
||||
|
@ -40,10 +40,9 @@ class GradientTape(object):
|
||||
# TODO(srbs): Add support for unconnected_gradients.
|
||||
def gradient(self, targets, sources, output_gradients=None):
|
||||
ctx = context_stack.get_default()
|
||||
vspace = _tape.TapeVSpace(ctx)
|
||||
flat_targets = nest.flatten(targets)
|
||||
flat_sources = nest.flatten(sources)
|
||||
out_grads = self._c_tape.ComputeGradient(vspace, flat_targets, flat_sources,
|
||||
out_grads = self._c_tape.ComputeGradient(ctx, flat_targets, flat_sources,
|
||||
output_gradients or [])
|
||||
return nest.pack_sequence_as(sources, out_grads)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user