Replace C++ gradient helpers with TapeContext in mnist_gradients_testutil.
PiperOrigin-RevId: 334255849 Change-Id: I065dc828ec6485822116bbad088827c6c6b3ef46
This commit is contained in:
parent
8d6c46237c
commit
534cb9ab79
tensorflow/c/eager
@ -313,6 +313,7 @@ cc_library(
|
|||||||
":gradients_internal",
|
":gradients_internal",
|
||||||
":gradients_util",
|
":gradients_util",
|
||||||
":tape",
|
":tape",
|
||||||
|
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||||
"//tensorflow/c/experimental/ops:array_ops",
|
"//tensorflow/c/experimental/ops:array_ops",
|
||||||
"//tensorflow/c/experimental/ops:math_ops",
|
"//tensorflow/c/experimental/ops:math_ops",
|
||||||
"//tensorflow/c/experimental/ops:nn_ops",
|
"//tensorflow/c/experimental/ops:nn_ops",
|
||||||
|
@ -25,133 +25,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
#include "tensorflow/c/eager/gradients_internal.h"
|
#include "tensorflow/c/eager/gradients_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients_util.h"
|
#include "tensorflow/c/eager/gradients_util.h"
|
||||||
|
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
|
|
||||||
// ========================== Tape Ops ==============================
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
using tensorflow::tracing::TracingOperation;
|
|
||||||
|
|
||||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
|
||||||
Status Add(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(add_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
|
||||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
bool transpose_a, bool transpose_b,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
|
||||||
/*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(matmul_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
|
||||||
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
|
||||||
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(mul_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
|
||||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(relu_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
|
||||||
int num_retvals = 1;
|
|
||||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
|
|
||||||
// one-hot) and records it on the tape.
|
|
||||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
|
||||||
AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry) {
|
|
||||||
AbstractTensorHandle* scores = inputs[0];
|
|
||||||
AbstractTensorHandle* labels = inputs[1];
|
|
||||||
|
|
||||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
|
||||||
ForwardOperation forward_op;
|
|
||||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
|
||||||
/*raw_device_name=*/nullptr, &forward_op));
|
|
||||||
if (isa<TracingOperation>(sm_op.get())) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
|
||||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
|
||||||
|
|
||||||
int num_retvals = 2; // returns loss values and backprop
|
|
||||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
|
||||||
registry);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===================== Test Models to run =========================
|
//===================== Test Models to run =========================
|
||||||
|
|
||||||
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
registry)); // Compute x+y.
|
TF_RETURN_IF_ERROR(
|
||||||
|
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
|
|
||||||
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||||
/*transpose_b=*/false, registry)); // Compute x*y.
|
absl::MakeSpan(mm_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute x*y.
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W2)); // Watch W2.
|
tape->Watch(ToId(W2)); // Watch W2.
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
|
||||||
absl::MakeSpan(temp_outputs), "relu",
|
absl::MakeSpan(temp_outputs),
|
||||||
registry)); // Compute Relu(X*W1)
|
"relu")); // Compute Relu(X*W1)
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
|
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||||
absl::MakeSpan(temp_outputs), "matmul1",
|
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
|
||||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
"matmul1",
|
||||||
registry)); // Compute W2*Relu(X*W1)
|
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||||
|
|
||||||
AbstractTensorHandle* scores = temp_outputs[0];
|
AbstractTensorHandle* scores = temp_outputs[0];
|
||||||
|
|
||||||
temp_outputs.resize(2);
|
temp_outputs.resize(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||||
|
|
||||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||||
|
|
||||||
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W1));
|
tape->Watch(ToId(W1));
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/true,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/true,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
|
|
||||||
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
|
|||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
tape->Watch(ToId(inputs[0])); // Watch X
|
tape->Watch(ToId(inputs[0])); // Watch X
|
||||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"relu0", registry)); // Relu(X)
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||||
|
absl::MakeSpan(relu_outputs),
|
||||||
|
"relu0")); // Relu(X)
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
|
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
|
||||||
|
|
||||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||||
source_tensors_that_are_targets;
|
source_tensors_that_are_targets;
|
||||||
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
|
|||||||
tape->Watch(ToId(W1)); // Watch W1.
|
tape->Watch(ToId(W1)); // Watch W1.
|
||||||
tape->Watch(ToId(W2)); // Watch W1.
|
tape->Watch(ToId(W2)); // Watch W1.
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
AbstractTensorHandle* mm = temp_outputs[0];
|
AbstractTensorHandle* mm = temp_outputs[0];
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
|
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
|
||||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||||
"relu0", registry));
|
"relu0"));
|
||||||
|
|
||||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
|
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||||
absl::MakeSpan(temp_outputs), "matmul1",
|
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
|
||||||
registry)); // W2*Relu(X*W1)
|
|
||||||
|
|
||||||
AbstractTensorHandle* scores = temp_outputs[0];
|
AbstractTensorHandle* scores = temp_outputs[0];
|
||||||
|
|
||||||
temp_outputs.resize(2);
|
temp_outputs.resize(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
"softmaxloss")); // W2*Relu(X*W1)
|
||||||
|
|
||||||
AbstractTensorHandle* loss = temp_outputs[0];
|
AbstractTensorHandle* loss = temp_outputs[0];
|
||||||
|
|
||||||
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
|
|||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"scalarMul0", registry)); // Compute eta*A
|
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
|
||||||
|
absl::MakeSpan(temp_outputs),
|
||||||
|
"scalarMul0")); // Compute eta*A
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
|
|
||||||
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"matmul0", /*transpose_a=*/false,
|
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
absl::MakeSpan(temp_outputs), "matmul0",
|
||||||
|
/*transpose_a=*/false,
|
||||||
|
/*transpose_b=*/false)); // Compute X*W1
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
delete tape;
|
delete tape;
|
||||||
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
"mul0", registry)); // Compute x*y
|
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
|
||||||
|
absl::MakeSpan(temp_outputs),
|
||||||
|
"mul0")); // Compute x*y
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0];
|
outputs[0] = temp_outputs[0];
|
||||||
delete tape;
|
delete tape;
|
||||||
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
|
|||||||
TapeVSpace vspace(ctx);
|
TapeVSpace vspace(ctx);
|
||||||
auto tape = new Tape(/*persistent=*/false);
|
auto tape = new Tape(/*persistent=*/false);
|
||||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||||
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
|
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||||
registry));
|
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
|
||||||
|
|
||||||
outputs[0] = temp_outputs[0]; // loss values
|
outputs[0] = temp_outputs[0]; // loss values
|
||||||
|
|
||||||
|
@ -29,45 +29,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
// ========================== Tape Ops ==============================
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace gradients {
|
namespace gradients {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
|
||||||
Status Add(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
|
||||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
bool transpose_a, bool transpose_b,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
|
||||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
|
||||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
|
||||||
// tape.
|
|
||||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
|
||||||
AbstractContext* ctx, Tape* tape,
|
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
|
||||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
|
||||||
const GradientRegistry& registry);
|
|
||||||
|
|
||||||
// ====================== End Tape Ops ============================
|
|
||||||
|
|
||||||
// Computes
|
// Computes
|
||||||
// y = inputs[0] + inputs[1]
|
// y = inputs[0] + inputs[1]
|
||||||
|
Loading…
Reference in New Issue
Block a user