diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc index 0c01ab95cf8..77e26aaf0dd 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.cc +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -74,37 +74,35 @@ void CompareNumericalAndAutodiffGradients( } } -void CompareManualAndAutodiffGradients( - Model grad_model, AbstractContext* ctx, - absl::Span inputs, - absl::Span manuals, bool use_function, double abs_error) { - auto num_inputs = inputs.size(); - std::vector outputs(num_inputs); - auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs), - /*use_function=*/use_function); +void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, + absl::Span dims, double abs_error) { + TF_Tensor* analytical_tensor; + auto s = GetValue(t, &analytical_tensor); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - int current_index_manual = 0; - for (int i = 0; i < num_inputs; ++i) { - if (!outputs[i]) continue; - - TF_Tensor* analytical_tensor; - s = GetValue(outputs[i], &analytical_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - auto num_elem_analytical = TF_TensorElementCount(analytical_tensor); - - float* danalytical = new float[num_elem_analytical]{0}; - memcpy(&danalytical[0], TF_TensorData(analytical_tensor), - TF_TensorByteSize(analytical_tensor)); - - for (int j = 0; j < num_elem_analytical; j++) { - ASSERT_NEAR(manuals[current_index_manual], danalytical[j], abs_error); - ++current_index_manual; - } - TF_DeleteTensor(analytical_tensor); - delete[] danalytical; - outputs[i]->Unref(); + int64_t num_elem_analytical = 1; + auto num_dims_analytical = TF_NumDims(analytical_tensor); + ASSERT_EQ(dims.size(), num_dims_analytical); + for (int j = 0; j < num_dims_analytical; j++) { + auto dim_analytical = TF_Dim(analytical_tensor, j); + ASSERT_EQ(dims[j], dim_analytical); + num_elem_analytical *= dim_analytical; } + + float* danalytical = new float[num_elem_analytical]{0}; + memcpy(&danalytical[0], TF_TensorData(analytical_tensor), + TF_TensorByteSize(analytical_tensor)); + + int64_t current_index_manual = 0; + for (int64_t j = 0; j < num_elem_analytical; j++) { + if (abs_error == 0) + ASSERT_EQ(manuals[current_index_manual], danalytical[j]); + else + ASSERT_NEAR(manuals[current_index_manual], danalytical[j], abs_error); + ++current_index_manual; + } + TF_DeleteTensor(analytical_tensor); + delete[] danalytical; } } // namespace internal diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h index 0e1f5b397dc..ca378eaef39 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.h +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -26,14 +26,8 @@ void CompareNumericalAndAutodiffGradients( absl::Span inputs, bool use_function, double abs_error = 1e-2); -// `manuals` should be a flat array of expected results of `grad_model`. e.g if -// `grad_model` output is `[[1, 2], nullptr, [3, 4]]`, `manuals` will be `[1, -// 2, 3, 4]`. -void CompareManualAndAutodiffGradients( - Model grad_model, AbstractContext* ctx, - absl::Span inputs, - absl::Span manuals, bool use_function, - double abs_error = 1e-2); +void CheckTensorValue(AbstractTensorHandle* t, absl::Span manuals, + absl::Span dims, double abs_error = 1e-2); } // namespace internal } // namespace gradients diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 805b79547cc..c85fb692117 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -154,10 +154,8 @@ class CppGradients }; TEST_P(CppGradients, TestReluGrad) { - // Mathematically, Relu isn't differentiable at `0`. So `gradient_checker` - // does not work with it. float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 10.0f, -1.0f}; - int64_t X_dims[] = {2, 2}; + int64_t X_dims[] = {3, 3}; AbstractTensorHandlePtr X; { AbstractTensorHandle* X_raw; @@ -170,6 +168,8 @@ TEST_P(CppGradients, TestReluGrad) { ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( ReluModel, ReluGradModel, ctx_.get(), {X.get()}, UseFunction())); + // Mathematically, Relu isn't differentiable at `0`. So `gradient_checker` + // does not work with it. AbstractTensorHandlePtr Y; { AbstractTensorHandle* Y_raw; @@ -178,8 +178,13 @@ TEST_P(CppGradients, TestReluGrad) { Y.reset(Y_raw); } - ASSERT_NO_FATAL_FAILURE(CompareManualAndAutodiffGradients( - ReluGradModel, ctx_.get(), {Y.get()}, {0.0f}, UseFunction())); + std::vector outputs(1); + auto s = RunModel(ReluGradModel, ctx_.get(), {Y.get()}, + absl::MakeSpan(outputs), UseFunction()); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {}, + /*abs_error*/ 0)); + outputs[0]->Unref(); } TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {