Replace C++ gradient helpers with TapeContext in mnist_gradients_testutil.

PiperOrigin-RevId: 334255849
Change-Id: I065dc828ec6485822116bbad088827c6c6b3ef46
This commit is contained in:
Saurabh Saxena 2020-09-28 15:56:58 -07:00 committed by TensorFlower Gardener
parent 8d6c46237c
commit 534cb9ab79
3 changed files with 67 additions and 199 deletions

View File

@ -313,6 +313,7 @@ cc_library(
":gradients_internal",
":gradients_util",
":tape",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",

View File

@ -25,133 +25,18 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
using std::vector;
using tensorflow::tracing::TracingOperation;
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
int num_retvals = 1;
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
// one-hot) and records it on the tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractTensorHandle* scores = inputs[0];
AbstractTensorHandle* labels = inputs[1];
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
int num_retvals = 2; // returns loss values and backprop
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
//===================== Test Models to run =========================
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
registry)); // Compute x+y.
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute x*y.
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
absl::MakeSpan(mm_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
absl::MakeSpan(temp_outputs), "relu",
registry)); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
absl::MakeSpan(temp_outputs),
"relu")); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // Compute W2*Relu(X*W1)
TF_RETURN_IF_ERROR(ops::MatMul(
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
"matmul1",
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss")); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0];
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/true,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/true,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0];
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1);
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
"relu0", registry)); // Relu(X)
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(relu_outputs),
"relu0")); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0", registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0"));
AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false,
registry)); // W2*Relu(X*W1)
TF_RETURN_IF_ERROR(ops::MatMul(
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1)
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss")); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0];
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
"scalarMul0", registry)); // Compute eta*A
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
absl::MakeSpan(temp_outputs),
"scalarMul0")); // Compute eta*A
outputs[0] = temp_outputs[0];
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
"mul0", registry)); // Compute x*y
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
absl::MakeSpan(temp_outputs),
"mul0")); // Compute x*y
outputs[0] = temp_outputs[0];
delete tape;
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
registry));
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
outputs[0] = temp_outputs[0]; // loss values

View File

@ -29,45 +29,10 @@ limitations under the License.
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
// ========================== Tape Ops ==============================
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` and records it on the tape.
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// ====================== End Tape Ops ============================
// Computes
// y = inputs[0] + inputs[1]