Replace C++ gradient helpers with TapeContext in mnist_gradients_testutil.
PiperOrigin-RevId: 334255849 Change-Id: I065dc828ec6485822116bbad088827c6c6b3ef46
This commit is contained in:
parent
8d6c46237c
commit
534cb9ab79
tensorflow/c/eager
@ -313,6 +313,7 @@ cc_library(
|
||||
":gradients_internal",
|
||||
":gradients_util",
|
||||
":tape",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
|
@ -25,133 +25,18 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::tracing::TracingOperation;
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_a", transpose_a, &forward_op));
|
||||
TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool(
|
||||
matmul_op.get(), "transpose_b", transpose_b, &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(mul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(relu_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
|
||||
// one-hot) and records it on the tape.
|
||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* scores = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits",
|
||||
/*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<TracingOperation>(sm_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
||||
|
||||
int num_retvals = 2; // returns loss values and backprop
|
||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
//===================== Test Models to run =========================
|
||||
|
||||
@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs),
|
||||
registry)); // Compute x+y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute x*y.
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(mm_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute x*y.
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W2)); // Watch W2.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs), "relu",
|
||||
registry)); // Compute Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"relu")); // Compute Relu(X*W1)
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // Compute W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs),
|
||||
"matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||
|
||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||
|
||||
@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W1));
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/true,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/true,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch X
|
||||
vector<AbstractTensorHandle*> relu_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs),
|
||||
"relu0", registry)); // Relu(X)
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(relu_outputs),
|
||||
"relu0")); // Relu(X)
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(inputs[0])); // Watch scores.
|
||||
tape->Watch(ToId(inputs[1])); // Watch labels.
|
||||
vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0"));
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx,
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
tape->Watch(ToId(W2)); // Watch W1.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
AbstractTensorHandle* mm = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0", registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0"));
|
||||
|
||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2},
|
||||
absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false,
|
||||
registry)); // W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss")); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* loss = temp_outputs[0];
|
||||
|
||||
@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0", registry)); // Compute eta*A
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"scalarMul0")); // Compute eta*A
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx,
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx,
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
|
||||
"mul0", registry)); // Compute x*y
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"mul0")); // Compute x*y
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
delete tape;
|
||||
@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx,
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
|
||||
registry));
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss"));
|
||||
|
||||
outputs[0] = temp_outputs[0]; // loss values
|
||||
|
||||
|
@ -29,45 +29,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
// ========================== Tape Ops ==============================
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
|
||||
// tape.
|
||||
Status SparseSoftmaxCrossEntropyWithLogits(
|
||||
AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// ====================== End Tape Ops ============================
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
|
Loading…
Reference in New Issue
Block a user