refactor gradient_check to use unified_api_testutil

This commit is contained in:
Võ Văn Nghĩa 2020-12-10 00:30:44 +07:00
parent 2234086df0
commit c18a4cade5
8 changed files with 205 additions and 226 deletions

View File

@ -389,6 +389,7 @@ cc_library(
cc_library( cc_library(
name = "gradient_checker", name = "gradient_checker",
testonly = 1,
srcs = [ srcs = [
"gradient_checker.cc", "gradient_checker.cc",
], ],
@ -399,28 +400,11 @@ cc_library(
"//tensorflow:internal", "//tensorflow:internal",
], ],
deps = [ deps = [
":abstract_tensor_handle", ":unified_api_testutil",
":c_api_experimental", "//tensorflow/c/eager:abstract_tensor_handle",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "@com_google_absl//absl/types:span",
"//tensorflow/cc/profiler", ],
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
) )
tf_cuda_cc_test( tf_cuda_cc_test(
@ -432,36 +416,17 @@ tf_cuda_cc_test(
args = ["--heap_check=local"], args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [ tags = tf_cuda_tests_tags() + [
"nomac",
"no_cuda_asan", # b/175330074 "no_cuda_asan", # b/175330074
"notap", # b/175330074 "notap", # b/175330074
], ],
deps = [ deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradient_checker", ":gradient_checker",
":gradients_internal", ":unified_api_testutil",
":gradients_util",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/experimental/gradients:nn_grad", "//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:tensor_float_32_utils",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
) )

View File

@ -18,18 +18,8 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
@ -45,16 +35,6 @@ void Range(vector<int>* data, int start, int end, int step = 1) {
} }
} }
// Returns AbstractTensorHandlePtr containing [0, ..., n-1].
AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) {
vector<int> vals(n);
int64_t vals_shape[] = {n};
Range(&vals, 0, n);
AbstractTensorHandlePtr r =
GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1);
return r;
}
// Fills out_dims with the dimensions of the given tensor. // Fills out_dims with the dimensions of the given tensor.
void GetDims(const TF_Tensor* t, int64_t* out_dims) { void GetDims(const TF_Tensor* t, int64_t* out_dims) {
int num_dims = TF_NumDims(t); int num_dims = TF_NumDims(t);
@ -69,13 +49,11 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs,
bool use_function) { bool use_function) {
GradientRegistry registry;
std::vector<AbstractTensorHandle*> model_outputs(1); std::vector<AbstractTensorHandle*> model_outputs(1);
// Run the model. // Run the model.
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs, TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
absl::MakeSpan(model_outputs), use_function, absl::MakeSpan(model_outputs), use_function));
registry));
AbstractTensorHandle* model_out = model_outputs[0]; AbstractTensorHandle* model_out = model_outputs[0];
TF_Tensor* model_out_tensor; TF_Tensor* model_out_tensor;
@ -91,8 +69,16 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
// Else, reduce sum the output to get a scalar // Else, reduce sum the output to get a scalar
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1]. // Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
AbstractTensorHandlePtr sum_dims = AbstractTensorHandlePtr sum_dims;
GetRangeTensorHandleUtil(ctx, num_dims_out); {
vector<int> vals(num_dims_out);
int64_t vals_shape[] = {num_dims_out};
Range(&vals, 0, num_dims_out);
AbstractTensorHandle* sum_dims_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsInt(ctx, vals.data(), vals_shape,
1, &sum_dims_raw));
sum_dims.reset(sum_dims_raw);
}
// Reduce sum the output on all dimensions. // Reduce sum the output on all dimensions.
std::vector<AbstractTensorHandle*> sum_inputs(2); std::vector<AbstractTensorHandle*> sum_inputs(2);
@ -145,22 +131,39 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
for (int i = 0; i < num_elems; i++) { for (int i = 0; i < num_elems; i++) {
// Get relative epsilon value // Get relative epsilon value
float epsilon = theta_data[i] == 0 ? 1e-4 : std::abs(theta_data[i] * 1e-4); float epsilon = theta_data[i] == 0 ? 1e-4 : std::abs(theta_data[i] * 1e-4);
AbstractTensorHandlePtr two_eps = AbstractTensorHandlePtr two_eps;
GetScalarTensorHandleUtil(ctx, 2 * epsilon); {
AbstractTensorHandle* two_eps_raw = nullptr;
TF_RETURN_IF_ERROR(
TestScalarTensorHandle(ctx, 2 * epsilon, &two_eps_raw));
two_eps.reset(two_eps_raw);
}
// Initialize theta[i] + epsilon. // Initialize theta[i] + epsilon.
memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor), memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor)); TF_TensorByteSize(theta_tensor));
thetaPlus_data[i] += epsilon; thetaPlus_data[i] += epsilon;
AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat( AbstractTensorHandlePtr thetaPlus;
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims); {
AbstractTensorHandle* thetaPlus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims,
&thetaPlus_raw));
thetaPlus.reset(thetaPlus_raw);
}
// Initialize theta[i] - epsilon. // Initialize theta[i] - epsilon.
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor), memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor)); TF_TensorByteSize(theta_tensor));
thetaMinus_data[i] -= epsilon; thetaMinus_data[i] -= epsilon;
AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat( AbstractTensorHandlePtr thetaMinus;
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims); {
AbstractTensorHandle* thetaMinus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims,
&thetaMinus_raw));
thetaMinus.reset(thetaMinus_raw);
}
// Get f(theta + eps): // Get f(theta + eps):
theta_inputs[input_index] = thetaPlus.get(); theta_inputs[input_index] = thetaPlus.get();
@ -195,7 +198,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
} }
// Populate *numerical_grad with the data from dtheta_approx. // Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat( TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
return Status::OK(); return Status::OK();
} }

View File

@ -12,23 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
#define TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_
#include <memory> #include <memory>
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
@ -51,3 +42,5 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
} // namespace gradients } // namespace gradients
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_GRADIENT_CHECKER_H_

View File

@ -15,21 +15,11 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -37,6 +27,54 @@ namespace gradients {
namespace internal { namespace internal {
namespace { namespace {
using tensorflow::TF_StatusPtr;
void CompareNumericalAndManualGradients(
Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, int input_index,
float* expected_grad, int num_grad, bool use_function,
double abs_error = 1e-2) {
AbstractTensorHandle* numerical_grad;
Status s = CalcNumericalGrad(ctx, model, inputs, input_index, use_function,
&numerical_grad);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* numerical_tensor;
s = GetValue(numerical_grad, &numerical_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto num_elem_numerical = TF_TensorElementCount(numerical_tensor);
ASSERT_EQ(num_elem_numerical, num_grad);
float* dnumerical = new float[num_elem_numerical]{0};
memcpy(&dnumerical[0], TF_TensorData(numerical_tensor),
TF_TensorByteSize(numerical_tensor));
for (int j = 0; j < num_grad; j++) {
ASSERT_NEAR(dnumerical[j], expected_grad[j], abs_error);
}
delete dnumerical;
TF_DeleteTensor(numerical_tensor);
}
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::MatMul(ctx, inputs, outputs, "MatMul",
/*transpose_a=*/false,
/*transpose_b=*/false);
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::Mul(ctx, inputs, outputs, "Mul");
}
// TODO(vnvo2409): Add more tests from `python/ops/gradient_checker_v2_test.py`.
// These tests should not be confused with `[*]_grad_test` which compare the
// result of `gradient_checker` and `[*]_grad`. The tests here test the
// functionality of `gradient_checker` by comparing the result with expected
// manual user-provided gradients.
class GradientCheckerTest class GradientCheckerTest
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected: protected:
@ -45,84 +83,56 @@ class GradientCheckerTest
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get()); Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message(); CHECK_EQ(errors::OK, s.code()) << s.error_message();
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx_.reset(ctx_raw);
}
} }
AbstractContextPtr ctx_;
public:
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
bool UseFunction() const { return std::get<2>(GetParam()); }
}; };
Status RegisterGradients(GradientRegistry* registry) { TEST_P(GradientCheckerTest, TestMatMul) {
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
// Computing numerical gradients with TensorFloat-32 is numerically unstable
enable_tensor_float_32_execution(false);
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2}; int64_t A_dims[] = {2, 2};
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
A.reset(A_raw);
}
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2}; int64_t B_dims[] = {2, 2};
int num_dims = 2; AbstractTensorHandlePtr B;
{
AbstractTensorHandlePtr A = AbstractTensorHandle* B_raw;
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); Status s =
AbstractTensorHandlePtr B = TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
B.reset(B_raw);
std::vector<AbstractTensorHandle*> inputs; }
inputs.push_back(A.get());
inputs.push_back(B.get());
AbstractTensorHandle* grad_approx;
Status s = CalcNumericalGrad(
ctx.get(), MatMulModel, absl::MakeSpan(inputs), /*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &grad_approx);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(grad_approx, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-2; ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
for (int j = 0; j < 4; j++) { MatMulModel, ctx_.get(), {A.get(), B.get()}, 0, expected_dA, 4,
ASSERT_NEAR(expected_dA[j], result_data[j], tolerance); UseFunction()));
}
TF_DeleteTensor(gt);
} }
TEST_P(GradientCheckerTest, TestGradCheckMul) { TEST_P(GradientCheckerTest, TestMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x; AbstractTensorHandlePtr x;
{ {
AbstractTensorHandle* x_raw = nullptr; AbstractTensorHandle* x_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw); Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw); x.reset(x_raw);
} }
@ -130,32 +140,15 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
AbstractTensorHandlePtr y; AbstractTensorHandlePtr y;
{ {
AbstractTensorHandle* y_raw = nullptr; AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw); Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw); y.reset(y_raw);
} }
// Will perform z = x*y. float expected_dx[1] = {7.0f};
// dz/dx = y ASSERT_NO_FATAL_FAILURE(CompareNumericalAndManualGradients(
MulModel, ctx_.get(), {x.get(), y.get()}, 0, expected_dx, 1,
std::vector<AbstractTensorHandle*> inputs; UseFunction()));
inputs.push_back(x.get());
inputs.push_back(y.get());
AbstractTensorHandle* g;
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
/*input_index=*/0,
/*use_function=*/!std::get<2>(GetParam()), &g);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[1] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2);
TF_DeleteTensor(gt);
} }
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE
@ -163,13 +156,13 @@ INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest, UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"), ::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false), /*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false))); /*use_function*/ ::testing::Values(true, false)));
#else #else
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest, UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"), ::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false), /*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false))); /*use_function*/ ::testing::Values(true, false)));
#endif #endif
} // namespace } // namespace
} // namespace internal } // namespace internal

View File

@ -4,6 +4,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
# buildifier: disable=same-origin-load # buildifier: disable=same-origin-load
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"if_libtpu",
"tf_cuda_cc_test", "tf_cuda_cc_test",
) )
load( load(
@ -165,7 +166,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/c/eager:gradient_checker", "//tensorflow/c/eager:gradient_checker",
"//tensorflow/c/eager:gradients_util", "//tensorflow/c/eager:unified_api_testutil",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ],
@ -183,9 +184,14 @@ tf_cuda_cc_test(
deps = [ deps = [
":grad_test_helper", ":grad_test_helper",
":nn_grad", ":nn_grad",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api_test_util", "//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
], ] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
) )

View File

@ -24,11 +24,11 @@ namespace internal {
void CompareNumericalAndAutodiffGradients( void CompareNumericalAndAutodiffGradients(
Model model, Model grad_model, AbstractContext* ctx, Model model, Model grad_model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, bool use_function, absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
const GradientRegistry& registry, double abs_error) { double abs_error) {
auto num_inputs = inputs.size(); auto num_inputs = inputs.size();
std::vector<AbstractTensorHandle*> outputs(num_inputs); std::vector<AbstractTensorHandle*> outputs(num_inputs);
auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs), auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs),
/*use_function=*/use_function, registry); /*use_function=*/use_function);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {

View File

@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ #define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_
#include "tensorflow/c/eager/gradients_util.h" #include "tensorflow/c/eager/unified_api_testutil.h"
namespace tensorflow { namespace tensorflow {
namespace gradients { namespace gradients {
@ -24,7 +24,7 @@ namespace internal {
void CompareNumericalAndAutodiffGradients( void CompareNumericalAndAutodiffGradients(
Model model, Model grad_model, AbstractContext* ctx, Model model, Model grad_model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, bool use_function, absl::Span<AbstractTensorHandle* const> inputs, bool use_function,
const GradientRegistry& registry, double abs_error = 1e-2); double abs_error = 1e-2);
} // namespace internal } // namespace internal
} // namespace gradients } // namespace gradients

View File

@ -15,8 +15,11 @@ limitations under the License.
#include "tensorflow/c/experimental/gradients/nn_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/gradients/grad_test_helper.h" #include "tensorflow/c/experimental/gradients/grad_test_helper.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
@ -28,16 +31,19 @@ using tensorflow::TF_StatusPtr;
Status SparseSoftmaxCrossEntropyWithLogitsModel( Status SparseSoftmaxCrossEntropyWithLogitsModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs) {
const GradientRegistry& registry) {
return ops::SparseSoftmaxCrossEntropyWithLogits( return ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, inputs, outputs, "SparseSoftmaxCrossEntropyWithLogits"); ctx, inputs, outputs, "SparseSoftmaxCrossEntropyWithLogits");
} }
Status SparseSoftmaxCrossEntropyWithLogitsGradModel( Status SparseSoftmaxCrossEntropyWithLogitsGradModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs) {
const GradientRegistry& registry) { GradientRegistry registry;
TF_RETURN_IF_ERROR(
registry.Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
Tape tape(/*persistent=*/false); Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]); // Watch score. tape.Watch(inputs[0]); // Watch score.
tape.Watch(inputs[1]); // Watch label. tape.Watch(inputs[1]); // Watch label.
@ -58,15 +64,16 @@ Status SparseSoftmaxCrossEntropyWithLogitsGradModel(
Status BiasAddModel(AbstractContext* ctx, Status BiasAddModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs) {
const GradientRegistry& registry) {
return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd"); return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd");
} }
Status BiasAddGradModel(AbstractContext* ctx, Status BiasAddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, absl::Span<AbstractTensorHandle*> outputs) {
const GradientRegistry& registry) { GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("BiasAdd", BiasAddRegisterer));
Tape tape(/*persistent=*/false); Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]); // Watch A. tape.Watch(inputs[0]); // Watch A.
tape.Watch(inputs[1]); // Watch Bias. tape.Watch(inputs[1]); // Watch Bias.
@ -84,14 +91,6 @@ Status BiasAddGradModel(AbstractContext* ctx,
return Status::OK(); return Status::OK();
} }
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
class CppGradients class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> { : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected: protected:
@ -99,7 +98,7 @@ class CppGradients
TF_StatusPtr status(TF_NewStatus()); TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get()); Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
{ {
AbstractContext* ctx_raw = nullptr; AbstractContext* ctx_raw = nullptr;
@ -108,12 +107,8 @@ class CppGradients
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx_.reset(ctx_raw); ctx_.reset(ctx_raw);
} }
s = RegisterGradients(&registry_);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
} }
GradientRegistry registry_;
AbstractContextPtr ctx_; AbstractContextPtr ctx_;
public: public:
@ -131,19 +126,31 @@ TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
// Score // Score
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f};
int64_t X_dims[] = {3, 3}; int64_t X_dims[] = {3, 3};
AbstractTensorHandlePtr X = AbstractTensorHandlePtr X;
GetTensorHandleUtilFloat(ctx_.get(), X_vals, X_dims, 2); {
AbstractTensorHandle* X_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
X.reset(X_raw);
}
// Label // Label
int Y_vals[] = {1, 0, 1}; int Y_vals[] = {1, 0, 1};
int64_t Y_dims[] = {3}; int64_t Y_dims[] = {3};
AbstractTensorHandlePtr Y = AbstractTensorHandlePtr Y;
GetTensorHandleUtilInt(ctx_.get(), Y_vals, Y_dims, 1); {
AbstractTensorHandle* Y_raw;
Status s =
TestTensorHandleWithDimsInt(ctx_.get(), Y_vals, Y_dims, 1, &Y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
Y.reset(Y_raw);
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
SparseSoftmaxCrossEntropyWithLogitsModel, SparseSoftmaxCrossEntropyWithLogitsModel,
SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(), SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
{X.get(), Y.get()}, {X.get(), Y.get()},
/*use_function=*/UseFunction(), registry_)); /*use_function=*/UseFunction()));
} }
TEST_P(CppGradients, TestBiasAddGrad) { TEST_P(CppGradients, TestBiasAddGrad) {
@ -154,17 +161,29 @@ TEST_P(CppGradients, TestBiasAddGrad) {
// A // A
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2}; int64_t A_dims[] = {2, 2};
AbstractTensorHandlePtr A = AbstractTensorHandlePtr A;
GetTensorHandleUtilFloat(ctx_.get(), A_vals, A_dims, 2); {
AbstractTensorHandle* A_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
A.reset(A_raw);
}
// Bias // Bias
float Bias_vals[] = {2.0f, 3.0f}; float Bias_vals[] = {2.0f, 3.0f};
int64_t Bias_dims[] = {2}; int64_t Bias_dims[] = {2};
AbstractTensorHandlePtr Bias = AbstractTensorHandlePtr Bias;
GetTensorHandleUtilFloat(ctx_.get(), Bias_vals, Bias_dims, 1); {
AbstractTensorHandle* Bias_raw;
Status s = TestTensorHandleWithDimsFloat(ctx_.get(), Bias_vals, Bias_dims,
1, &Bias_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
Bias.reset(Bias_raw);
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()}, BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
/*use_function=*/UseFunction(), registry_)); /*use_function=*/UseFunction()));
} }
#ifdef PLATFORM_GOOGLE #ifdef PLATFORM_GOOGLE