diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 7a4824257c2..54771ffa840 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -313,6 +313,7 @@ cc_library( ":gradients_internal", ":gradients_util", ":tape", + "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index 8354b37354e..6688d9d4e75 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -25,133 +25,18 @@ limitations under the License. #include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -// ========================== Tape Ops ============================== namespace tensorflow { namespace gradients { namespace internal { using std::vector; -using tensorflow::tracing::TracingOperation; - -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, - const GradientRegistry& registry) { - AbstractOperationPtr add_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR( - Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); - if (isa<TracingOperation>(add_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add")); - } - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); - int num_retvals = 1; - return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. -Status MatMul(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - bool transpose_a, bool transpose_b, - const GradientRegistry& registry) { - AbstractOperationPtr matmul_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa<TracingOperation>(matmul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast<TracingOperation>(matmul_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op)); - TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( - matmul_op.get(), "transpose_a", transpose_a, &forward_op)); - TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( - matmul_op.get(), "transpose_b", transpose_b, &forward_op)); - - int num_retvals = 1; - return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -Status Mul(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - const GradientRegistry& registry) { - AbstractOperationPtr mul_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR( - Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op)); - if (isa<TracingOperation>(mul_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op)); - TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op)); - - int num_retvals = 1; - return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `Relu(inputs[0])` and records it on the tape. -Status Relu(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - const GradientRegistry& registry) { - AbstractOperationPtr relu_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR( - Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op)); - if (isa<TracingOperation>(relu_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast<TracingOperation>(relu_op.get())->SetOpName(name)); - } - TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op)); - int num_retvals = 1; - return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} - -// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not -// one-hot) and records it on the tape. -Status SparseSoftmaxCrossEntropyWithLogits( - AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - const GradientRegistry& registry) { - AbstractTensorHandle* scores = inputs[0]; - AbstractTensorHandle* labels = inputs[1]; - - AbstractOperationPtr sm_op(ctx->CreateOperation()); - ForwardOperation forward_op; - TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", - /*raw_device_name=*/nullptr, &forward_op)); - if (isa<TracingOperation>(sm_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast<TracingOperation>(sm_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op)); - TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op)); - - int num_retvals = 2; // returns loss values and backprop - return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, - registry); -} //===================== Test Models to run ========================= @@ -167,8 +52,9 @@ Status AddGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. std::vector<AbstractTensorHandle*> add_outputs(1); - TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), - registry)); // Compute x+y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR( + ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add")); std::unordered_map<tensorflow::int64, TapeTensor> source_tensors_that_are_targets; @@ -200,9 +86,11 @@ Status MatMulGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch x. tape->Watch(ToId(inputs[1])); // Watch y. vector<AbstractTensorHandle*> mm_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute x*y. + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs, + absl::MakeSpan(mm_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute x*y. std::unordered_map<tensorflow::int64, TapeTensor> source_tensors_that_are_targets; @@ -256,25 +144,27 @@ Status MNISTForwardModel(AbstractContext* ctx, tape->Watch(ToId(W2)); // Watch W2. vector<AbstractTensorHandle*> temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 - TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, - absl::MakeSpan(temp_outputs), "relu", - registry)); // Compute Relu(X*W1) + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {temp_outputs[0]}, + absl::MakeSpan(temp_outputs), + "relu")); // Compute Relu(X*W1) - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, - absl::MakeSpan(temp_outputs), "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false, - registry)); // Compute W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::MatMul( + tape_ctx.get(), {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), + "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1) AbstractTensorHandle* scores = temp_outputs[0]; temp_outputs.resize(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmax_loss", registry)); // Compute Softmax(Scores,labels) + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmax_loss")); // Compute Softmax(Scores,labels) AbstractTensorHandle* loss_vals = temp_outputs[0]; @@ -297,9 +187,11 @@ Status MatMulTransposeModel(AbstractContext* ctx, tape->Watch(ToId(W1)); vector<AbstractTensorHandle*> temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/true, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/true, + /*transpose_b=*/false)); // Compute X*W1 outputs[0] = temp_outputs[0]; @@ -315,8 +207,10 @@ Status ReluGradModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); tape->Watch(ToId(inputs[0])); // Watch X vector<AbstractTensorHandle*> relu_outputs(1); - TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), - "relu0", registry)); // Relu(X) + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs, + absl::MakeSpan(relu_outputs), + "relu0")); // Relu(X) std::unordered_map<tensorflow::int64, TapeTensor> source_tensors_that_are_targets; @@ -346,8 +240,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch scores. tape->Watch(ToId(inputs[1])); // Watch labels. vector<AbstractTensorHandle*> sm_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), inputs, absl::MakeSpan(sm_outputs), "softmax0")); std::unordered_map<tensorflow::int64, TapeTensor> source_tensors_that_are_targets; @@ -381,29 +276,30 @@ Status MNISTGradModel(AbstractContext* ctx, tape->Watch(ToId(W1)); // Watch W1. tape->Watch(ToId(W2)); // Watch W1. vector<AbstractTensorHandle*> temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 AbstractTensorHandle* mm = temp_outputs[0]; - TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, - absl::MakeSpan(temp_outputs), // Relu(X*W1) - "relu0", registry)); + TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm}, + absl::MakeSpan(temp_outputs), // Relu(X*W1) + "relu0")); AbstractTensorHandle* hidden = temp_outputs[0]; - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, - absl::MakeSpan(temp_outputs), "matmul1", - /*transpose_a=*/false, /*transpose_b=*/false, - registry)); // W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::MatMul( + tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1) AbstractTensorHandle* scores = temp_outputs[0]; temp_outputs.resize(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), - "softmaxloss", registry)); // W2*Relu(X*W1) + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmaxloss")); // W2*Relu(X*W1) AbstractTensorHandle* loss = temp_outputs[0]; @@ -440,8 +336,10 @@ Status ScalarMulModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); vector<AbstractTensorHandle*> temp_outputs(1); - TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), - "scalarMul0", registry)); // Compute eta*A + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {eta, A}, + absl::MakeSpan(temp_outputs), + "scalarMul0")); // Compute eta*A outputs[0] = temp_outputs[0]; @@ -459,9 +357,11 @@ Status MatMulModel(AbstractContext* ctx, TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector<AbstractTensorHandle*> temp_outputs(1); - TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), - "matmul0", /*transpose_a=*/false, - /*transpose_b=*/false, registry)); // Compute X*W1 + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1}, + absl::MakeSpan(temp_outputs), "matmul0", + /*transpose_a=*/false, + /*transpose_b=*/false)); // Compute X*W1 outputs[0] = temp_outputs[0]; delete tape; @@ -478,8 +378,10 @@ Status MulModel(AbstractContext* ctx, TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector<AbstractTensorHandle*> temp_outputs(1); - TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs), - "mul0", registry)); // Compute x*y + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), {x, y}, + absl::MakeSpan(temp_outputs), + "mul0")); // Compute x*y outputs[0] = temp_outputs[0]; delete tape; @@ -496,9 +398,9 @@ Status SoftmaxModel(AbstractContext* ctx, TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector<AbstractTensorHandle*> temp_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( - ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss", - registry)); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits( + tape_ctx.get(), {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss")); outputs[0] = temp_outputs[0]; // loss values diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index 1cf87bb9dee..b173446ac9b 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -29,45 +29,10 @@ limitations under the License. #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/status.h" -// ========================== Tape Ops ============================== namespace tensorflow { namespace gradients { namespace internal { -// Computes `inputs[0] + inputs[1]` and records it on the tape. -Status Add(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, - const GradientRegistry& registry); - -// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. -Status MatMul(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - bool transpose_a, bool transpose_b, - const GradientRegistry& registry); - -// Computes `inputs[0] * inputs[1]` and records it on the tape. -Status Mul(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - const GradientRegistry& registry); - -// Computes `Relu(inputs[0])` and records it on the tape. -Status Relu(AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - const GradientRegistry& registry); - -// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the -// tape. -Status SparseSoftmaxCrossEntropyWithLogits( - AbstractContext* ctx, Tape* tape, - absl::Span<AbstractTensorHandle* const> inputs, - absl::Span<AbstractTensorHandle*> outputs, const char* name, - const GradientRegistry& registry); - -// ====================== End Tape Ops ============================ // Computes // y = inputs[0] + inputs[1]