add BiasAddGradient

This commit is contained in:
Võ Văn Nghĩa 2020-11-20 21:48:55 +07:00
parent 74ecc3ec25
commit 7b9e2ff862
8 changed files with 448 additions and 0 deletions

View File

@ -1,5 +1,21 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"if_libtpu",
"tf_cc_test",
"tf_copts",
"tf_cuda_cc_test",
"tf_cuda_library",
)
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_kernel_tests_linkstatic",
)
load(
"//tensorflow/core/platform:build_config_root.bzl",
"tf_cuda_tests_tags",
)
# Library of gradient functions.
package(
@ -95,3 +111,69 @@ filegroup(
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "nn_grad_testutil",
srcs = [
"nn_grad_testutil.cc",
],
hdrs = [
"nn_grad_testutil.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/eager:gradients_util",
"//tensorflow/c/eager:tape",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test(
name = "nn_grad_test",
size = "small",
srcs = [
"nn_grad_test.cc",
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients_internal",
"//tensorflow/c/eager:unified_api_testutil",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/gradients:nn_grad_testutil",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
using std::vector;
using tensorflow::ops::BiasAddGrad;
using tensorflow::ops::Mul;
using tensorflow::ops::ReluGrad;
@ -110,6 +111,48 @@ class SparseSoftmaxCrossEntropyWithLogitsGradientFunction
vector<AbstractTensorHandle*> forward_outputs;
};
// TODO(vnvo2409): Add python test
class BiasAddGradientFunction : public GradientFunction {
public:
explicit BiasAddGradientFunction(AttrBuilder f_attrs)
: forward_attrs(f_attrs) {}
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
/* Given upstream grad U and a BiasAdd: A + bias, the gradients are:
*
* dA = U
* dbias = reduceSum(U, dims = channel_dim)
*/
AbstractTensorHandle* upstream_grad = grad_inputs[0];
DCHECK(upstream_grad);
grad_outputs->resize(2);
// Recover data format from forward pass for gradient.
std::string data_format;
forward_attrs.Get("data_format", &data_format);
// Grad for A
(*grad_outputs)[0] = upstream_grad;
(*grad_outputs)[0]->Ref();
// Grad for bias
vector<AbstractTensorHandle*> bias_add_grad_outputs(1);
std::string name = "bias_add_grad";
TF_RETURN_IF_ERROR(BiasAddGrad(ctx->ctx, {upstream_grad},
absl::MakeSpan(bias_add_grad_outputs),
data_format.c_str(), name.c_str()));
(*grad_outputs)[1] = bias_add_grad_outputs[0];
return Status::OK();
}
~BiasAddGradientFunction() override {}
private:
AttrBuilder forward_attrs;
};
} // namespace
BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
@ -129,5 +172,14 @@ BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* BiasAddRegisterer(const ForwardOperation& op) {
// For ops with a single output, the gradient function is not called if there
// is no incoming gradient. So we do not need to worry about creating zeros
// grads in this case.
auto gradient_function = new BiasAddGradientFunction(op.attrs);
auto default_gradients = new PassThroughDefaultGradients(op);
return new BackwardFunction(gradient_function, default_gradients);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ namespace gradients {
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op);
BackwardFunction* BiasAddRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,154 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using std::vector;
using tensorflow::TF_StatusPtr;
using tracing::TracingOperation;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("BiasAdd", BiasAddRegisterer));
return Status::OK();
}
TEST_P(CppGradients, TestBiasAddGrad) {
if (std::get<0>(GetParam()) == "mlir" && std::get<2>(GetParam()) == 0) {
GTEST_SKIP() << "SetAttrString has not been implemented yet.\n";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float Bias_vals[] = {2.0f, 3.0f};
int64_t Bias_dims[] = {2};
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, 2);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), Bias_vals, Bias_dims, 1);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(A)
* tape.watch(B)
* Y = BiasAdd(A, Bias)
* outputs = tape.gradient(Y, [A, Bias])
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(BiasAddGradModel, ctx.get(), {A.get(), B.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data_dA[4] = {0};
memcpy(&result_data_dA[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float expected_dA[4] = {1.0f, 1.0f, 1.0f, 1.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data_dA[j], expected_dA[j], tolerance);
}
TF_Tensor* dBias_tensor;
s = GetValue(outputs[1], &dBias_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data_dBias[2] = {0};
memcpy(&result_data_dBias[0], TF_TensorData(dBias_tensor),
TF_TensorByteSize(dBias_tensor));
float expected_dBias[2] = {2.0f, 2.0f};
for (int j = 0; j < 2; j++) {
ASSERT_NEAR(result_data_dBias[j], expected_dBias[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(dA_tensor);
TF_DeleteTensor(dBias_tensor);
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,71 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/nn_grad_testutil.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes
// y = BiasAdd(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
Status BiasAddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
tape->Watch(ToId(inputs[1])); // Watch y.
std::vector<AbstractTensorHandle*> bias_add_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs,
absl::MakeSpan(bias_add_outputs), "BiasAdd"));
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(bias_add_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads,
/*build_default_zeros_grads=*/false));
for (auto bias_add_output : bias_add_outputs) {
bias_add_output->Unref();
}
outputs[0] = out_grads[0];
outputs[1] = out_grads[1];
delete tape;
return Status::OK();
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,46 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes
// y = BiasAdd(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
Status BiasAddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_TESTUTIL_H_

View File

@ -69,5 +69,38 @@ Status Relu(AbstractContext* ctx,
return Status::OK();
}
Status BiasAdd(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr bias_add_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
bias_add_op->Reset("BiasAdd", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_op.get(), name));
TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[0])); // tensor input
TF_RETURN_IF_ERROR(bias_add_op->AddInput(inputs[1])); // bias
int num_retvals = 1;
TF_RETURN_IF_ERROR(bias_add_op->Execute(outputs, &num_retvals));
return Status::OK();
}
// Computes Bias Add gradient given upstream grads
Status BiasAddGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const char* data_format, const char* name) {
AbstractOperationPtr bias_add_grad_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
bias_add_grad_op->Reset("BiasAddGrad", /*raw_device_name=*/nullptr));
TF_RETURN_IF_ERROR(MaybeSetOpName(bias_add_grad_op.get(), name));
TF_RETURN_IF_ERROR(bias_add_grad_op->SetAttrString("data_format", data_format,
strlen(data_format)));
TF_RETURN_IF_ERROR(bias_add_grad_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(bias_add_grad_op->Execute(outputs, &num_retvals));
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -34,6 +34,15 @@ Status Relu(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status BiasAdd(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status BiasAddGrad(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const char* data_format, const char* name);
} // namespace ops
} // namespace tensorflow