diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index d4f4b43d63b..df12dafdeb3 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -18,13 +18,13 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/bias_op.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if GOOGLE_CUDA #include "tensorflow/core/kernels/bias_op_gpu.h" @@ -153,13 +153,13 @@ class BiasOp : public BinaryOp<T> { bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims); } break; case 5: { - Eigen::DSizes<int32, 5> four_dims(1, channel, 1, 1, 1); + Eigen::DSizes<int32, 5> five_dims(1, channel, 1, 1, 1); Eigen::DSizes<int32, 5> broad_cast_dims(batch, 1, height, width, depth); const Device& d = context->eigen_device<Device>(); output->tensor<T, 5>().device(d) = input.tensor<T, 5>() + - bias.tensor<T, 1>().reshape(four_dims).broadcast(broad_cast_dims); + bias.tensor<T, 1>().reshape(five_dims).broadcast(broad_cast_dims); } break; default: OP_REQUIRES(context, false, @@ -269,28 +269,24 @@ class BiasGradOp : public OpKernel { output->template flat<T>().setZero(); } else { // Added by intel_tf to support NCHW on CPU regardless of MKL used or not. - // TODO(yongtang): Add 3/4/5 dimensional data support for NCHW format. if (data_format_ == FORMAT_NCHW) { - OP_REQUIRES(context, output_backprop.dims() == 4, - errors::InvalidArgument( - "NCHW format supports only 4D input/output tensor.")); - Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width); + Eigen::DSizes<Eigen::Index, 3> three_dims(batch, channel, + height * width * depth); #ifdef EIGEN_HAS_INDEX_LIST using idx0 = Eigen::type2index<0>; using idx2 = Eigen::type2index<2>; - using idx3 = Eigen::type2index<3>; - Eigen::IndexList<idx0, idx2, idx3> reduction_axes; + Eigen::IndexList<idx0, idx2> reduction_axes; #else - Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3}; + Eigen::array<Eigen::Index, 2> reduction_axes = {0, 2}; #endif output->template flat<T>().device(context->eigen_device<Device>()) = output_backprop.flat<T>() .template cast<typename AccumulatorType<T>::type>() - .reshape(four_dims) + .reshape(three_dims) .sum(reduction_axes) .template cast<T>(); // End of code by intel_tf. } else { - Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width, + Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width * depth, channel); #ifdef EIGEN_HAS_INDEX_LIST Eigen::IndexList<Eigen::type2index<0> > reduction_axis; @@ -496,21 +492,21 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { void ComputeWithCustomKernel(OpKernelContext* context, const Tensor& output_backprop, int32 batch, - int32 width, int32 height, int32 channel, - Tensor* output) { + int32 width, int32 height, int32 depth, + int32 channel, Tensor* output) { BiasGradGPU<T>::compute(context->template eigen_device<Device>(), output_backprop.template flat<T>().data(), output->flat<T>().data(), batch, width, height, - channel, data_format_); + depth, channel, data_format_); } void ComputeWithReduceSum(OpKernelContext* context, const Tensor& output_backprop, int32 batch, - int32 width, int32 height, int32 channel, - Tensor* output) { + int32 width, int32 height, int32 depth, + int32 channel, Tensor* output) { if (data_format_ == FORMAT_NCHW) { int32 row_count = batch * channel; - int32 col_count = height * width; + int32 col_count = height * width * depth; Tensor temp_grad_outputs; // For 'NCHW' format, we perform reduction twice: first HW, then N. TensorShape temp_grad_output_shape{row_count, col_count}; @@ -528,7 +524,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { row_count, col_count); } else { // For 'NHWC', we simply apply reduction once on NHW. - int32 row_count = batch * height * width; + int32 row_count = batch * height * width * depth; int32 col_count = channel; BiasGradGPU<T>::DoColReduction( context, const_cast<T*>(output->flat<T>().data()), @@ -561,7 +557,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { int device_id = stream->parent()->device_ordinal(); DataType dtype = output_backprop.dtype(); BiasAddParams bias_parameters = { - {batch, height * width, channel}, + {batch, height * width * depth, channel}, data_format_, dtype, device_id, @@ -576,7 +572,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { stream->InitTimer(&timer); stream->ThenStartTimer(&timer); ComputeWithCustomKernel(context, output_backprop, batch, width, height, - channel, output); + depth, channel, output); stream->ThenStopTimer(&timer); uint64 elapsed_microseconds = timer.Microseconds(); VLOG(1) << "BiasAddGrad " << bias_parameters.ToString() @@ -589,7 +585,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { // Try reduction and profile. stream->ThenStartTimer(&timer); ComputeWithReduceSum(context, output_backprop, batch, width, height, - channel, output); + depth, channel, output); stream->ThenStopTimer(&timer); elapsed_microseconds = timer.Microseconds(); @@ -610,11 +606,11 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { // Choose the best algorithm based on autotune results. if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) { ComputeWithReduceSum(context, output_backprop, batch, width, height, - channel, output); + depth, channel, output); } else { // Default to the customized kernel. ComputeWithCustomKernel(context, output_backprop, batch, width, height, - channel, output); + depth, channel, output); } } diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 24fea8a8e6f..006fa1dc712 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -195,10 +195,10 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, template <typename T> void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop, T* bias_backprop, int32 batch, int32 height, - int32 width, int32 channel, + int32 width, int32 depth, int32 channel, TensorFormat data_format) { const int32 bias_size = channel; - const int32 image_size = height * width; + const int32 image_size = height * width * depth; const int32 total_count = batch * bias_size * image_size; if (total_count == 0) { return; diff --git a/tensorflow/core/kernels/bias_op_gpu.h b/tensorflow/core/kernels/bias_op_gpu.h index a0b2ce4f9b3..372a403e687 100644 --- a/tensorflow/core/kernels/bias_op_gpu.h +++ b/tensorflow/core/kernels/bias_op_gpu.h @@ -39,7 +39,7 @@ template <typename T> struct BiasGradGPU { static void compute(const GPUDevice& device, const T* output_backprop, T* bias_backprop, int32 batch, int32 height, int32 width, - int32 channel, TensorFormat data_format); + int32 depth, int32 channel, TensorFormat data_format); static void DoRowReduction(OpKernelContext* context, T* output, const T* input, int rows, int cols); diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py index 66f442dbddb..c3976194a0f 100644 --- a/tensorflow/python/kernel_tests/bias_op_test.py +++ b/tensorflow/python/kernel_tests/bias_op_test.py @@ -196,9 +196,7 @@ class BiasAddTest(test.TestCase): self.assertAllClose(grad_jacob_t, grad_jacob_n, threshold, threshold) @test_util.run_deprecated_v1 - def testGradientTensor(self): - # TODO(yongtang): BiasAddGrad with NCHW only works 4D. Reenable once - # all dimensions are supported. + def testGradientTensor2D(self): for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True): for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): np_input = np.array( @@ -207,9 +205,19 @@ class BiasAddTest(test.TestCase): bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype) self._testGradient(np_input, bias, dtype, data_format, use_gpu) + @test_util.run_deprecated_v1 + def testGradientTensor3D(self): + for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True), + ("NCHW", False), ("NCHW", True)]: + for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + np_input = np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + dtype=dtype.as_numpy_dtype).reshape(1, 3, 2) + bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype) + self._testGradient(np_input, bias, dtype, data_format, use_gpu) + @test_util.run_deprecated_v1 def testGradientTensor4D(self): - # BiasAddGrad with NCHW support 4D so all are enabled. for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True), ("NCHW", False), ("NCHW", True)]: for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): @@ -219,6 +227,17 @@ class BiasAddTest(test.TestCase): bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype) self._testGradient(np_input, bias, dtype, data_format, use_gpu) + @test_util.run_deprecated_v1 + def testGradientTensor5D(self): + for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True), + ("NCHW", False), ("NCHW", True)]: + for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + np_input = np.arange( + 1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape( + [1, 2, 3, 4, 2]).astype(np.float32) + bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype) + self._testGradient(np_input, bias, dtype, data_format, use_gpu) + @test_util.run_deprecated_v1 def testEmpty(self): np.random.seed(7) @@ -227,10 +246,15 @@ class BiasAddTest(test.TestCase): @test_util.run_deprecated_v1 def testEmptyGradient(self): - # TODO(yongtang): BiasAddGrad with NCHW only works 4D. Reenable once - # all dimensions are supported. for (data_format, use_gpu) in ("NHWC", False), ("NHWC", True): - for shape in (0, 0), (2, 0), (0, 2), (4, 3, 0), (4, 0, 3), (0, 4, 3): + for shape in (0, 0), (2, 0), (0, 2): + self._testGradient( + np.random.randn(*shape), + np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu) + + for (data_format, use_gpu) in [("NHWC", False), ("NHWC", True), + ("NCHW", False), ("NCHW", True)]: + for shape in (4, 3, 0), (4, 0, 3), (0, 4, 3): self._testGradient( np.random.randn(*shape), np.random.randn(shape[-1]), dtypes.float64, data_format, use_gpu) diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 34404edc9a1..7131e4abc45 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -314,10 +314,10 @@ def _BiasAddGradGrad(op, received_grad): if data_format == b"NCHW": expanded_shape = array_ops.concat([ - array_ops.ones_like(shape[:-3]), bias_shape, - array_ops.ones_like(shape[-2:]) + array_ops.ones_like(shape[:1]), bias_shape, + array_ops.ones_like(shape[2:]) ], 0) - tile_mults = array_ops.concat([shape[:-3], [1], shape[-2:]], 0) + tile_mults = array_ops.concat([shape[:1], [1], shape[2:]], 0) else: expanded_shape = array_ops.concat( [array_ops.ones_like(shape[:-1]), bias_shape], 0)