From 7b9e2ff8626f0d99d65ed58d2afa55780c28f287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=B5=20V=C4=83n=20Ngh=C4=A9a?= Date: Fri, 20 Nov 2020 21:48:55 +0700 Subject: [PATCH] add BiasAddGradient --- tensorflow/c/experimental/gradients/BUILD | 82 ++++++++++ .../c/experimental/gradients/nn_grad.cc | 52 ++++++ tensorflow/c/experimental/gradients/nn_grad.h | 1 + .../c/experimental/gradients/nn_grad_test.cc | 154 ++++++++++++++++++ .../gradients/nn_grad_testutil.cc | 71 ++++++++ .../experimental/gradients/nn_grad_testutil.h | 46 ++++++ tensorflow/c/experimental/ops/nn_ops.cc | 33 ++++ tensorflow/c/experimental/ops/nn_ops.h | 9 + 8 files changed, 448 insertions(+) create mode 100644 tensorflow/c/experimental/gradients/nn_grad_test.cc create mode 100644 tensorflow/c/experimental/gradients/nn_grad_testutil.cc create mode 100644 tensorflow/c/experimental/gradients/nn_grad_testutil.h diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index e8a50e32216..728f9ab851c 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -1,5 +1,21 @@ load("//tensorflow:tensorflow.bzl", "filegroup") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") +load( + "//tensorflow:tensorflow.bzl", + "if_libtpu", + "tf_cc_test", + "tf_copts", + "tf_cuda_cc_test", + "tf_cuda_library", +) +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_kernel_tests_linkstatic", +) +load( + "//tensorflow/core/platform:build_config_root.bzl", + "tf_cuda_tests_tags", +) # Library of gradient functions. package( @@ -95,3 +111,69 @@ filegroup( "//tensorflow/python:__pkg__", ], ) + +cc_library( + name = "nn_grad_testutil", + srcs = [ + "nn_grad_testutil.cc", + ], + hdrs = [ + "nn_grad_testutil.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:gradients_util", + "//tensorflow/c/eager:tape", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/types:span", + ], +) + +tf_cuda_cc_test( + name = "nn_grad_test", + size = "small", + srcs = [ + "nn_grad_test.cc", + ], + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/c/eager:unified_api_testutil", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/gradients:nn_grad_testutil", + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc index 64532c8ffc0..8dcdaff9d8b 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.cc +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" using std::vector; +using tensorflow::ops::BiasAddGrad; using tensorflow::ops::Mul; using tensorflow::ops::ReluGrad; @@ -110,6 +111,48 @@ class SparseSoftmaxCrossEntropyWithLogitsGradientFunction vector forward_outputs; }; +// TODO(vnvo2409): Add python test +class BiasAddGradientFunction : public GradientFunction { + public: + explicit BiasAddGradientFunction(AttrBuilder f_attrs) + : forward_attrs(f_attrs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a BiasAdd: A + bias, the gradients are: + * + * dA = U + * dbias = reduceSum(U, dims = channel_dim) + */ + + AbstractTensorHandle* upstream_grad = grad_inputs[0]; + DCHECK(upstream_grad); + grad_outputs->resize(2); + + // Recover data format from forward pass for gradient. + std::string data_format; + forward_attrs.Get("data_format", &data_format); + + // Grad for A + (*grad_outputs)[0] = upstream_grad; + (*grad_outputs)[0]->Ref(); + + // Grad for bias + vector bias_add_grad_outputs(1); + std::string name = "bias_add_grad"; + TF_RETURN_IF_ERROR(BiasAddGrad(ctx->ctx, {upstream_grad}, + absl::MakeSpan(bias_add_grad_outputs), + data_format.c_str(), name.c_str())); + + (*grad_outputs)[1] = bias_add_grad_outputs[0]; + return Status::OK(); + } + ~BiasAddGradientFunction() override {} + + private: + AttrBuilder forward_attrs; +}; + } // namespace BackwardFunction* ReluRegisterer(const ForwardOperation& op) { @@ -129,5 +172,14 @@ BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( return new BackwardFunction(gradient_function, default_gradients); } +BackwardFunction* BiasAddRegisterer(const ForwardOperation& op) { + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto gradient_function = new BiasAddGradientFunction(op.attrs); + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h index 034f20d7325..b4406bb5bbc 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.h +++ b/tensorflow/c/experimental/gradients/nn_grad.h @@ -22,6 +22,7 @@ namespace gradients { BackwardFunction* ReluRegisterer(const ForwardOperation& op); BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op); +BackwardFunction* BiasAddRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_test.cc b/tensorflow/c/experimental/gradients/nn_grad_test.cc new file mode 100644 index 00000000000..ee549b8333f --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/nn_grad.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_util.h" +#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { +using std::vector; +using tensorflow::TF_StatusPtr; +using tracing::TracingOperation; + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } +}; + +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer)); + return Status::OK(); +} + +TEST_P(CppGradients, TestBiasAddGrad) { + if (std::get<0>(GetParam()) == "mlir" && std::get<2>(GetParam()) == 0) { + GTEST_SKIP() << "SetAttrString has not been implemented yet.\n"; + } + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + float Bias_vals[] = {2.0f, 3.0f}; + int64_t Bias_dims[] = {2}; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, 2); + AbstractTensorHandlePtr B = + GetTensorHandleUtilFloat(ctx.get(), Bias_vals, Bias_dims, 1); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(A) + * tape.watch(B) + * Y = BiasAdd(A, Bias) + * outputs = tape.gradient(Y, [A, Bias]) + */ + std::vector outputs(2); + s = RunModel(BiasAddGradModel, ctx.get(), {A.get(), B.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dA_tensor; + s = GetValue(outputs[0], &dA_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data_dA[4] = {0}; + memcpy(&result_data_dA[0], TF_TensorData(dA_tensor), + TF_TensorByteSize(dA_tensor)); + + float expected_dA[4] = {1.0f, 1.0f, 1.0f, 1.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data_dA[j], expected_dA[j], tolerance); + } + + TF_Tensor* dBias_tensor; + s = GetValue(outputs[1], &dBias_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data_dBias[2] = {0}; + memcpy(&result_data_dBias[0], TF_TensorData(dBias_tensor), + TF_TensorByteSize(dBias_tensor)); + + float expected_dBias[2] = {2.0f, 2.0f}; + for (int j = 0; j < 2; j++) { + ASSERT_NEAR(result_data_dBias[j], expected_dBias[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(dA_tensor); + TF_DeleteTensor(dBias_tensor); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_testutil.cc b/tensorflow/c/experimental/gradients/nn_grad_testutil.cc new file mode 100644 index 00000000000..6866e3f878c --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad_testutil.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/gradients/tape/tape_context.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +// Computes +// y = BiasAdd(inputs[0], inputs[1]) +// return grad(y, {inputs[0], inputs[1]}) +Status BiasAddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector bias_add_outputs(1); + AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry)); + TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs, + absl::MakeSpan(bias_add_outputs), "BiasAdd")); + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(bias_add_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto bias_add_output : bias_add_outputs) { + bias_add_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +} // namespace internal +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad_testutil.h b/tensorflow/c/experimental/gradients/nn_grad_testutil.h new file mode 100644 index 00000000000..dff65906821 --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad_testutil.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { +namespace gradients { +namespace internal { + +// Computes +// y = BiasAdd(inputs[0], inputs[1]) +// return grad(y, {inputs[0], inputs[1]}) +Status BiasAddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +} // namespace internal +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_ \ No newline at end of file diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc index 6a97dbf0939..b1cc2ffc7d6 100644 --- a/tensorflow/c/experimental/ops/nn_ops.cc +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -69,5 +69,38 @@ Status Relu(AbstractContext* ctx, return Status::OK(); } +Status BiasAdd(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr bias_add_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + bias_add_op->Reset("BiasAdd", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_op.get(), name)); + TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[0])); // tensor input + TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[1])); // bias + + int num_retvals = 1; + TF_RETURN_IF_ERROR(bias_add_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +// Computes Bias Add gradient given upstream grads +Status BiasAddGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const char* data_format, const char* name) { + AbstractOperationPtr bias_add_grad_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + bias_add_grad_op->Reset("BiasAddGrad", /*raw_device_name=*/nullptr)); + TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_grad_op.get(), name)); + TF_RETURN_IF_ERROR(bias_add_grad_op->SetAttrString("data_format", data_format, + strlen(data_format))); + TF_RETURN_IF_ERROR(bias_add_grad_op->AddInput(inputs[0])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(bias_add_grad_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h index 3c0e04579a1..d5b8cdde356 100644 --- a/tensorflow/c/experimental/ops/nn_ops.h +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -34,6 +34,15 @@ Status Relu(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status BiasAdd(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +Status BiasAddGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const char* data_format, const char* name); + } // namespace ops } // namespace tensorflow