From 7b9e2ff8626f0d99d65ed58d2afa55780c28f287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Fri, 20 Nov 2020 21:48:55 +0700 Subject: [PATCH 01/15] add BiasAddGradient --- tensorflow/c/experimental/gradients/BUILD | 82 ++++++++++ .../c/experimental/gradients/nn_grad.cc | 52 ++++++ tensorflow/c/experimental/gradients/nn_grad.h | 1 + .../c/experimental/gradients/nn_grad_test.cc | 154 ++++++++++++++++++ .../gradients/nn_grad_testutil.cc | 71 ++++++++ .../experimental/gradients/nn_grad_testutil.h | 46 ++++++ tensorflow/c/experimental/ops/nn_ops.cc | 33 ++++ tensorflow/c/experimental/ops/nn_ops.h | 9 + 8 files changed, 448 insertions(+) create mode 100644 tensorflow/c/experimental/gradients/nn_grad_test.cc create mode 100644 tensorflow/c/experimental/gradients/nn_grad_testutil.cc create mode 100644 tensorflow/c/experimental/gradients/nn_grad_testutil.h diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index e8a50e32216..728f9ab851c 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -1,5 +1,21 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load( + "//tensorflow:tensorflow.bzl", + "if_libtpu", + "tf_cc_test", + "tf_copts", + "tf_cuda_cc_test", + "tf_cuda_library", +) +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core/platform:build_config_root.bzl", + "tf_cuda_tests_tags", +) # Library of gradient functions. package( @@ -95,3 +111,69 @@ filegroup( "//tensorflow/python:__pkg__", ], ) + +cc_library( + name = "nn_grad_testutil", + srcs = [ + "nn_grad_testutil.cc", + ], + hdrs = [ + "nn_grad_testutil.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:gradients_util", + "//tensorflow/c/eager:tape", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/types:span", + ], +) + +tf_cuda_cc_test( + name = "nn_grad_test", + size = "small", + srcs = [ + "nn_grad_test.cc", + ], + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/gradients:nn_grad_testutil", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc index 64532c8ffc0..8dcdaff9d8b 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.cc +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" using std::vector; +using tensorflow::ops::BiasAddGrad; using tensorflow::ops::Mul; using tensorflow::ops::ReluGrad; @@ -110,6 +111,48 @@ class SparseSoftmaxCrossEntropyWithLogitsGradientFunction vector forward_outputs; }; +// TODO(vnvo2409): Add python test +class BiasAddGradientFunction : public GradientFunction { + public: + explicit BiasAddGradientFunction(AttrBuilder f_attrs) + : forward_attrs(f_attrs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a BiasAdd: A + bias, the gradients are: + * + * dA = U + * dbias = reduceSum(U, dims = channel_dim) + */ + + AbstractTensorHandle* upstream_grad = grad_inputs[0]; + DCHECK(upstream_grad); + grad_outputs->resize(2); + + // Recover data format from forward pass for gradient. + std::string data_format; + forward_attrs.Get("data_format", &data_format); + + // Grad for A + (*grad_outputs)[0] = upstream_grad; + (*grad_outputs)[0]->Ref(); + + // Grad for bias + vector bias_add_grad_outputs(1); + std::string name = "bias_add_grad"; + TF_RETURN_IF_ERROR(BiasAddGrad(ctx->ctx, {upstream_grad}, + absl::MakeSpan(bias_add_grad_outputs), + data_format.c_str(), name.c_str())); + + (*grad_outputs)[1] = bias_add_grad_outputs[0]; + return Status::OK(); + } + ~BiasAddGradientFunction() override {} + + private: + AttrBuilder forward_attrs; +}; + } // namespace BackwardFunction* ReluRegisterer(const ForwardOperation& op) { @@ -129,5 +172,14 @@ BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( return new BackwardFunction(gradient_function, default_gradients); } +BackwardFunction* BiasAddRegisterer(const ForwardOperation& op) { + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto gradient_function = new BiasAddGradientFunction(op.attrs); + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h index 034f20d7325..b4406bb5bbc 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.h +++ b/tensorflow/c/experimental/gradients/nn_grad.h @@ -22,6 +22,7 @@ namespace gradients { BackwardFunction* ReluRegisterer(const ForwardOperation& op); BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op); +BackwardFunction* BiasAddRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc new file mode 100644 index 00000000000..ee549b8333f --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/nn_grad.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { +using std::vector; +using tensorflow::TF_StatusPtr; +using tracing::TracingOperation; + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } +}; + +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer)); + return Status::OK(); +} + +TEST_P(CppGradients, TestBiasAddGrad) { + if (std::get<0>(GetParam()) == "mlir" && std::get<2>(GetParam()) == 0) { + GTEST_SKIP() << "SetAttrString has not been implemented yet.\n"; + } + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + float Bias_vals[] = {2.0f, 3.0f}; + int64_t Bias_dims[] = {2}; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, 2); + AbstractTensorHandlePtr B = + GetTensorHandleUtilFloat(ctx.get(), Bias_vals, Bias_dims, 1); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(A) + * tape.watch(B) + * Y = BiasAdd(A, Bias) + * outputs = tape.gradient(Y, [A, Bias]) + */ + std::vector outputs(2); + s = RunModel(BiasAddGradModel, ctx.get(), {A.get(), B.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dA_tensor; + s = GetValue(outputs[0], &dA_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data_dA[4] = {0}; + memcpy(&result_data_dA[0], TF_TensorData(dA_tensor), + TF_TensorByteSize(dA_tensor)); + + float expected_dA[4] = {1.0f, 1.0f, 1.0f, 1.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data_dA[j], expected_dA[j], tolerance); + } + + TF_Tensor* dBias_tensor; + s = GetValue(outputs[1], &dBias_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data_dBias[2] = {0}; + memcpy(&result_data_dBias[0], TF_TensorData(dBias_tensor), + TF_TensorByteSize(dBias_tensor)); + + float expected_dBias[2] = {2.0f, 2.0f}; + for (int j = 0; j < 2; j++) { + ASSERT_NEAR(result_data_dBias[j], expected_dBias[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(dA_tensor); + TF_DeleteTensor(dBias_tensor); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_testutil.cc b/tensorflow/c/experimental/gradients/nn_grad_testutil.cc new file mode 100644 index 00000000000..6866e3f878c --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad_testutil.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +// Computes +// y = BiasAdd(inputs[0], inputs[1]) +// return grad(y, {inputs[0], inputs[1]}) +Status BiasAddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector bias_add_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, + absl::MakeSpan(bias_add_outputs), "BiasAdd")); + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(bias_add_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto bias_add_output : bias_add_outputs) { + bias_add_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_testutil.h b/tensorflow/c/experimental/gradients/nn_grad_testutil.h new file mode 100644 index 00000000000..dff65906821 --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad_testutil.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +// Computes +// y = BiasAdd(inputs[0], inputs[1]) +// return grad(y, {inputs[0], inputs[1]}) +Status BiasAddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ \ No newline at end of file diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc index 6a97dbf0939..b1cc2ffc7d6 100644 --- a/tensorflow/c/experimental/ops/nn_ops.cc +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -69,5 +69,38 @@ Status Relu(AbstractContext* ctx, return Status::OK(); } +Status BiasAdd(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr bias_add_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + bias_add_op->Reset("BiasAdd", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_op.get(), name)); + TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[0])); // tensor input + TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[1])); // bias + + int num_retvals = 1; + TF_RETURN_IF_ERROR(bias_add_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +// Computes Bias Add gradient given upstream grads +Status BiasAddGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const char* data_format, const char* name) { + AbstractOperationPtr bias_add_grad_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + bias_add_grad_op->Reset("BiasAddGrad", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_grad_op.get(), name)); + TF_RETURN_IF_ERROR(bias_add_grad_op->SetAttrString("data_format", data_format, + strlen(data_format))); + TF_RETURN_IF_ERROR(bias_add_grad_op->AddInput(inputs[0])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(bias_add_grad_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h index 3c0e04579a1..d5b8cdde356 100644 --- a/tensorflow/c/experimental/ops/nn_ops.h +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -34,6 +34,15 @@ Status Relu(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status BiasAdd(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status BiasAddGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const char* data_format, const char* name); + } // namespace ops } // namespace tensorflow From 22a7fd609f36346bc1077ff8e1544594e9fb540d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 26 Nov 2020 01:22:29 +0700 Subject: [PATCH 02/15] Use gradient checker for testing BiasAddGrad --- tensorflow/c/experimental/gradients/BUILD | 48 ++----- .../gradients/grad_test_helper.cc | 77 ++++++++++ .../experimental/gradients/grad_test_helper.h | 34 +++++ .../c/experimental/gradients/nn_grad_test.cc | 131 ++++++++---------- .../gradients/nn_grad_testutil.cc | 71 ---------- .../experimental/gradients/nn_grad_testutil.h | 46 ------ 6 files changed, 179 insertions(+), 228 deletions(-) create mode 100644 tensorflow/c/experimental/gradients/grad_test_helper.cc create mode 100644 tensorflow/c/experimental/gradients/grad_test_helper.h delete mode 100644 tensorflow/c/experimental/gradients/nn_grad_testutil.cc delete mode 100644 tensorflow/c/experimental/gradients/nn_grad_testutil.h diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 728f9ab851c..74c7b54b935 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -113,31 +113,18 @@ filegroup( ) cc_library( - name = "nn_grad_testutil", - srcs = [ - "nn_grad_testutil.cc", - ], - hdrs = [ - "nn_grad_testutil.h", - ], + name = "grad_test_helper", + testonly = True, + srcs = ["grad_test_helper.cc"], + hdrs = ["grad_test_helper.h"], visibility = [ "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_experimental", - "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:gradient_checker", "//tensorflow/c/eager:gradients_util", - "//tensorflow/c/eager:tape", - "//tensorflow/c/experimental/gradients/tape:tape_context", - "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/c/experimental/ops:nn_ops", - "//tensorflow/core/lib/llvm_rtti", - "//tensorflow/core/platform:status", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/types:span", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) @@ -151,29 +138,10 @@ tf_cuda_cc_test( linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ - "//tensorflow/c:c_api", - "//tensorflow/c:c_test_util", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_experimental", + ":grad_test_helper", "//tensorflow/c/eager:c_api_test_util", - "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients_internal", - "//tensorflow/c/eager:unified_api_testutil", - "//tensorflow/c/experimental/gradients:nn_grad", - "//tensorflow/c/experimental/gradients:nn_grad_testutil", "//tensorflow/c/experimental/gradients/tape:tape_context", - "//tensorflow/c/experimental/ops", - "//tensorflow/cc/profiler", - "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core/lib/llvm_rtti", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc new file mode 100644 index 00000000000..8db15951132 --- /dev/null +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/grad_test_helper.h" + +#include "tensorflow/c/eager/gradient_checker.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +void CompareWithGradientsCheckers(Model model, Model grad_model, + AbstractContext* ctx, + std::vector inputs, + bool use_function, + const GradientRegistry& registry) { + auto num_inputs = inputs.size(); + std::vector outputs(num_inputs); + auto s = + RunModel(grad_model, ctx, absl::MakeSpan(inputs), absl::MakeSpan(outputs), + /*use_function=*/use_function, registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + for (int i = 0; i < num_inputs; ++i) { + std::vector numerical_inputs{inputs}; + + AbstractTensorHandle* g; // Will contain numerical approximation data. + // TODO(vnvo2409): `CalcNumericalGrad` should not modify `inputs`. + s = CalcNumericalGrad(ctx, model, absl::MakeSpan(numerical_inputs), + /*input_index=*/i, + /*use_function=*/use_function, &g); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* numerical_tensor; + s = GetValue(g, &numerical_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto num_elem_numerical = TF_TensorElementCount(numerical_tensor); + + TF_Tensor* analytical_tensor; + s = GetValue(outputs[i], &analytical_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto num_elem_analytical = TF_TensorElementCount(analytical_tensor); + + ASSERT_EQ(num_elem_numerical, num_elem_analytical); + + float* dnumerical = new float[num_elem_numerical]{0}; + memcpy(&dnumerical[0], TF_TensorData(numerical_tensor), + TF_TensorByteSize(numerical_tensor)); + float* danalytical = new float[num_elem_analytical]{0}; + memcpy(&danalytical[0], TF_TensorData(analytical_tensor), + TF_TensorByteSize(analytical_tensor)); + + for (int j = 0; j < num_elem_numerical; j++) { + ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2); + } + TF_DeleteTensor(analytical_tensor); + TF_DeleteTensor(numerical_tensor); + delete[] danalytical; + delete[] dnumerical; + } +} + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h new file mode 100644 index 00000000000..d51fd6d7871 --- /dev/null +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ + +#include "tensorflow/c/eager/gradients_util.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +void CompareWithGradientsCheckers(Model model, Model grad_model, + AbstractContext* ctx, + std::vector inputs, + bool use_function, + const GradientRegistry& registry); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_GRAD_TEST_HELPER_H_ diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index ee549b8333f..871e34998a8 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -14,36 +14,67 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/gradients/nn_grad.h" -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/eager/gradients_util.h" -#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h" +#include "tensorflow/c/experimental/gradients/grad_test_helper.h" #include "tensorflow/c/experimental/gradients/tape/tape_context.h" -#include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace gradients { namespace internal { namespace { -using std::vector; + using tensorflow::TF_StatusPtr; using tracing::TracingOperation; +Status BiasAddModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector bias_add_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, + absl::MakeSpan(bias_add_outputs), "BiasAdd")); + outputs[0] = bias_add_outputs[0]; + + delete tape; + return Status::OK(); +} + +Status BiasAddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector bias_add_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::BiasAdd( + tape_ctx.get(), inputs, absl::MakeSpan(bias_add_outputs), "BiasAddGrad")); + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(bias_add_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto bias_add_output : bias_add_outputs) { + bias_add_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + class CppGradients : public ::testing::TestWithParam> { protected: @@ -60,12 +91,10 @@ Status RegisterGradients(GradientRegistry* registry) { return Status::OK(); } -TEST_P(CppGradients, TestBiasAddGrad) { - if (std::get<0>(GetParam()) == "mlir" && std::get<2>(GetParam()) == 0) { +TEST_P(CppGradients, TestBiasAddGradChecker) { + if (std::get<0>(GetParam()) == "mlir" && !std::get<2>(GetParam())) { GTEST_SKIP() << "SetAttrString has not been implemented yet.\n"; } - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); AbstractContextPtr ctx; { AbstractContext* ctx_raw = nullptr; @@ -75,64 +104,24 @@ TEST_P(CppGradients, TestBiasAddGrad) { ctx.reset(ctx_raw); } + // A float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; int64_t A_dims[] = {2, 2}; - float Bias_vals[] = {2.0f, 3.0f}; - int64_t Bias_dims[] = {2}; - AbstractTensorHandlePtr A = GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, 2); - AbstractTensorHandlePtr B = + // Bias + float Bias_vals[] = {2.0f, 3.0f}; + int64_t Bias_dims[] = {2}; + AbstractTensorHandlePtr Bias = GetTensorHandleUtilFloat(ctx.get(), Bias_vals, Bias_dims, 1); GradientRegistry registry; Status s = RegisterGradients(®istry); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - /* Pseudo-code: - * - * tape.watch(A) - * tape.watch(B) - * Y = BiasAdd(A, Bias) - * outputs = tape.gradient(Y, [A, Bias]) - */ - std::vector outputs(2); - s = RunModel(BiasAddGradModel, ctx.get(), {A.get(), B.get()}, - absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam()), registry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - TF_Tensor* dA_tensor; - s = GetValue(outputs[0], &dA_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data_dA[4] = {0}; - memcpy(&result_data_dA[0], TF_TensorData(dA_tensor), - TF_TensorByteSize(dA_tensor)); - - float expected_dA[4] = {1.0f, 1.0f, 1.0f, 1.0f}; - float tolerance = 1e-3; - for (int j = 0; j < 4; j++) { - ASSERT_NEAR(result_data_dA[j], expected_dA[j], tolerance); - } - - TF_Tensor* dBias_tensor; - s = GetValue(outputs[1], &dBias_tensor); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - - float result_data_dBias[2] = {0}; - memcpy(&result_data_dBias[0], TF_TensorData(dBias_tensor), - TF_TensorByteSize(dBias_tensor)); - - float expected_dBias[2] = {2.0f, 2.0f}; - for (int j = 0; j < 2; j++) { - ASSERT_NEAR(result_data_dBias[j], expected_dBias[j], tolerance); - } - - outputs[0]->Unref(); - outputs[1]->Unref(); - TF_DeleteTensor(dA_tensor); - TF_DeleteTensor(dBias_tensor); + ASSERT_NO_FATAL_FAILURE(CompareWithGradientsCheckers( + BiasAddModel, BiasAddGradModel, ctx.get(), {A.get(), Bias.get()}, + /*use_function=*/!std::get<2>(GetParam()), registry)); } #ifdef PLATFORM_GOOGLE diff --git a/tensorflow/c/experimental/gradients/nn_grad_testutil.cc b/tensorflow/c/experimental/gradients/nn_grad_testutil.cc deleted file mode 100644 index 6866e3f878c..00000000000 --- a/tensorflow/c/experimental/gradients/nn_grad_testutil.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h" - -#include - -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/experimental/gradients/tape/tape_context.h" -#include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/status.h" - -namespace tensorflow { -namespace gradients { -namespace internal { - -// Computes -// y = BiasAdd(inputs[0], inputs[1]) -// return grad(y, {inputs[0], inputs[1]}) -Status BiasAddGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - tape->Watch(ToId(inputs[1])); // Watch y. - std::vector bias_add_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, - absl::MakeSpan(bias_add_outputs), "BiasAdd")); - std::unordered_map - source_tensors_that_are_targets; - - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(bias_add_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto bias_add_output : bias_add_outputs) { - bias_add_output->Unref(); - } - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - delete tape; - return Status::OK(); -} - -} // namespace internal -} // namespace gradients -} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_testutil.h b/tensorflow/c/experimental/gradients/nn_grad_testutil.h deleted file mode 100644 index dff65906821..00000000000 --- a/tensorflow/c/experimental/gradients/nn_grad_testutil.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ -#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ -#include - -#include "absl/types/span.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/gradients.h" -#include "tensorflow/c/eager/gradients_internal.h" -#include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/status.h" - -namespace tensorflow { -namespace gradients { -namespace internal { - -// Computes -// y = BiasAdd(inputs[0], inputs[1]) -// return grad(y, {inputs[0], inputs[1]}) -Status BiasAddGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -} // namespace internal -} // namespace gradients -} // namespace tensorflow - -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ \ No newline at end of file From fbc285f8b1417fee5bb45062da8ca6f0808cd220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 26 Nov 2020 17:40:04 +0700 Subject: [PATCH 03/15] cleanup `nn_grad_test.cc` Changes: - add `CompareWithGradientsCheckers` to automatically test the model against `gradient_checker`. - `TF_MODEL_FACTORY` and `TF_GRAD_MODEL_FACTORY` will expand to models that are needed for testing. - Move `RegisterGradients` and `BuildImmediateExecutionContext` to `Setup`. --- tensorflow/c/experimental/gradients/BUILD | 1 + .../gradients/grad_test_helper.cc | 2 + .../experimental/gradients/grad_test_helper.h | 80 ++++++++++++++++ .../c/experimental/gradients/nn_grad_test.cc | 92 ++++++------------- 4 files changed, 111 insertions(+), 64 deletions(-) diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 74c7b54b935..0d0ebebf88e 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -139,6 +139,7 @@ tf_cuda_cc_test( tags = tf_cuda_tests_tags() + ["nomac"], deps = [ ":grad_test_helper", + ":nn_grad", "//tensorflow/c/eager:c_api_test_util", "//tensorflow/c/experimental/gradients/tape:tape_context", "//tensorflow/core:test", diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc index 8db15951132..2764ce8d92d 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.cc +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -34,6 +34,7 @@ void CompareWithGradientsCheckers(Model model, Model grad_model, ASSERT_EQ(errors::OK, s.code()) << s.error_message(); for (int i = 0; i < num_inputs; ++i) { + if (!outputs[i]) continue; std::vector numerical_inputs{inputs}; AbstractTensorHandle* g; // Will contain numerical approximation data. @@ -69,6 +70,7 @@ void CompareWithGradientsCheckers(Model model, Model grad_model, TF_DeleteTensor(numerical_tensor); delete[] danalytical; delete[] dnumerical; + outputs[i]->Unref(); } } diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h index d51fd6d7871..604495c9380 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.h +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -21,6 +21,86 @@ namespace tensorflow { namespace gradients { namespace internal { +// This macro will expand to a function that defines a `Model`. This `Model` is +// then used for testing by `nn_grad_test` and `math_grad_test`. `ops_call` is a +// statement that calls to a `ops::` and should be wrapped around by `{}`. +// `ops_call` has access to `inputs`. The output parameter of the ops should +// always be `absl::MakeSpan(temp_outputs)`. This macro supports most one-ops +// model. +// TODO(vnvo2409): Extend support for more complex model. +#define TF_MODEL_FACTORY(name, num_inputs, num_outputs, ops_call) \ + Status name(AbstractContext* ctx, \ + absl::Span inputs, \ + absl::Span outputs, \ + const GradientRegistry& registry) { \ + auto tape = new Tape(/*persistent=*/false); \ + for (int i{}; i < num_inputs; ++i) { \ + tape->Watch(ToId(inputs[i])); \ + } \ + \ + AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ + ops_call; \ + \ + for (int i{}; i < num_outputs; ++i) { \ + outputs[i] = temp_outputs[i]; \ + } \ + delete tape; \ + return Status::OK(); \ + } + +// This macro will expand to a function that defines a `GradModel`. This +// `GradModel` is then used for testing by `nn_grad_test` and `math_grad_test`. +// `ops_call` is a statement that calls to a `ops::` and should be wrapped +// around by `{}`. `ops_call` has access to `inputs`. The output parameter of +// the ops should always be `absl::MakeSpan(temp_outputs)`. This macro supports +// most one-ops model. +// TODO(vnvo2409): Extend support for more complex model. +#define TF_GRAD_MODEL_FACTORY(name, num_inputs, num_outputs, num_grad_outputs, \ + ops_call) \ + Status name(AbstractContext* ctx, \ + absl::Span inputs, \ + absl::Span outputs, \ + const GradientRegistry& registry) { \ + TapeVSpace vspace(ctx); \ + auto tape = new Tape(/*persistent=*/false); \ + for (int i{}; i < num_inputs; ++i) { \ + tape->Watch(ToId(inputs[i])); \ + } \ + \ + AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ + ops_call; \ + \ + std::unordered_map \ + source_tensors_that_are_targets; \ + std::vector out_grads(num_grad_outputs); \ + \ + int64 target_tensor_ids[num_outputs] = {}; \ + for (int i{}; i < num_outputs; ++i) { \ + target_tensor_ids[i] = ToId(temp_outputs[i]); \ + } \ + \ + int64 source_tensor_ids[num_inputs] = {}; \ + for (int i{}; i < num_inputs; ++i) { \ + source_tensor_ids[i] = ToId(inputs[i]); \ + } \ + \ + TF_RETURN_IF_ERROR(tape->ComputeGradient( \ + vspace, target_tensor_ids, source_tensor_ids, \ + source_tensors_that_are_targets, /*output_gradients=*/{}, &out_grads, \ + /*build_default_zeros_grads=*/false)); \ + \ + for (int i{}; i < num_outputs; ++i) { \ + temp_outputs[i]->Unref(); \ + } \ + for (int i{}; i < num_grad_outputs; ++i) { \ + outputs[i] = out_grads[i]; \ + } \ + delete tape; \ + return Status::OK(); \ + } + void CompareWithGradientsCheckers(Model model, Model grad_model, AbstractContext* ctx, std::vector inputs, diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 871e34998a8..42cce9c5e24 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -27,51 +27,18 @@ namespace { using tensorflow::TF_StatusPtr; using tracing::TracingOperation; -Status BiasAddModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - tape->Watch(ToId(inputs[1])); // Watch y. - std::vector bias_add_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); +TF_MODEL_FACTORY(BiasAddModel, 2, 1, { TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, - absl::MakeSpan(bias_add_outputs), "BiasAdd")); - outputs[0] = bias_add_outputs[0]; + absl::MakeSpan(temp_outputs), "BiasAdd")); +}) - delete tape; - return Status::OK(); -} +TF_GRAD_MODEL_FACTORY(BiasAddGradModel, 2, 1, 2, { + TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, + absl::MakeSpan(temp_outputs), "BiasAddGrad")); +}) -Status BiasAddGradModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - TapeVSpace vspace(ctx); - auto tape = new Tape(/*persistent=*/false); - tape->Watch(ToId(inputs[0])); // Watch x. - tape->Watch(ToId(inputs[1])); // Watch y. - std::vector bias_add_outputs(1); - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); - TF_RETURN_IF_ERROR(ops::BiasAdd( - tape_ctx.get(), inputs, absl::MakeSpan(bias_add_outputs), "BiasAddGrad")); - std::unordered_map - source_tensors_that_are_targets; - - std::vector out_grads; - TF_RETURN_IF_ERROR(tape->ComputeGradient( - vspace, /*target_tensor_ids=*/{ToId(bias_add_outputs[0])}, - /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, - source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads, - /*build_default_zeros_grads=*/false)); - for (auto bias_add_output : bias_add_outputs) { - bias_add_output->Unref(); - } - outputs[0] = out_grads[0]; - outputs[1] = out_grads[1]; - delete tape; +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer)); return Status::OK(); } @@ -83,45 +50,42 @@ class CppGradients TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.error_message(); + + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx_.reset(ctx_raw); + } + + s = RegisterGradients(®istry_); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); } + + GradientRegistry registry_; + AbstractContextPtr ctx_; }; -Status RegisterGradients(GradientRegistry* registry) { - TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer)); - return Status::OK(); -} - -TEST_P(CppGradients, TestBiasAddGradChecker) { +TEST_P(CppGradients, TestBiasAddGrad) { if (std::get<0>(GetParam()) == "mlir" && !std::get<2>(GetParam())) { GTEST_SKIP() << "SetAttrString has not been implemented yet.\n"; } - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); - ctx.reset(ctx_raw); - } // A float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; int64_t A_dims[] = {2, 2}; AbstractTensorHandlePtr A = - GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, 2); + GetTensorHandleUtilFloat(ctx_.get(), A_vals, A_dims, 2); // Bias float Bias_vals[] = {2.0f, 3.0f}; int64_t Bias_dims[] = {2}; AbstractTensorHandlePtr Bias = - GetTensorHandleUtilFloat(ctx.get(), Bias_vals, Bias_dims, 1); - - GradientRegistry registry; - Status s = RegisterGradients(®istry); - ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + GetTensorHandleUtilFloat(ctx_.get(), Bias_vals, Bias_dims, 1); ASSERT_NO_FATAL_FAILURE(CompareWithGradientsCheckers( - BiasAddModel, BiasAddGradModel, ctx.get(), {A.get(), Bias.get()}, - /*use_function=*/!std::get<2>(GetParam()), registry)); + BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()}, + /*use_function=*/!std::get<2>(GetParam()), registry_)); } #ifdef PLATFORM_GOOGLE From 9bb01db0dd360c5ded9de36e7a53971a968a1f31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Sat, 28 Nov 2020 17:08:47 +0700 Subject: [PATCH 04/15] restore `inputs` in `gradients_checker` --- tensorflow/c/eager/gradient_checker.cc | 6 ++++-- tensorflow/c/eager/gradient_checker.h | 3 +++ .../c/experimental/gradients/grad_test_helper.cc | 11 ++++------- .../c/experimental/gradients/grad_test_helper.h | 2 +- tensorflow/c/experimental/gradients/nn_grad_test.cc | 4 +++- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index 640edc7228a..c2da03f28dd 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -99,8 +99,8 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, sum_inputs[0] = model_out; sum_inputs[1] = sum_dims.get(); - TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs), - absl::MakeSpan(model_outputs), "sum_output")); + TF_RETURN_IF_ERROR( + ops::Sum(ctx, sum_inputs, absl::MakeSpan(model_outputs), "sum_output")); outputs[0] = model_outputs[0]; return Status::OK(); } @@ -191,6 +191,8 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, dtheta_approx[i] = grad_data[0]; } + // Restore the inputs + inputs[input_index] = theta; // Populate *numerical_grad with the data from dtheta_approx. TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat( ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h index 8497f5af48e..f4b9ac544eb 100644 --- a/tensorflow/c/eager/gradient_checker.h +++ b/tensorflow/c/eager/gradient_checker.h @@ -43,6 +43,9 @@ namespace gradients { * * `numerical_grad` is the pointer to the AbstractTensorHandle* which will * hold the numerical gradient data at the end of the function. + * + * Note that this function will modify `inputs` and restore it to the original + * data before returning. */ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, absl::Span inputs, diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc index 2764ce8d92d..978e1f759a5 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.cc +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -23,23 +23,20 @@ namespace internal { void CompareWithGradientsCheckers(Model model, Model grad_model, AbstractContext* ctx, - std::vector inputs, + absl::Span inputs, bool use_function, const GradientRegistry& registry) { auto num_inputs = inputs.size(); std::vector outputs(num_inputs); - auto s = - RunModel(grad_model, ctx, absl::MakeSpan(inputs), absl::MakeSpan(outputs), - /*use_function=*/use_function, registry); + auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs), + /*use_function=*/use_function, registry); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); for (int i = 0; i < num_inputs; ++i) { if (!outputs[i]) continue; - std::vector numerical_inputs{inputs}; AbstractTensorHandle* g; // Will contain numerical approximation data. - // TODO(vnvo2409): `CalcNumericalGrad` should not modify `inputs`. - s = CalcNumericalGrad(ctx, model, absl::MakeSpan(numerical_inputs), + s = CalcNumericalGrad(ctx, model, inputs, /*input_index=*/i, /*use_function=*/use_function, &g); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h index 604495c9380..5ce7ed9856f 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.h +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -103,7 +103,7 @@ namespace internal { void CompareWithGradientsCheckers(Model model, Model grad_model, AbstractContext* ctx, - std::vector inputs, + absl::Span inputs, bool use_function, const GradientRegistry& registry); diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 42cce9c5e24..4859ff27d2d 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -83,8 +83,10 @@ TEST_P(CppGradients, TestBiasAddGrad) { AbstractTensorHandlePtr Bias = GetTensorHandleUtilFloat(ctx_.get(), Bias_vals, Bias_dims, 1); + std::vector inputs{A.get(), Bias.get()}; + ASSERT_NO_FATAL_FAILURE(CompareWithGradientsCheckers( - BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()}, + BiasAddModel, BiasAddGradModel, ctx_.get(), absl::MakeSpan(inputs), /*use_function=*/!std::get<2>(GetParam()), registry_)); } From 6644523743c191ca2308b06e6468e3ac5487206d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Sat, 28 Nov 2020 23:02:49 +0700 Subject: [PATCH 05/15] Move `TF_MODEL_FACTORY` and `TF_GRAD_MODEL_FACTORY` to a seperate header --- tensorflow/c/experimental/gradients/BUILD | 6 ++ .../experimental/gradients/grad_test_helper.h | 80 --------------- .../gradients/model_factory_helper.h | 98 +++++++++++++++++++ .../c/experimental/gradients/nn_grad_test.cc | 1 + 4 files changed, 105 insertions(+), 80 deletions(-) create mode 100644 tensorflow/c/experimental/gradients/model_factory_helper.h diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 0d0ebebf88e..a5006b1da7c 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -112,6 +112,11 @@ filegroup( ], ) +cc_library( + name = "model_factory_helper", + hdrs = ["model_factory_helper.h"], +) + cc_library( name = "grad_test_helper", testonly = True, @@ -139,6 +144,7 @@ tf_cuda_cc_test( tags = tf_cuda_tests_tags() + ["nomac"], deps = [ ":grad_test_helper", + ":model_factory_helper", ":nn_grad", "//tensorflow/c/eager:c_api_test_util", "//tensorflow/c/experimental/gradients/tape:tape_context", diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h index 5ce7ed9856f..be1f230168e 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.h +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -21,86 +21,6 @@ namespace tensorflow { namespace gradients { namespace internal { -// This macro will expand to a function that defines a `Model`. This `Model` is -// then used for testing by `nn_grad_test` and `math_grad_test`. `ops_call` is a -// statement that calls to a `ops::` and should be wrapped around by `{}`. -// `ops_call` has access to `inputs`. The output parameter of the ops should -// always be `absl::MakeSpan(temp_outputs)`. This macro supports most one-ops -// model. -// TODO(vnvo2409): Extend support for more complex model. -#define TF_MODEL_FACTORY(name, num_inputs, num_outputs, ops_call) \ - Status name(AbstractContext* ctx, \ - absl::Span inputs, \ - absl::Span outputs, \ - const GradientRegistry& registry) { \ - auto tape = new Tape(/*persistent=*/false); \ - for (int i{}; i < num_inputs; ++i) { \ - tape->Watch(ToId(inputs[i])); \ - } \ - \ - AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ - ops_call; \ - \ - for (int i{}; i < num_outputs; ++i) { \ - outputs[i] = temp_outputs[i]; \ - } \ - delete tape; \ - return Status::OK(); \ - } - -// This macro will expand to a function that defines a `GradModel`. This -// `GradModel` is then used for testing by `nn_grad_test` and `math_grad_test`. -// `ops_call` is a statement that calls to a `ops::` and should be wrapped -// around by `{}`. `ops_call` has access to `inputs`. The output parameter of -// the ops should always be `absl::MakeSpan(temp_outputs)`. This macro supports -// most one-ops model. -// TODO(vnvo2409): Extend support for more complex model. -#define TF_GRAD_MODEL_FACTORY(name, num_inputs, num_outputs, num_grad_outputs, \ - ops_call) \ - Status name(AbstractContext* ctx, \ - absl::Span inputs, \ - absl::Span outputs, \ - const GradientRegistry& registry) { \ - TapeVSpace vspace(ctx); \ - auto tape = new Tape(/*persistent=*/false); \ - for (int i{}; i < num_inputs; ++i) { \ - tape->Watch(ToId(inputs[i])); \ - } \ - \ - AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ - ops_call; \ - \ - std::unordered_map \ - source_tensors_that_are_targets; \ - std::vector out_grads(num_grad_outputs); \ - \ - int64 target_tensor_ids[num_outputs] = {}; \ - for (int i{}; i < num_outputs; ++i) { \ - target_tensor_ids[i] = ToId(temp_outputs[i]); \ - } \ - \ - int64 source_tensor_ids[num_inputs] = {}; \ - for (int i{}; i < num_inputs; ++i) { \ - source_tensor_ids[i] = ToId(inputs[i]); \ - } \ - \ - TF_RETURN_IF_ERROR(tape->ComputeGradient( \ - vspace, target_tensor_ids, source_tensor_ids, \ - source_tensors_that_are_targets, /*output_gradients=*/{}, &out_grads, \ - /*build_default_zeros_grads=*/false)); \ - \ - for (int i{}; i < num_outputs; ++i) { \ - temp_outputs[i]->Unref(); \ - } \ - for (int i{}; i < num_grad_outputs; ++i) { \ - outputs[i] = out_grads[i]; \ - } \ - delete tape; \ - return Status::OK(); \ - } - void CompareWithGradientsCheckers(Model model, Model grad_model, AbstractContext* ctx, absl::Span inputs, diff --git a/tensorflow/c/experimental/gradients/model_factory_helper.h b/tensorflow/c/experimental/gradients/model_factory_helper.h new file mode 100644 index 00000000000..02dc0e7b935 --- /dev/null +++ b/tensorflow/c/experimental/gradients/model_factory_helper.h @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MODEL_FACTORY_MACRO_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MODEL_FACTORY_MACRO_H_ + +// This macro will expand to a function that defines a `Model`. This `Model` is +// then used for testing by `nn_grad_test` and `math_grad_test`. `ops_call` is a +// statement that calls to a `ops::` and should be wrapped around by `{}`. +// `ops_call` has access to `inputs`. The output parameter of the ops should +// always be `absl::MakeSpan(temp_outputs)`. This macro supports most one-ops +// model. +// TODO(vnvo2409): Extend support for more complex model. +#define TF_MODEL_FACTORY(name, num_inputs, num_outputs, ops_call) \ + Status name(AbstractContext* ctx, \ + absl::Span inputs, \ + absl::Span outputs, \ + const GradientRegistry& registry) { \ + auto tape = new Tape(/*persistent=*/false); \ + for (int i{}; i < num_inputs; ++i) { \ + tape->Watch(ToId(inputs[i])); \ + } \ + \ + AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ + ops_call; \ + \ + for (int i{}; i < num_outputs; ++i) { \ + outputs[i] = temp_outputs[i]; \ + } \ + delete tape; \ + return Status::OK(); \ + } + +// This macro will expand to a function that defines a `GradModel`. This +// `GradModel` is then used for testing by `nn_grad_test` and `math_grad_test`. +// `ops_call` is a statement that calls to a `ops::` and should be wrapped +// around by `{}`. `ops_call` has access to `inputs`. The output parameter of +// the ops should always be `absl::MakeSpan(temp_outputs)`. This macro supports +// most one-ops model. +// TODO(vnvo2409): Extend support for more complex model. +#define TF_GRAD_MODEL_FACTORY(name, num_inputs, num_outputs, num_grad_outputs, \ + ops_call) \ + Status name(AbstractContext* ctx, \ + absl::Span inputs, \ + absl::Span outputs, \ + const GradientRegistry& registry) { \ + TapeVSpace vspace(ctx); \ + auto tape = new Tape(/*persistent=*/false); \ + for (int i{}; i < num_inputs; ++i) { \ + tape->Watch(ToId(inputs[i])); \ + } \ + \ + AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ + ops_call; \ + \ + std::unordered_map \ + source_tensors_that_are_targets; \ + std::vector out_grads(num_grad_outputs); \ + \ + int64 target_tensor_ids[num_outputs] = {}; \ + for (int i{}; i < num_outputs; ++i) { \ + target_tensor_ids[i] = ToId(temp_outputs[i]); \ + } \ + \ + int64 source_tensor_ids[num_inputs] = {}; \ + for (int i{}; i < num_inputs; ++i) { \ + source_tensor_ids[i] = ToId(inputs[i]); \ + } \ + \ + TF_RETURN_IF_ERROR(tape->ComputeGradient( \ + vspace, target_tensor_ids, source_tensor_ids, \ + source_tensors_that_are_targets, /*output_gradients=*/{}, &out_grads, \ + /*build_default_zeros_grads=*/false)); \ + \ + for (int i{}; i < num_outputs; ++i) { \ + temp_outputs[i]->Unref(); \ + } \ + for (int i{}; i < num_grad_outputs; ++i) { \ + outputs[i] = out_grads[i]; \ + } \ + delete tape; \ + return Status::OK(); \ + } + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MODEL_FACTORY_MACRO_H_ diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 4859ff27d2d..7ae80ed3da5 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/experimental/gradients/grad_test_helper.h" +#include "tensorflow/c/experimental/gradients/model_factory_helper.h" #include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/core/platform/test.h" From fa1cfa16ec7e1f65e6b37cc09c0a508f54e1a031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Wed, 2 Dec 2020 15:42:42 +0700 Subject: [PATCH 06/15] remove MODEL macro --- tensorflow/c/experimental/gradients/BUILD | 8 +- .../gradients/model_factory_helper.h | 98 ------------------- .../c/experimental/gradients/nn_grad_test.cc | 41 ++++++-- 3 files changed, 35 insertions(+), 112 deletions(-) delete mode 100644 tensorflow/c/experimental/gradients/model_factory_helper.h diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index a5006b1da7c..c1f09a45d5e 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -112,11 +112,6 @@ filegroup( ], ) -cc_library( - name = "model_factory_helper", - hdrs = ["model_factory_helper.h"], -) - cc_library( name = "grad_test_helper", testonly = True, @@ -141,10 +136,9 @@ tf_cuda_cc_test( ], args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags() + ["nomac"], + tags = tf_cuda_tests_tags(), deps = [ ":grad_test_helper", - ":model_factory_helper", ":nn_grad", "//tensorflow/c/eager:c_api_test_util", "//tensorflow/c/experimental/gradients/tape:tape_context", diff --git a/tensorflow/c/experimental/gradients/model_factory_helper.h b/tensorflow/c/experimental/gradients/model_factory_helper.h deleted file mode 100644 index 02dc0e7b935..00000000000 --- a/tensorflow/c/experimental/gradients/model_factory_helper.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MODEL_FACTORY_MACRO_H_ -#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MODEL_FACTORY_MACRO_H_ - -// This macro will expand to a function that defines a `Model`. This `Model` is -// then used for testing by `nn_grad_test` and `math_grad_test`. `ops_call` is a -// statement that calls to a `ops::` and should be wrapped around by `{}`. -// `ops_call` has access to `inputs`. The output parameter of the ops should -// always be `absl::MakeSpan(temp_outputs)`. This macro supports most one-ops -// model. -// TODO(vnvo2409): Extend support for more complex model. -#define TF_MODEL_FACTORY(name, num_inputs, num_outputs, ops_call) \ - Status name(AbstractContext* ctx, \ - absl::Span inputs, \ - absl::Span outputs, \ - const GradientRegistry& registry) { \ - auto tape = new Tape(/*persistent=*/false); \ - for (int i{}; i < num_inputs; ++i) { \ - tape->Watch(ToId(inputs[i])); \ - } \ - \ - AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ - ops_call; \ - \ - for (int i{}; i < num_outputs; ++i) { \ - outputs[i] = temp_outputs[i]; \ - } \ - delete tape; \ - return Status::OK(); \ - } - -// This macro will expand to a function that defines a `GradModel`. This -// `GradModel` is then used for testing by `nn_grad_test` and `math_grad_test`. -// `ops_call` is a statement that calls to a `ops::` and should be wrapped -// around by `{}`. `ops_call` has access to `inputs`. The output parameter of -// the ops should always be `absl::MakeSpan(temp_outputs)`. This macro supports -// most one-ops model. -// TODO(vnvo2409): Extend support for more complex model. -#define TF_GRAD_MODEL_FACTORY(name, num_inputs, num_outputs, num_grad_outputs, \ - ops_call) \ - Status name(AbstractContext* ctx, \ - absl::Span inputs, \ - absl::Span outputs, \ - const GradientRegistry& registry) { \ - TapeVSpace vspace(ctx); \ - auto tape = new Tape(/*persistent=*/false); \ - for (int i{}; i < num_inputs; ++i) { \ - tape->Watch(ToId(inputs[i])); \ - } \ - \ - AbstractTensorHandle* temp_outputs[num_outputs] = {}; \ - AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); \ - ops_call; \ - \ - std::unordered_map \ - source_tensors_that_are_targets; \ - std::vector out_grads(num_grad_outputs); \ - \ - int64 target_tensor_ids[num_outputs] = {}; \ - for (int i{}; i < num_outputs; ++i) { \ - target_tensor_ids[i] = ToId(temp_outputs[i]); \ - } \ - \ - int64 source_tensor_ids[num_inputs] = {}; \ - for (int i{}; i < num_inputs; ++i) { \ - source_tensor_ids[i] = ToId(inputs[i]); \ - } \ - \ - TF_RETURN_IF_ERROR(tape->ComputeGradient( \ - vspace, target_tensor_ids, source_tensor_ids, \ - source_tensors_that_are_targets, /*output_gradients=*/{}, &out_grads, \ - /*build_default_zeros_grads=*/false)); \ - \ - for (int i{}; i < num_outputs; ++i) { \ - temp_outputs[i]->Unref(); \ - } \ - for (int i{}; i < num_grad_outputs; ++i) { \ - outputs[i] = out_grads[i]; \ - } \ - delete tape; \ - return Status::OK(); \ - } - -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MODEL_FACTORY_MACRO_H_ diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 7ae80ed3da5..13ff0369096 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/experimental/gradients/grad_test_helper.h" -#include "tensorflow/c/experimental/gradients/model_factory_helper.h" #include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/core/platform/test.h" @@ -28,15 +27,43 @@ namespace { using tensorflow::TF_StatusPtr; using tracing::TracingOperation; -TF_MODEL_FACTORY(BiasAddModel, 2, 1, { - TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, - absl::MakeSpan(temp_outputs), "BiasAdd")); -}) +Status BiasAddModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd"); +} -TF_GRAD_MODEL_FACTORY(BiasAddGradModel, 2, 1, 2, { +Status BiasAddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch A. + tape->Watch(ToId(inputs[1])); // Watch Bias. + std::vector temp_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs), "BiasAddGrad")); -}) + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(temp_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto temp_output : temp_outputs) { + temp_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer)); From 80b1f26e2dde9a5299c48b3ec0ce7bfc4ffed620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Wed, 2 Dec 2020 15:59:19 +0700 Subject: [PATCH 07/15] `gradient_checker` takes const inputs --- tensorflow/c/eager/gradient_checker.cc | 21 +++++++++++-------- tensorflow/c/eager/gradient_checker.h | 5 +---- .../gradients/grad_test_helper.cc | 9 ++++---- .../experimental/gradients/grad_test_helper.h | 9 ++++---- .../c/experimental/gradients/nn_grad_test.cc | 2 +- 5 files changed, 22 insertions(+), 24 deletions(-) diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index c2da03f28dd..50c67d76ae6 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -66,7 +66,7 @@ void GetDims(const TF_Tensor* t, int64_t* out_dims) { // Runs model as is if output is a scalar, // else sums the output tensor before returning. Status RunAndMaybeSum(AbstractContext* ctx, Model forward, - absl::Span inputs, + absl::Span inputs, absl::Span outputs, bool use_function) { GradientRegistry registry; @@ -107,11 +107,16 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, // ========================= End Helper Functions============================== Status CalcNumericalGrad(AbstractContext* ctx, Model forward, - absl::Span inputs, + absl::Span inputs, int input_index, bool use_function, AbstractTensorHandle** numerical_grad) { + vector theta_inputs(inputs.size()); + for (int i{}; i < inputs.size(); ++i) { + theta_inputs[i] = inputs[i]; + } + AbstractTensorHandle* theta = - inputs[input_index]; // parameter we are grad checking + theta_inputs[input_index]; // parameter we are grad checking // Convert from AbstractTensor to TF_Tensor. TF_Tensor* theta_tensor; @@ -159,14 +164,14 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, ctx, thetaMinus_data.data(), theta_dims.data(), num_dims); // Get f(theta + eps): - inputs[input_index] = thetaPlus.get(); - TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + theta_inputs[input_index] = thetaPlus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs, absl::MakeSpan(f_outputs), use_function)); AbstractTensorHandle* fPlus = f_outputs[0]; // Get f(theta - eps): - inputs[input_index] = thetaMinus.get(); - TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, inputs, + theta_inputs[input_index] = thetaMinus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, theta_inputs, absl::MakeSpan(f_outputs), use_function)); AbstractTensorHandle* fMinus = f_outputs[0]; @@ -191,8 +196,6 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, dtheta_approx[i] = grad_data[0]; } - // Restore the inputs - inputs[input_index] = theta; // Populate *numerical_grad with the data from dtheta_approx. TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat( ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h index f4b9ac544eb..705318b02e2 100644 --- a/tensorflow/c/eager/gradient_checker.h +++ b/tensorflow/c/eager/gradient_checker.h @@ -43,12 +43,9 @@ namespace gradients { * * `numerical_grad` is the pointer to the AbstractTensorHandle* which will * hold the numerical gradient data at the end of the function. - * - * Note that this function will modify `inputs` and restore it to the original - * data before returning. */ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, - absl::Span inputs, + absl::Span inputs, int input_index, bool use_function, AbstractTensorHandle** numerical_grad); diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc index 978e1f759a5..51ebdd7f75d 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.cc +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -21,11 +21,10 @@ namespace tensorflow { namespace gradients { namespace internal { -void CompareWithGradientsCheckers(Model model, Model grad_model, - AbstractContext* ctx, - absl::Span inputs, - bool use_function, - const GradientRegistry& registry) { +void CompareWithGradientsCheckers( + Model model, Model grad_model, AbstractContext* ctx, + absl::Span inputs, bool use_function, + const GradientRegistry& registry) { auto num_inputs = inputs.size(); std::vector outputs(num_inputs); auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs), diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h index be1f230168e..f310c046b5d 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.h +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -21,11 +21,10 @@ namespace tensorflow { namespace gradients { namespace internal { -void CompareWithGradientsCheckers(Model model, Model grad_model, - AbstractContext* ctx, - absl::Span inputs, - bool use_function, - const GradientRegistry& registry); +void CompareWithGradientsCheckers( + Model model, Model grad_model, AbstractContext* ctx, + absl::Span inputs, bool use_function, + const GradientRegistry& registry); } // namespace internal } // namespace gradients diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 13ff0369096..7ab7742bbdc 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -114,7 +114,7 @@ TEST_P(CppGradients, TestBiasAddGrad) { std::vector inputs{A.get(), Bias.get()}; ASSERT_NO_FATAL_FAILURE(CompareWithGradientsCheckers( - BiasAddModel, BiasAddGradModel, ctx_.get(), absl::MakeSpan(inputs), + BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()}, /*use_function=*/!std::get<2>(GetParam()), registry_)); } From 2cdf8665715398ef4f1c0af2f2d19d12212e9dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Wed, 2 Dec 2020 23:49:18 +0700 Subject: [PATCH 08/15] rename `CompareWithGradientChecker` to `CompareNumericalAndAutodiffGradients` --- tensorflow/c/experimental/gradients/grad_test_helper.cc | 6 +++--- tensorflow/c/experimental/gradients/grad_test_helper.h | 4 ++-- tensorflow/c/experimental/gradients/nn_grad_test.cc | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.cc b/tensorflow/c/experimental/gradients/grad_test_helper.cc index 51ebdd7f75d..4031f8c867d 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.cc +++ b/tensorflow/c/experimental/gradients/grad_test_helper.cc @@ -21,10 +21,10 @@ namespace tensorflow { namespace gradients { namespace internal { -void CompareWithGradientsCheckers( +void CompareNumericalAndAutodiffGradients( Model model, Model grad_model, AbstractContext* ctx, absl::Span inputs, bool use_function, - const GradientRegistry& registry) { + const GradientRegistry& registry, double abs_error) { auto num_inputs = inputs.size(); std::vector outputs(num_inputs); auto s = RunModel(grad_model, ctx, inputs, absl::MakeSpan(outputs), @@ -60,7 +60,7 @@ void CompareWithGradientsCheckers( TF_TensorByteSize(analytical_tensor)); for (int j = 0; j < num_elem_numerical; j++) { - ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2); + ASSERT_NEAR(dnumerical[j], danalytical[j], abs_error); } TF_DeleteTensor(analytical_tensor); TF_DeleteTensor(numerical_tensor); diff --git a/tensorflow/c/experimental/gradients/grad_test_helper.h b/tensorflow/c/experimental/gradients/grad_test_helper.h index f310c046b5d..78b2d5b41ef 100644 --- a/tensorflow/c/experimental/gradients/grad_test_helper.h +++ b/tensorflow/c/experimental/gradients/grad_test_helper.h @@ -21,10 +21,10 @@ namespace tensorflow { namespace gradients { namespace internal { -void CompareWithGradientsCheckers( +void CompareNumericalAndAutodiffGradients( Model model, Model grad_model, AbstractContext* ctx, absl::Span inputs, bool use_function, - const GradientRegistry& registry); + const GradientRegistry& registry, double abs_error = 1e-2); } // namespace internal } // namespace gradients diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 7ab7742bbdc..58d4f4071d8 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -113,7 +113,7 @@ TEST_P(CppGradients, TestBiasAddGrad) { std::vector inputs{A.get(), Bias.get()}; - ASSERT_NO_FATAL_FAILURE(CompareWithGradientsCheckers( + ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()}, /*use_function=*/!std::get<2>(GetParam()), registry_)); } From 21a22d44721b1a23db0277b0bbd5725163378af4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 3 Dec 2020 12:36:20 +0700 Subject: [PATCH 09/15] fix internal build --- tensorflow/c/experimental/gradients/BUILD | 1 - tensorflow/c/experimental/gradients/nn_grad.cc | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index c1f09a45d5e..2349f6647c9 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -134,7 +134,6 @@ tf_cuda_cc_test( srcs = [ "nn_grad_test.cc", ], - args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags(), deps = [ diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc index 8dcdaff9d8b..17adde35f9f 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.cc +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -131,7 +131,7 @@ class BiasAddGradientFunction : public GradientFunction { // Recover data format from forward pass for gradient. std::string data_format; - forward_attrs.Get("data_format", &data_format); + TF_RETURN_IF_ERROR(forward_attrs.Get("data_format", &data_format)); // Grad for A (*grad_outputs)[0] = upstream_grad; From 145c5d6daa299a6ca6bbc7ef52e26e921661e3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 3 Dec 2020 13:17:47 +0700 Subject: [PATCH 10/15] use `UseFunction` and `UseMlir` --- tensorflow/c/experimental/gradients/nn_grad_test.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 58d4f4071d8..6e0278e4333 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -93,10 +93,14 @@ class CppGradients GradientRegistry registry_; AbstractContextPtr ctx_; + + public: + bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } + bool UseFunction() const { return std::get<2>(GetParam()); } }; TEST_P(CppGradients, TestBiasAddGrad) { - if (std::get<0>(GetParam()) == "mlir" && !std::get<2>(GetParam())) { + if (!UseFunction() && UseMlir()) { GTEST_SKIP() << "SetAttrString has not been implemented yet.\n"; } From e5bd5997ac43caa89dfda68aba83ab346a92b10e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 3 Dec 2020 13:34:50 +0700 Subject: [PATCH 11/15] use `UseFunction()` for `CompareNumericalAndAutodiffGradients` --- tensorflow/c/experimental/gradients/nn_grad_test.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 6e0278e4333..3a6d1f31b49 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -96,11 +96,11 @@ class CppGradients public: bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } - bool UseFunction() const { return std::get<2>(GetParam()); } + bool UseFunction() const { return !std::get<2>(GetParam()); } }; TEST_P(CppGradients, TestBiasAddGrad) { - if (!UseFunction() && UseMlir()) { + if (UseFunction() && UseMlir()) { GTEST_SKIP() << "SetAttrString has not been implemented yet.\n"; } @@ -119,7 +119,7 @@ TEST_P(CppGradients, TestBiasAddGrad) { ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients( BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()}, - /*use_function=*/!std::get<2>(GetParam()), registry_)); + /*use_function=*/UseFunction(), registry_)); } #ifdef PLATFORM_GOOGLE @@ -127,13 +127,13 @@ INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), - /*executing_eagerly*/ ::testing::Values(true, false))); + /*use_function*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, ::testing::Combine(::testing::Values("graphdef", "mlir"), /*tfrt*/ ::testing::Values(false), - /*executing_eagerly*/ ::testing::Values(true, false))); + /*use_function*/ ::testing::Values(true, false))); #endif } // namespace } // namespace internal From 669c2e4b178df4daac5c16245fd1fb03d43b63c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 3 Dec 2020 13:44:09 +0700 Subject: [PATCH 12/15] Revert `UseFunction()` to original --- tensorflow/c/experimental/gradients/nn_grad_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc index 3a6d1f31b49..ce196a9f8f9 100644 --- a/tensorflow/c/experimental/gradients/nn_grad_test.cc +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -96,7 +96,7 @@ class CppGradients public: bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } - bool UseFunction() const { return !std::get<2>(GetParam()); } + bool UseFunction() const { return std::get<2>(GetParam()); } }; TEST_P(CppGradients, TestBiasAddGrad) { From 36c901f93d665b9894606d116ce6aaea1693389d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 3 Dec 2020 20:19:51 +0700 Subject: [PATCH 13/15] Remove unused in bazel --- tensorflow/c/experimental/gradients/BUILD | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 2349f6647c9..013e27ea594 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -1,12 +1,10 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +# buildifier: disable=same-origin-load load( "//tensorflow:tensorflow.bzl", - "if_libtpu", - "tf_cc_test", - "tf_copts", "tf_cuda_cc_test", - "tf_cuda_library", ) load( "//tensorflow/core/platform:build_config.bzl", From 6ed0590d0e40bb0795d022db9462fafcb451aa6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 3 Dec 2020 22:08:17 +0700 Subject: [PATCH 14/15] Disable cuda asan for nn_grad_test --- tensorflow/c/experimental/gradients/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 013e27ea594..3ad2f67cc0a 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -133,7 +133,7 @@ tf_cuda_cc_test( "nn_grad_test.cc", ], linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, deps = [ ":grad_test_helper", ":nn_grad", From 4df993c625d2e063ceb29357eab5c7ff925ffa54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Thu, 3 Dec 2020 23:10:32 +0700 Subject: [PATCH 15/15] add `--heap_check=local` --- tensorflow/c/experimental/gradients/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 3ad2f67cc0a..acdd823a858 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -132,6 +132,7 @@ tf_cuda_cc_test( srcs = [ "nn_grad_test.cc", ], + args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156, deps = [