- The main change is to get rid of int64 tensor ids from Tape and directly use AbstractTensorHandles. - Get rid of tensorflow::gradients::Context and directly use AbstractContext*. - Get rid of DefaultGradientFunction and BackwardFunction(which was a wrapper for a GradientFunction and DefaultGradientFunction). We had introduced DefaultGradientFunction in order to support existing python gradient functions which expect all necessary incoming grads to be non-None. This is only relevant for ops with more than one output, which are few. We could handle those by creating a wrapper GradientFunction that builds the zeros if needed. Getting rid of DefaultGradientFunction greatly simplifies the API. - Introduce ForwardOperation::skip_input_indices. This will be filled up in a follow-up change. There is a bug tracking this. - Introduce helpers for implementing behavior of tf.no_gradient and tf.stop_gradient, i.e. RegisterNotDifferentiable and NotDifferentiableGradientFunction. - One slight behavior change: Currently, when an op does not have a GradientFunction registered we silently record a nullptr GradientFunction on the tape. This sometimes leads to uninformative error messages. Now we loudly raise an error when GradientRegistry::Lookup fails in TapeContext::Execute. So any op executing under a TapeContext must have a registered GradientFunction. Non-differentiable ops need to be explicitly registered using RegisterNotDifferentiable e.g. CheckNumerics in gradients_test.cc c/eager/tape.h: I changed the signatures of gradient functions to use `absl::Span<Gradient*>` instead of `vector<Gradient*>*` for the result grads. This makes it consistent with the new Tape API and generally makes things cleaner. PiperOrigin-RevId: 345534016 Change-Id: Ie1bf5dff88f87390e6b470acc379d3852ce68b5c
177 lines
6.9 KiB
C++
177 lines
6.9 KiB
C++
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_C_EAGER_GRADIENTS_H_
|
|
#define TENSORFLOW_C_EAGER_GRADIENTS_H_
|
|
|
|
#include "absl/container/flat_hash_map.h"
|
|
#include "tensorflow/c/eager/abstract_context.h"
|
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
|
#include "tensorflow/c/eager/tape.h"
|
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
|
|
|
namespace tensorflow {
|
|
namespace gradients {
|
|
|
|
// =============== Experimental C++ API for computing gradients ===============
|
|
|
|
// Sample gradient function:
|
|
//
|
|
// class AddGradientFunction : public GradientFunction {
|
|
// public:
|
|
// Status Compute(Context* ctx,
|
|
// absl::Span<AbstractTensorHandle* const> grad_inputs,
|
|
// absl::Span<AbstractTensorHandle*> grad_outputs) override {
|
|
// grad_outputs[0] = grad_inputs[0];
|
|
// grad_outputs[1] = grad_inputs[0];
|
|
// grad_outputs[0]->Ref();
|
|
// grad_outputs[1]->Ref();
|
|
// return Status::OK();
|
|
// }
|
|
// ~AddGradientFunction() override {}
|
|
// };
|
|
//
|
|
// GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
|
// // More complex gradient functions can use inputs/attrs etc. from the
|
|
// // forward `op`.
|
|
// return new AddGradientFunction;
|
|
// }
|
|
//
|
|
// Status RegisterGradients(GradientRegistry* registry) {
|
|
// return registry->Register("Add", AddRegisterer);
|
|
// }
|
|
class GradientFunction {
|
|
public:
|
|
virtual Status Compute(AbstractContext* ctx,
|
|
absl::Span<AbstractTensorHandle* const> grad_outputs,
|
|
absl::Span<AbstractTensorHandle*> grad_inputs) = 0;
|
|
virtual ~GradientFunction() {}
|
|
};
|
|
|
|
// Metadata from the forward operation that is made available to the
|
|
// gradient registerer to instantiate a GradientFunction.
|
|
struct ForwardOperation {
|
|
public:
|
|
string op_name;
|
|
std::vector<AbstractTensorHandle*> inputs;
|
|
std::vector<AbstractTensorHandle*> outputs;
|
|
std::vector<int64> skip_input_indices;
|
|
AttrBuilder attrs;
|
|
};
|
|
|
|
using GradientFunctionFactory =
|
|
std::function<GradientFunction*(const ForwardOperation& op)>;
|
|
|
|
// Map from op name to a `GradientFunctionFactory`.
|
|
class GradientRegistry {
|
|
public:
|
|
Status Register(const string& op,
|
|
GradientFunctionFactory gradient_function_factory);
|
|
Status Lookup(const ForwardOperation& op,
|
|
std::unique_ptr<GradientFunction>* gradient_function) const;
|
|
|
|
private:
|
|
absl::flat_hash_map<string, GradientFunctionFactory> registry_;
|
|
};
|
|
|
|
// TODO(srbs): Figure out if we can avoid declaring this in the public header.
|
|
// Wrapper for a tensor output of an operation executing under a tape.
|
|
//
|
|
// `GetID` returns a unique id for the wrapped tensor which is used to maintain
|
|
// a map (`tensorflow::eager::TensorTape`) from the wrapped tensor to the id of
|
|
// the op that produced it (or -1 if this tensor was watched using
|
|
// `GradientTape::Watch`.) The op_id is simply a unique index assigned to each
|
|
// op executed under the tape. A separate map (`tensorflow::eager::OpTape`)
|
|
// maintains the map from `op_id` to a `OpTapeEntry` which stores the `op_type`,
|
|
// inputs and outputs and the gradient function These data structures combined
|
|
// allow us to trace the data dependencies between operations and hence compute
|
|
// gradients.
|
|
//
|
|
// `ZerosLike` is not expected to be called and returns a nullptr. The creation
|
|
// of default zeros grads is handled by the `DefaultGradientFunction` registered
|
|
// for each op.
|
|
// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy.
|
|
// Figure out a way to avoid this.
|
|
// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr?
|
|
class TapeTensor {
|
|
public:
|
|
explicit TapeTensor(AbstractTensorHandle* handle);
|
|
TapeTensor(const TapeTensor& other);
|
|
~TapeTensor();
|
|
|
|
tensorflow::int64 GetID() const;
|
|
tensorflow::DataType GetDType() const;
|
|
|
|
AbstractTensorHandle* ZerosLike() const;
|
|
|
|
AbstractTensorHandle* GetHandle() const;
|
|
|
|
private:
|
|
AbstractTensorHandle* handle_;
|
|
};
|
|
|
|
// A tracing/immediate-execution agnostic tape.
|
|
//
|
|
// Gradient functions defined for this tape must support handling null incoming
|
|
// gradients.
|
|
class Tape : protected eager::GradientTape<AbstractTensorHandle,
|
|
GradientFunction, TapeTensor> {
|
|
public:
|
|
using GradientTape<AbstractTensorHandle, GradientFunction,
|
|
TapeTensor>::GradientTape;
|
|
// Returns whether the tape is persistent, i.e., whether the tape will hold
|
|
// onto its internal state after a call to `ComputeGradient`.
|
|
using GradientTape<AbstractTensorHandle, GradientFunction,
|
|
TapeTensor>::IsPersistent;
|
|
|
|
// Adds this tensor to the list of watched tensors.
|
|
//
|
|
// This is a no-op if the tensor is already being watched either from an
|
|
// earlier call to `GradientTape::Watch` or being an output of an op with
|
|
// watched inputs.
|
|
void Watch(const AbstractTensorHandle*);
|
|
// Records an operation with given inputs and outputs
|
|
// on the tape and marks all its outputs as watched if at
|
|
// least one input of the op is watched and has a trainable dtype.
|
|
// op_name is optional and is used for debugging only.
|
|
void RecordOperation(absl::Span<AbstractTensorHandle* const> inputs,
|
|
absl::Span<AbstractTensorHandle* const> outputs,
|
|
GradientFunction* gradient_function,
|
|
const string& op_name = "");
|
|
// Returns whether any tensor in a list of tensors is being watched and has
|
|
// a trainable dtype.
|
|
bool ShouldRecord(
|
|
absl::Span<const AbstractTensorHandle* const> tensors) const;
|
|
// Unwatches this tensor on the tape. Mainly used for cleanup when deleting
|
|
// eager tensors.
|
|
void DeleteTrace(const AbstractTensorHandle*);
|
|
|
|
// Consumes the internal state of the tape (so cannot be called more than
|
|
// once unless the tape is persistent) and produces the gradient of the target
|
|
// tensors with respect to the source tensors. The output gradients are used
|
|
// if not empty and not null. The result is populated with one tensor per
|
|
// target element.
|
|
Status ComputeGradient(
|
|
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> targets,
|
|
absl::Span<AbstractTensorHandle* const> sources,
|
|
absl::Span<AbstractTensorHandle* const> output_gradients,
|
|
absl::Span<AbstractTensorHandle*> result);
|
|
};
|
|
|
|
} // namespace gradients
|
|
} // namespace tensorflow
|
|
|
|
#endif // TENSORFLOW_C_EAGER_GRADIENTS_H_
|