ENH: The gradient op of bias_add supports 3/4/5D NCHW format

This commit is contained in:
Yan Facai (颜发才) 2018-12-18 08:51:25 +08:00
parent add19e0e56
commit 17bc7e61e5
No known key found for this signature in database
GPG Key ID: FC09D3D65359DEF7
5 changed files with 58 additions and 38 deletions

View File

@ -18,13 +18,13 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/bias_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/bias_op_gpu.h"
@ -153,13 +153,13 @@ class BiasOp : public BinaryOp<T> {
bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
} break;
case 5: {
Eigen::DSizes<int32, 5> four_dims(1, channel, 1, 1, 1);
Eigen::DSizes<int32, 5> five_dims(1, channel, 1, 1, 1);
Eigen::DSizes<int32, 5> broad_cast_dims(batch, 1, height, width,
depth);
const Device& d = context->eigen_device<Device>();
output->tensor<T, 5>().device(d) =
input.tensor<T, 5>() +
bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
bias.tensor<T, 1>().reshape(five_dims).broadcast(broad_cast_dims);
} break;
default:
OP_REQUIRES(context, false,
@ -269,28 +269,24 @@ class BiasGradOp : public OpKernel {
output->template flat<T>().setZero();
} else {
// Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
// TODO(yongtang): Add 3/4/5 dimensional data support for NCHW format.
if (data_format_ == FORMAT_NCHW) {
OP_REQUIRES(context, output_backprop.dims() == 4,
errors::InvalidArgument(
"NCHW format supports only 4D input/output tensor."));
Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width);
Eigen::DSizes<Eigen::Index, 3> three_dims(batch, channel,
height * width * depth);
#ifdef EIGEN_HAS_INDEX_LIST
using idx0 = Eigen::type2index<0>;
using idx2 = Eigen::type2index<2>;
using idx3 = Eigen::type2index<3>;
Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
Eigen::IndexList<idx0, idx2> reduction_axes;
#else
Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3};
Eigen::array<Eigen::Index, 2> reduction_axes = {0, 2};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
.template cast<typename AccumulatorType<T>::type>()
.reshape(four_dims)
.reshape(three_dims)
.sum(reduction_axes)
.template cast<T>(); // End of code by intel_tf.
} else {
Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width,
Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width * depth,
channel);
#ifdef EIGEN_HAS_INDEX_LIST
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
@ -496,21 +492,21 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
void ComputeWithCustomKernel(OpKernelContext* context,
const Tensor& output_backprop, int32 batch,
int32 width, int32 height, int32 channel,
Tensor* output) {
int32 width, int32 height, int32 depth,
int32 channel, Tensor* output) {
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
output_backprop.template flat<T>().data(),
output->flat<T>().data(), batch, width, height,
channel, data_format_);
depth, channel, data_format_);
}
void ComputeWithReduceSum(OpKernelContext* context,
const Tensor& output_backprop, int32 batch,
int32 width, int32 height, int32 channel,
Tensor* output) {
int32 width, int32 height, int32 depth,
int32 channel, Tensor* output) {
if (data_format_ == FORMAT_NCHW) {
int32 row_count = batch * channel;
int32 col_count = height * width;
int32 col_count = height * width * depth;
Tensor temp_grad_outputs;
// For 'NCHW' format, we perform reduction twice: first HW, then N.
TensorShape temp_grad_output_shape{row_count, col_count};
@ -528,7 +524,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
row_count, col_count);
} else {
// For 'NHWC', we simply apply reduction once on NHW.
int32 row_count = batch * height * width;
int32 row_count = batch * height * width * depth;
int32 col_count = channel;
BiasGradGPU<T>::DoColReduction(
context, const_cast<T*>(output->flat<T>().data()),
@ -561,7 +557,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
int device_id = stream->parent()->device_ordinal();
DataType dtype = output_backprop.dtype();
BiasAddParams bias_parameters = {
{batch, height * width, channel},
{batch, height * width * depth, channel},
data_format_,
dtype,
device_id,
@ -576,7 +572,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
stream->InitTimer(&timer);
stream->ThenStartTimer(&timer);
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
channel, output);
depth, channel, output);
stream->ThenStopTimer(&timer);
uint64 elapsed_microseconds = timer.Microseconds();
VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
@ -589,7 +585,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
// Try reduction and profile.
stream->ThenStartTimer(&timer);
ComputeWithReduceSum(context, output_backprop, batch, width, height,
channel, output);
depth, channel, output);
stream->ThenStopTimer(&timer);
elapsed_microseconds = timer.Microseconds();
@ -610,11 +606,11 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
// Choose the best algorithm based on autotune results.
if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
ComputeWithReduceSum(context, output_backprop, batch, width, height,
channel, output);
depth, channel, output);
} else {
// Default to the customized kernel.
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
channel, output);
depth, channel, output);
}
}

View File

@ -195,10 +195,10 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
template <typename T>
void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
T* bias_backprop, int32 batch, int32 height,
int32 width, int32 channel,
int32 width, int32 depth, int32 channel,
TensorFormat data_format) {
const int32 bias_size = channel;
const int32 image_size = height * width;
const int32 image_size = height * width * depth;
const int32 total_count = batch * bias_size * image_size;
if (total_count == 0) {
return;

View File

@ -39,7 +39,7 @@ template <typename T>
struct BiasGradGPU {
static void compute(const GPUDevice& device, const T* output_backprop,
T* bias_backprop, int32 batch, int32 height, int32 width,
int32 channel, TensorFormat data_format);
int32 depth, int32 channel, TensorFormat data_format);
static void DoRowReduction(OpKernelContext* context, T* output,
const T* input, int rows, int cols);

View File

@ -196,9 +196,7 @@ class BiasAddTest(test.TestCase):
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testGradientTensor(self):
# TODO(yongtang): BiasAddGrad with NCHW only works 4D. Reenable once
# all dimensions are supported.
def testGradientTensor2D(self):
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.array(
@ -207,9 +205,19 @@ class BiasAddTest(test.TestCase):
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor3D(self):
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.array(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
dtype=dtype.as_numpy_dtype).reshape(1, 3, 2)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor4D(self):
# BiasAddGrad with NCHW support 4D so all are enabled.
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
@ -219,6 +227,17 @@ class BiasAddTest(test.TestCase):
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testGradientTensor5D(self):
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
np_input = np.arange(
1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape(
[1, 2, 3, 4, 2]).astype(np.float32)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
@test_util.run_deprecated_v1
def testEmpty(self):
np.random.seed(7)
@ -227,10 +246,15 @@ class BiasAddTest(test.TestCase):
@test_util.run_deprecated_v1
def testEmptyGradient(self):
# TODO(yongtang): BiasAddGrad with NCHW only works 4D. Reenable once
# all dimensions are supported.
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
for shape in (0, 0), (2, 0), (0, 2):
self._testGradient(
np.random.randn(*shape),
np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu)
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
("NCHW", False), ("NCHW", True)]:
for shape in (4, 3, 0), (4, 0, 3), (0, 4, 3):
self._testGradient(
np.random.randn(*shape),
np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu)

View File

@ -314,10 +314,10 @@ def _BiasAddGradGrad(op, received_grad):
if data_format == b"NCHW":
expanded_shape = array_ops.concat([
array_ops.ones_like(shape[:-3]), bias_shape,
array_ops.ones_like(shape[-2:])
array_ops.ones_like(shape[:1]), bias_shape,
array_ops.ones_like(shape[2:])
], 0)
tile_mults = array_ops.concat([shape[:-3], [1], shape[-2:]], 0)
tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0)
else:
expanded_shape = array_ops.concat(
[array_ops.ones_like(shape[:-1]), bias_shape], 0)