Replace C++ gradient helpers with TapeContext in mnist_gradients_testutil.

PiperOrigin-RevId: 334255849
Change-Id: I065dc828ec6485822116bbad088827c6c6b3ef46
This commit is contained in:
Saurabh Saxena 2020-09-28 15:56:58 -07:00 committed by TensorFlower Gardener
parent 8d6c46237c
commit 534cb9ab79
3 changed files with 67 additions and 199 deletions

View File

@ -313,6 +313,7 @@ cc_library(
":gradients_internal", ":gradients_internal",
":gradients_util", ":gradients_util",
":tape", ":tape",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/c/experimental/ops:nn_ops",

View File

@ -25,133 +25,18 @@ limitations under the License.
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h" #include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
// ========================== Tape Ops ==============================
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
namespace internal { namespace internal {
using std::vector; using std::vector;
using tensorflow::tracing::TracingOperation;
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr add_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry) {
AbstractOperationPtr matmul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(matmul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
int num_retvals = 1;
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
int num_retvals = 1;
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractOperationPtr relu_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(relu_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
// one-hot) and records it on the tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry) {
AbstractTensorHandle* scores = inputs[0];
AbstractTensorHandle* labels = inputs[1];
AbstractOperationPtr sm_op(ctx->CreateOperation());
ForwardOperation forward_op;
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
/*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(sm_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
int num_retvals = 2; // returns loss values and backprop
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
//===================== Test Models to run ========================= //===================== Test Models to run =========================
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1); std::vector<AbstractTensorHandle*> add_outputs(1);
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
registry)); // Compute x+y. TF_RETURN_IF_ERROR(
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y. tape->Watch(ToId(inputs[1])); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1); vector<AbstractTensorHandle*> mm_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
/*transpose_b=*/false, registry)); // Compute x*y. absl::MakeSpan(mm_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute x*y.
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
tape->Watch(ToId(W2)); // Watch W2. tape->Watch(ToId(W2)); // Watch W2.
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
absl::MakeSpan(temp_outputs), "relu", absl::MakeSpan(temp_outputs),
registry)); // Compute Relu(X*W1) "relu")); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, TF_RETURN_IF_ERROR(ops::MatMul(
absl::MakeSpan(temp_outputs), "matmul1", tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
/*transpose_a=*/false, /*transpose_b=*/false, "matmul1",
registry)); // Compute W2*Relu(X*W1) /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0]; AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2); temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels) "softmax_loss")); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0]; AbstractTensorHandle* loss_vals = temp_outputs[0];
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); tape->Watch(ToId(W1));
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/true, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/true,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch X tape->Watch(ToId(inputs[0])); // Watch X
vector<AbstractTensorHandle*> relu_outputs(1); vector<AbstractTensorHandle*> relu_outputs(1);
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"relu0", registry)); // Relu(X) TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(relu_outputs),
"relu0")); // Relu(X)
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch scores. tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels. tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2); vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
std::unordered_map<tensorflow::int64, TapeTensor> std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets; source_tensors_that_are_targets;
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
tape->Watch(ToId(W1)); // Watch W1. tape->Watch(ToId(W1)); // Watch W1.
tape->Watch(ToId(W2)); // Watch W1. tape->Watch(ToId(W2)); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0]; AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1) absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0", registry)); "relu0"));
AbstractTensorHandle* hidden = temp_outputs[0]; AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, TF_RETURN_IF_ERROR(ops::MatMul(
absl::MakeSpan(temp_outputs), "matmul1", tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false, /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
registry)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0]; AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2); temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1) "softmaxloss")); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0]; AbstractTensorHandle* loss = temp_outputs[0];
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
vector<AbstractTensorHandle*> temp_outputs(1); vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"scalarMul0", registry)); // Compute eta*A TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
absl::MakeSpan(temp_outputs),
"scalarMul0")); // Compute eta*A
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1); std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"matmul0", /*transpose_a=*/false, TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
/*transpose_b=*/false, registry)); // Compute X*W1 absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
delete tape; delete tape;
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1); std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs), AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
"mul0", registry)); // Compute x*y TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
absl::MakeSpan(temp_outputs),
"mul0")); // Compute x*y
outputs[0] = temp_outputs[0]; outputs[0] = temp_outputs[0];
delete tape; delete tape;
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
TapeVSpace vspace(ctx); TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false); auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2); std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss", TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
registry)); tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
outputs[0] = temp_outputs[0]; // loss values outputs[0] = temp_outputs[0]; // loss values

View File

@ -29,45 +29,10 @@ limitations under the License.
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
// ========================== Tape Ops ==============================
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
namespace internal { namespace internal {
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
Status MatMul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b,
const GradientRegistry& registry);
// Computes `inputs[0] * inputs[1]` and records it on the tape.
Status Mul(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `Relu(inputs[0])` and records it on the tape.
Status Relu(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
const GradientRegistry& registry);
// ====================== End Tape Ops ============================
// Computes // Computes
// y = inputs[0] + inputs[1] // y = inputs[0] + inputs[1]