ENH: The gradient op of bias_add supports 3/4/5D NCHW format
This commit is contained in:
parent
add19e0e56
commit
17bc7e61e5
tensorflow
@ -18,13 +18,13 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/bias_op.h"
|
#include "tensorflow/core/kernels/bias_op.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/kernels/bias_op_gpu.h"
|
#include "tensorflow/core/kernels/bias_op_gpu.h"
|
||||||
@ -153,13 +153,13 @@ class BiasOp : public BinaryOp<T> {
|
|||||||
bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
|
bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
|
||||||
} break;
|
} break;
|
||||||
case 5: {
|
case 5: {
|
||||||
Eigen::DSizes<int32, 5> four_dims(1, channel, 1, 1, 1);
|
Eigen::DSizes<int32, 5> five_dims(1, channel, 1, 1, 1);
|
||||||
Eigen::DSizes<int32, 5> broad_cast_dims(batch, 1, height, width,
|
Eigen::DSizes<int32, 5> broad_cast_dims(batch, 1, height, width,
|
||||||
depth);
|
depth);
|
||||||
const Device& d = context->eigen_device<Device>();
|
const Device& d = context->eigen_device<Device>();
|
||||||
output->tensor<T, 5>().device(d) =
|
output->tensor<T, 5>().device(d) =
|
||||||
input.tensor<T, 5>() +
|
input.tensor<T, 5>() +
|
||||||
bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims);
|
bias.tensor<T, 1>().reshape(five_dims).broadcast(broad_cast_dims);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
OP_REQUIRES(context, false,
|
OP_REQUIRES(context, false,
|
||||||
@ -269,28 +269,24 @@ class BiasGradOp : public OpKernel {
|
|||||||
output->template flat<T>().setZero();
|
output->template flat<T>().setZero();
|
||||||
} else {
|
} else {
|
||||||
// Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
|
// Added by intel_tf to support NCHW on CPU regardless of MKL used or not.
|
||||||
// TODO(yongtang): Add 3/4/5 dimensional data support for NCHW format.
|
|
||||||
if (data_format_ == FORMAT_NCHW) {
|
if (data_format_ == FORMAT_NCHW) {
|
||||||
OP_REQUIRES(context, output_backprop.dims() == 4,
|
Eigen::DSizes<Eigen::Index, 3> three_dims(batch, channel,
|
||||||
errors::InvalidArgument(
|
height * width * depth);
|
||||||
"NCHW format supports only 4D input/output tensor."));
|
|
||||||
Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width);
|
|
||||||
#ifdef EIGEN_HAS_INDEX_LIST
|
#ifdef EIGEN_HAS_INDEX_LIST
|
||||||
using idx0 = Eigen::type2index<0>;
|
using idx0 = Eigen::type2index<0>;
|
||||||
using idx2 = Eigen::type2index<2>;
|
using idx2 = Eigen::type2index<2>;
|
||||||
using idx3 = Eigen::type2index<3>;
|
Eigen::IndexList<idx0, idx2> reduction_axes;
|
||||||
Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
|
|
||||||
#else
|
#else
|
||||||
Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3};
|
Eigen::array<Eigen::Index, 2> reduction_axes = {0, 2};
|
||||||
#endif
|
#endif
|
||||||
output->template flat<T>().device(context->eigen_device<Device>()) =
|
output->template flat<T>().device(context->eigen_device<Device>()) =
|
||||||
output_backprop.flat<T>()
|
output_backprop.flat<T>()
|
||||||
.template cast<typename AccumulatorType<T>::type>()
|
.template cast<typename AccumulatorType<T>::type>()
|
||||||
.reshape(four_dims)
|
.reshape(three_dims)
|
||||||
.sum(reduction_axes)
|
.sum(reduction_axes)
|
||||||
.template cast<T>(); // End of code by intel_tf.
|
.template cast<T>(); // End of code by intel_tf.
|
||||||
} else {
|
} else {
|
||||||
Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width,
|
Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width * depth,
|
||||||
channel);
|
channel);
|
||||||
#ifdef EIGEN_HAS_INDEX_LIST
|
#ifdef EIGEN_HAS_INDEX_LIST
|
||||||
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
|
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
|
||||||
@ -496,21 +492,21 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
|||||||
|
|
||||||
void ComputeWithCustomKernel(OpKernelContext* context,
|
void ComputeWithCustomKernel(OpKernelContext* context,
|
||||||
const Tensor& output_backprop, int32 batch,
|
const Tensor& output_backprop, int32 batch,
|
||||||
int32 width, int32 height, int32 channel,
|
int32 width, int32 height, int32 depth,
|
||||||
Tensor* output) {
|
int32 channel, Tensor* output) {
|
||||||
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
|
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
|
||||||
output_backprop.template flat<T>().data(),
|
output_backprop.template flat<T>().data(),
|
||||||
output->flat<T>().data(), batch, width, height,
|
output->flat<T>().data(), batch, width, height,
|
||||||
channel, data_format_);
|
depth, channel, data_format_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComputeWithReduceSum(OpKernelContext* context,
|
void ComputeWithReduceSum(OpKernelContext* context,
|
||||||
const Tensor& output_backprop, int32 batch,
|
const Tensor& output_backprop, int32 batch,
|
||||||
int32 width, int32 height, int32 channel,
|
int32 width, int32 height, int32 depth,
|
||||||
Tensor* output) {
|
int32 channel, Tensor* output) {
|
||||||
if (data_format_ == FORMAT_NCHW) {
|
if (data_format_ == FORMAT_NCHW) {
|
||||||
int32 row_count = batch * channel;
|
int32 row_count = batch * channel;
|
||||||
int32 col_count = height * width;
|
int32 col_count = height * width * depth;
|
||||||
Tensor temp_grad_outputs;
|
Tensor temp_grad_outputs;
|
||||||
// For 'NCHW' format, we perform reduction twice: first HW, then N.
|
// For 'NCHW' format, we perform reduction twice: first HW, then N.
|
||||||
TensorShape temp_grad_output_shape{row_count, col_count};
|
TensorShape temp_grad_output_shape{row_count, col_count};
|
||||||
@ -528,7 +524,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
|||||||
row_count, col_count);
|
row_count, col_count);
|
||||||
} else {
|
} else {
|
||||||
// For 'NHWC', we simply apply reduction once on NHW.
|
// For 'NHWC', we simply apply reduction once on NHW.
|
||||||
int32 row_count = batch * height * width;
|
int32 row_count = batch * height * width * depth;
|
||||||
int32 col_count = channel;
|
int32 col_count = channel;
|
||||||
BiasGradGPU<T>::DoColReduction(
|
BiasGradGPU<T>::DoColReduction(
|
||||||
context, const_cast<T*>(output->flat<T>().data()),
|
context, const_cast<T*>(output->flat<T>().data()),
|
||||||
@ -561,7 +557,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
|||||||
int device_id = stream->parent()->device_ordinal();
|
int device_id = stream->parent()->device_ordinal();
|
||||||
DataType dtype = output_backprop.dtype();
|
DataType dtype = output_backprop.dtype();
|
||||||
BiasAddParams bias_parameters = {
|
BiasAddParams bias_parameters = {
|
||||||
{batch, height * width, channel},
|
{batch, height * width * depth, channel},
|
||||||
data_format_,
|
data_format_,
|
||||||
dtype,
|
dtype,
|
||||||
device_id,
|
device_id,
|
||||||
@ -576,7 +572,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
|||||||
stream->InitTimer(&timer);
|
stream->InitTimer(&timer);
|
||||||
stream->ThenStartTimer(&timer);
|
stream->ThenStartTimer(&timer);
|
||||||
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
|
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
|
||||||
channel, output);
|
depth, channel, output);
|
||||||
stream->ThenStopTimer(&timer);
|
stream->ThenStopTimer(&timer);
|
||||||
uint64 elapsed_microseconds = timer.Microseconds();
|
uint64 elapsed_microseconds = timer.Microseconds();
|
||||||
VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
|
VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
|
||||||
@ -589,7 +585,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
|||||||
// Try reduction and profile.
|
// Try reduction and profile.
|
||||||
stream->ThenStartTimer(&timer);
|
stream->ThenStartTimer(&timer);
|
||||||
ComputeWithReduceSum(context, output_backprop, batch, width, height,
|
ComputeWithReduceSum(context, output_backprop, batch, width, height,
|
||||||
channel, output);
|
depth, channel, output);
|
||||||
stream->ThenStopTimer(&timer);
|
stream->ThenStopTimer(&timer);
|
||||||
|
|
||||||
elapsed_microseconds = timer.Microseconds();
|
elapsed_microseconds = timer.Microseconds();
|
||||||
@ -610,11 +606,11 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
|
|||||||
// Choose the best algorithm based on autotune results.
|
// Choose the best algorithm based on autotune results.
|
||||||
if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
|
if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
|
||||||
ComputeWithReduceSum(context, output_backprop, batch, width, height,
|
ComputeWithReduceSum(context, output_backprop, batch, width, height,
|
||||||
channel, output);
|
depth, channel, output);
|
||||||
} else {
|
} else {
|
||||||
// Default to the customized kernel.
|
// Default to the customized kernel.
|
||||||
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
|
ComputeWithCustomKernel(context, output_backprop, batch, width, height,
|
||||||
channel, output);
|
depth, channel, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,10 +195,10 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
|
void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
|
||||||
T* bias_backprop, int32 batch, int32 height,
|
T* bias_backprop, int32 batch, int32 height,
|
||||||
int32 width, int32 channel,
|
int32 width, int32 depth, int32 channel,
|
||||||
TensorFormat data_format) {
|
TensorFormat data_format) {
|
||||||
const int32 bias_size = channel;
|
const int32 bias_size = channel;
|
||||||
const int32 image_size = height * width;
|
const int32 image_size = height * width * depth;
|
||||||
const int32 total_count = batch * bias_size * image_size;
|
const int32 total_count = batch * bias_size * image_size;
|
||||||
if (total_count == 0) {
|
if (total_count == 0) {
|
||||||
return;
|
return;
|
||||||
|
@ -39,7 +39,7 @@ template <typename T>
|
|||||||
struct BiasGradGPU {
|
struct BiasGradGPU {
|
||||||
static void compute(const GPUDevice& device, const T* output_backprop,
|
static void compute(const GPUDevice& device, const T* output_backprop,
|
||||||
T* bias_backprop, int32 batch, int32 height, int32 width,
|
T* bias_backprop, int32 batch, int32 height, int32 width,
|
||||||
int32 channel, TensorFormat data_format);
|
int32 depth, int32 channel, TensorFormat data_format);
|
||||||
|
|
||||||
static void DoRowReduction(OpKernelContext* context, T* output,
|
static void DoRowReduction(OpKernelContext* context, T* output,
|
||||||
const T* input, int rows, int cols);
|
const T* input, int rows, int cols);
|
||||||
|
@ -196,9 +196,7 @@ class BiasAddTest(test.TestCase):
|
|||||||
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
|
self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGradientTensor(self):
|
def testGradientTensor2D(self):
|
||||||
# TODO(yongtang): BiasAddGrad with NCHW only works 4D. Reenable once
|
|
||||||
# all dimensions are supported.
|
|
||||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||||
np_input = np.array(
|
np_input = np.array(
|
||||||
@ -207,9 +205,19 @@ class BiasAddTest(test.TestCase):
|
|||||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradientTensor3D(self):
|
||||||
|
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||||
|
("NCHW", False), ("NCHW", True)]:
|
||||||
|
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||||
|
np_input = np.array(
|
||||||
|
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
|
||||||
|
dtype=dtype.as_numpy_dtype).reshape(1, 3, 2)
|
||||||
|
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||||
|
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGradientTensor4D(self):
|
def testGradientTensor4D(self):
|
||||||
# BiasAddGrad with NCHW support 4D so all are enabled.
|
|
||||||
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||||
("NCHW", False), ("NCHW", True)]:
|
("NCHW", False), ("NCHW", True)]:
|
||||||
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||||
@ -219,6 +227,17 @@ class BiasAddTest(test.TestCase):
|
|||||||
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||||
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradientTensor5D(self):
|
||||||
|
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||||
|
("NCHW", False), ("NCHW", True)]:
|
||||||
|
for dtype in (dtypes.float16, dtypes.float32, dtypes.float64):
|
||||||
|
np_input = np.arange(
|
||||||
|
1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape(
|
||||||
|
[1, 2, 3, 4, 2]).astype(np.float32)
|
||||||
|
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
|
||||||
|
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
np.random.seed(7)
|
np.random.seed(7)
|
||||||
@ -227,10 +246,15 @@ class BiasAddTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testEmptyGradient(self):
|
def testEmptyGradient(self):
|
||||||
# TODO(yongtang): BiasAddGrad with NCHW only works 4D. Reenable once
|
|
||||||
# all dimensions are supported.
|
|
||||||
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True):
|
||||||
for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3):
|
for shape in (0, 0), (2, 0), (0, 2):
|
||||||
|
self._testGradient(
|
||||||
|
np.random.randn(*shape),
|
||||||
|
np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu)
|
||||||
|
|
||||||
|
for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True),
|
||||||
|
("NCHW", False), ("NCHW", True)]:
|
||||||
|
for shape in (4, 3, 0), (4, 0, 3), (0, 4, 3):
|
||||||
self._testGradient(
|
self._testGradient(
|
||||||
np.random.randn(*shape),
|
np.random.randn(*shape),
|
||||||
np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu)
|
np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu)
|
||||||
|
@ -314,10 +314,10 @@ def _BiasAddGradGrad(op, received_grad):
|
|||||||
|
|
||||||
if data_format == b"NCHW":
|
if data_format == b"NCHW":
|
||||||
expanded_shape = array_ops.concat([
|
expanded_shape = array_ops.concat([
|
||||||
array_ops.ones_like(shape[:-3]), bias_shape,
|
array_ops.ones_like(shape[:1]), bias_shape,
|
||||||
array_ops.ones_like(shape[-2:])
|
array_ops.ones_like(shape[2:])
|
||||||
], 0)
|
], 0)
|
||||||
tile_mults = array_ops.concat([shape[:-3], [1], shape[-2:]], 0)
|
tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0)
|
||||||
else:
|
else:
|
||||||
expanded_shape = array_ops.concat(
|
expanded_shape = array_ops.concat(
|
||||||
[array_ops.ones_like(shape[:-1]), bias_shape], 0)
|
[array_ops.ones_like(shape[:-1]), bias_shape], 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user