diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 2de74cd32ce..69ba9edefa5 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -166,7 +166,9 @@ cc_library( ], deps = [ "//tensorflow/c/eager:gradient_checker", + "//tensorflow/c/eager:gradients_internal", "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc index e8e2164239b..dae395cc349 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.cc +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/experimental/gradients/grad_test_helper.h" #include "tensorflow/c/eager/gradient_checker.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -105,6 +106,34 @@ void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, delete[] danalytical; } +Model BuildGradModel(Model forward, size_t num_inputs, const string& op, + GradientFunctionFactory gradient_function_factory) { + return [&, forward, gradient_function_factory]( + AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) -> Status { + GradientRegistry registry; + TF_RETURN_IF_ERROR(registry.Register(op, gradient_function_factory)); + + Tape tape(/*persistent=*/false); + for (size_t i{}; i < num_inputs; ++i) { + tape.Watch(inputs[i]); + } + std::vector temp_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); + TF_RETURN_IF_ERROR( + forward(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs))); + + TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, + /*sources=*/inputs, + /*output_gradients=*/{}, outputs)); + for (auto temp_output : temp_outputs) { + temp_output->Unref(); + } + return Status::OK(); + }; +} + } // namespace internal } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h index ca378eaef39..653b5470a97 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.h +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ #define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ +#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/unified_api_testutil.h" namespace tensorflow { @@ -29,6 +30,9 @@ void CompareNumericalAndAutodiffGradients( void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, absl::Span dims, double abs_error = 1e-2); +Model BuildGradModel(Model forward, size_t num_inputs, const string& op, + GradientFunctionFactory gradient_function_factory); + } // namespace internal } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad_test.cc b/tensorflow/c/experimental/gradients/math_grad_test.cc index bce5352d24e..3fea07da914 100644 --- a/tensorflow/c/experimental/gradients/math_grad_test.cc +++ b/tensorflow/c/experimental/gradients/math_grad_test.cc @@ -36,199 +36,42 @@ Status AddModel(AbstractContext* ctx, return ops::Add(ctx, inputs, outputs, "Add"); } -Status AddGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - GradientRegistry registry; - TF_RETURN_IF_ERROR(registry.Register("AddV2", AddRegisterer)); - - Tape tape(/*persistent=*/false); - tape.Watch(inputs[0]); - tape.Watch(inputs[1]); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); - TF_RETURN_IF_ERROR(ops::Add(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "AddGrad")); - - TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, - /*sources=*/inputs, - /*output_gradients=*/{}, outputs)); - for (auto temp_output : temp_outputs) { - temp_output->Unref(); - } - return Status::OK(); -} - Status ExpModel(AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { return ops::Exp(ctx, inputs, outputs, "Exp"); } -Status ExpGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - GradientRegistry registry; - TF_RETURN_IF_ERROR(registry.Register("Exp", ExpRegisterer)); - - Tape tape(/*persistent=*/false); - tape.Watch(inputs[0]); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); - TF_RETURN_IF_ERROR(ops::Exp(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "ExpGrad")); - - TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, - /*sources=*/inputs, - /*output_gradients=*/{}, outputs)); - for (auto temp_output : temp_outputs) { - temp_output->Unref(); - } - return Status::OK(); -} - Status SqrtModel(AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { return ops::Sqrt(ctx, inputs, outputs, "Sqrt"); } -Status SqrtGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - GradientRegistry registry; - TF_RETURN_IF_ERROR(registry.Register("Sqrt", SqrtRegisterer)); - - Tape tape(/*persistent=*/false); - tape.Watch(inputs[0]); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); - TF_RETURN_IF_ERROR(ops::Sqrt(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "SqrtGrad")); - - TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, - /*sources=*/inputs, - /*output_gradients=*/{}, outputs)); - for (auto temp_output : temp_outputs) { - temp_output->Unref(); - } - return Status::OK(); -} - Status NegModel(AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { return ops::Neg(ctx, inputs, outputs, "Neg"); } -Status NegGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - GradientRegistry registry; - TF_RETURN_IF_ERROR(registry.Register("Neg", NegRegisterer)); - - Tape tape(/*persistent=*/false); - tape.Watch(inputs[0]); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); - TF_RETURN_IF_ERROR(ops::Neg(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "NegGrad")); - - TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, - /*sources=*/inputs, - /*output_gradients=*/{}, outputs)); - for (auto temp_output : temp_outputs) { - temp_output->Unref(); - } - return Status::OK(); -} - Status SubModel(AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { return ops::Sub(ctx, inputs, outputs, "Sub"); } -Status SubGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - GradientRegistry registry; - TF_RETURN_IF_ERROR(registry.Register("Sub", SubRegisterer)); - - Tape tape(/*persistent=*/false); - tape.Watch(inputs[0]); - tape.Watch(inputs[1]); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); - TF_RETURN_IF_ERROR(ops::Sub(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "SubGrad")); - - TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, - /*sources=*/inputs, - /*output_gradients=*/{}, outputs)); - for (auto temp_output : temp_outputs) { - temp_output->Unref(); - } - return Status::OK(); -} - Status MulModel(AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { return ops::Mul(ctx, inputs, outputs, "Mul"); } -Status MulGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - GradientRegistry registry; - TF_RETURN_IF_ERROR(registry.Register("Mul", MulRegisterer)); - - Tape tape(/*persistent=*/false); - tape.Watch(inputs[0]); - tape.Watch(inputs[1]); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); - TF_RETURN_IF_ERROR(ops::Mul(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "MulGrad")); - - TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, - /*sources=*/inputs, - /*output_gradients=*/{}, outputs)); - for (auto temp_output : temp_outputs) { - temp_output->Unref(); - } - return Status::OK(); -} - Status Log1pModel(AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { return ops::Log1p(ctx, inputs, outputs, "Log1p"); } -Status Log1pGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - GradientRegistry registry; - TF_RETURN_IF_ERROR(registry.Register("Log1p", Log1pRegisterer)); - - Tape tape(/*persistent=*/false); - tape.Watch(inputs[0]); - std::vector temp_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry)); - TF_RETURN_IF_ERROR(ops::Log1p(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "Log1pGrad")); - - TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs, - /*sources=*/inputs, - /*output_gradients=*/{}, outputs)); - for (auto temp_output : temp_outputs) { - temp_output->Unref(); - } - return Status::OK(); -} - Status DivNoNanModel(AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { @@ -301,7 +144,8 @@ TEST_P(CppGradients, TestAddGrad) { } ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( - AddModel, AddGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction())); + AddModel, BuildGradModel(AddModel, 2, "AddV2", AddRegisterer), ctx_.get(), + {x.get(), y.get()}, UseFunction())); } TEST_P(CppGradients, TestExpGrad) { @@ -314,7 +158,8 @@ TEST_P(CppGradients, TestExpGrad) { } ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( - ExpModel, ExpGradModel, ctx_.get(), {x.get()}, UseFunction())); + ExpModel, BuildGradModel(ExpModel, 1, "Exp", ExpRegisterer), ctx_.get(), + {x.get()}, UseFunction())); } TEST_P(CppGradients, TestSqrtGrad) { @@ -327,7 +172,8 @@ TEST_P(CppGradients, TestSqrtGrad) { } ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( - SqrtModel, SqrtGradModel, ctx_.get(), {x.get()}, UseFunction())); + SqrtModel, BuildGradModel(SqrtModel, 1, "Sqrt", SqrtRegisterer), + ctx_.get(), {x.get()}, UseFunction())); } TEST_P(CppGradients, TestNegGrad) { @@ -340,7 +186,8 @@ TEST_P(CppGradients, TestNegGrad) { } ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( - NegModel, NegGradModel, ctx_.get(), {x.get()}, UseFunction())); + NegModel, BuildGradModel(NegModel, 1, "Neg", NegRegisterer), ctx_.get(), + {x.get()}, UseFunction())); } TEST_P(CppGradients, TestSubGrad) { @@ -361,7 +208,8 @@ TEST_P(CppGradients, TestSubGrad) { } ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( - SubModel, SubGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction())); + SubModel, BuildGradModel(SubModel, 2, "Sub", SubRegisterer), ctx_.get(), + {x.get(), y.get()}, UseFunction())); } TEST_P(CppGradients, TestMulGrad) { @@ -382,7 +230,8 @@ TEST_P(CppGradients, TestMulGrad) { } ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( - MulModel, MulGradModel, ctx_.get(), {x.get(), y.get()}, UseFunction())); + MulModel, BuildGradModel(MulModel, 2, "Mul", MulRegisterer), ctx_.get(), + {x.get(), y.get()}, UseFunction())); } TEST_P(CppGradients, TestLog1pGrad) { @@ -395,10 +244,13 @@ TEST_P(CppGradients, TestLog1pGrad) { } ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( - Log1pModel, Log1pGradModel, ctx_.get(), {x.get()}, UseFunction())); + Log1pModel, BuildGradModel(Log1pModel, 1, "Log1p", Log1pRegisterer), + ctx_.get(), {x.get()}, UseFunction())); } TEST_P(CppGradients, TestDivNoNanGrad) { + // TODO(vnvo2409): Figure out why `BuildGradModel` does not work with + // `DivNoNan`. AbstractTensorHandlePtr x; { AbstractTensorHandle* x_raw = nullptr;