diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ae39c4522db..015b4025fc2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -194,10 +194,9 @@ cc_library( ], ) -cc_library( +tf_kernel_library( name = "fill_functor", - srcs = ["fill_functor.cc"], - hdrs = ["fill_functor.h"], + prefix = "fill_functor", deps = [ "//tensorflow/core:framework", "//third_party/eigen3", @@ -3043,6 +3042,7 @@ tf_kernel_library( "//conditions:default": [], }), hdrs = [ + "fill_functor.h", "conv_grad_ops.h", "deep_conv2d.h", "gemm_functors.h", @@ -3067,6 +3067,7 @@ tf_kernel_library( ":conv_2d", ":conv_3d", ":image_resizer_state", + ":fill_functor", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index c485d02e232..c8bfb268592 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -151,18 +151,6 @@ typedef Eigen::GpuDevice GPUDevice; typedef Eigen::SyclDevice SYCLDevice; #endif // TENSORFLOW_USE_SYCL -namespace functor { - -// Partial specialization of FillFunctor<Device=CPUDevice, T>. -template <typename T> -struct FillFunctor<CPUDevice, T> { - void operator()(const CPUDevice& d, typename TTypes<T>::Flat out, - typename TTypes<T>::ConstScalar in) { - out.device(d) = out.constant(in()); - } -}; - -} // end namespace functor template <typename Device, typename T, typename Index> class FillOp : public OpKernel { @@ -191,28 +179,6 @@ class FillOp : public OpKernel { } }; -#ifdef TENSORFLOW_USE_SYCL - -namespace functor { -// Partial specialization of FillFunctor<Device=SYCLDevice, T>. -template <typename T> -struct FillFunctor<SYCLDevice, T> { - void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, - typename TTypes<T>::ConstScalar in) { -#if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::array<int, 1> rank1{1}; -#else - Eigen::IndexList<Eigen::type2index<1> > rank1; -#endif - const int size = out.dimension(0); - Eigen::array<int, 1> broadcast_dims{size}; - - To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims); - } -}; -} // namespace functor -#endif // TENSORFLOW_USE_SYCL - #define REGISTER_KERNEL(D, TYPE) \ REGISTER_KERNEL_BUILDER(Name("Fill") \ .Device(DEVICE_##D) \ diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 1791c510966..5e4feb25848 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/fill_functor.h" #ifdef TENSORFLOW_USE_LIBXSMM #include "tensorflow/core/kernels/xsmm_conv2d.h" #endif @@ -595,6 +596,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel { if (filter_shape.num_elements() == 0) { return; } + // If input is empty, set gradients to zero. + if (input.shape().num_elements() == 0) { + functor::SetZeroFunctor<Device, T> f; + f(context->eigen_device<Device>(), filter_backprop->flat<T>()); + return; + } + // For now we take the stride from the second and third dimensions only (we // do not support striding on the batch or depth dimension). diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index ea0cc139f3d..35d9693f541 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -19,6 +19,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant_encode_decode.h" @@ -74,6 +75,7 @@ DEFINE_SETZERO_SYCL(int32); DEFINE_SETZERO_SYCL(int64); #undef DEFINE_SETZERO_SYCL #endif // TENSORFLOW_USE_SYCL + template <typename T> void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()( const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) { @@ -112,5 +114,47 @@ DEFINE_SETONE_SYCL(double); #undef DEFINE_SETONE_SYCL #endif // TENSORFLOW_USE_SYCL +template <typename T> +struct FillFunctor<Eigen::ThreadPoolDevice, T> { + void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstScalar in) { + out.device(d) = out.constant(in()); + } +}; + +// Explicit instantiations. +#define DEFINE_FILL_CPU(T) \ + template struct FillFunctor<Eigen::ThreadPoolDevice, T>; + +TF_CALL_ALL_TYPES(DEFINE_FILL_CPU); +DEFINE_FILL_CPU(quint8); +DEFINE_FILL_CPU(quint16); +#undef DEFINE_FILL_CPU + +#ifdef TENSORFLOW_USE_SYCL +template <typename T> +struct FillFunctor<Eigen::SyclDevice, T> { + void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstScalar in) { +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::array<int, 1> rank1{1}; +#else + Eigen::IndexList<Eigen::type2index<1> > rank1; +#endif + const int size = out.dimension(0); + Eigen::array<int, 1> broadcast_dims{size}; + + To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims); + } +}; + +#define DEFINE_FILL_SYCL(T) \ + template struct FillFunctor<Eigen::SyclDevice, T>; +DEFINE_FILL_SYCL(float); +DEFINE_FILL_SYCL(double); +TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL) +#undef DEFINE_FILL_SYCL +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/fill_functor.cu.cc similarity index 100% rename from tensorflow/core/kernels/constant_op_gpu.cu.cc rename to tensorflow/core/kernels/fill_functor.cu.cc diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 09ba092f404..7ccaef948c4 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/fused_batch_norm_op.h" #include "tensorflow/core/util/tensor_format.h" @@ -239,6 +240,14 @@ struct FusedBatchNorm<GPUDevice, T, U> { << " offset shape: " << offset.shape().DebugString() << " tensor format: " << tensor_format; + // If input is empty, return NaN mean/variance + if (x.shape().num_elements() == 0) { + functor::SetNanFunctor<U> f; + f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>()); + f(context->eigen_device<GPUDevice>(), batch_var->flat<U>()); + return; + } + Tensor x_maybe_transformed = x; Tensor x_transformed; Tensor y_transformed; @@ -656,6 +665,14 @@ class FusedBatchNormGradOp : public OpKernel { context, context->allocate_output(4, TensorShape({}), &placeholder_2)); FillZeros<Device>(placeholder_2); + // If input is empty, set gradients w.r.t scale/offset to zero. + if (x.shape().num_elements() == 0) { + functor::SetZeroFunctor<Device, U> f; + f(context->eigen_device<Device>(), scale_backprop->flat<U>()); + f(context->eigen_device<Device>(), offset_backprop->flat<U>()); + return; + } + if (is_training_) { functor::FusedBatchNormGrad<Device, T, U>()( context, y_backprop, x, scale, saved_mean_or_pop_mean, diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc index dc956066ecf..a8484390b92 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc @@ -65,8 +65,15 @@ void InvVarianceToVariance<T>::operator()(const Eigen::GpuDevice& d, epsilon, sample_size, variance); } +template <class T> +void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d, + typename TTypes<T>::Flat out) { + To32Bit(out).device(d) = To32Bit(out).constant(Eigen::NumTraits<T>::quiet_NaN()); +} + template class VarianceToInvVariance<float>; template class InvVarianceToVariance<float>; +template class SetNanFunctor<float>; } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h index 3af104bf954..d6c68df9861 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.h +++ b/tensorflow/core/kernels/fused_batch_norm_op.h @@ -49,6 +49,12 @@ struct InvVarianceToVariance { int channels, T* variance); }; +// This function sets a GPU tensor to NaNs. +template <class T> +struct SetNanFunctor { + void operator()(const Eigen::GpuDevice& d, typename TTypes<T>::Flat out); +}; + #endif // GOOGLE_CUDA // Functor used by FusedBatchNormGradOp to do the computations when diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index ac90f67ce0b..6a52a15c931 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -147,6 +147,9 @@ void DnnPoolingOp<T>::Compute( Tensor* tensor_out = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, tensor_out_shape, &tensor_out)); + if (tensor_in.shape().num_elements() == 0) { + return; + } PoolParameters params{context, size, stride, padding, data_format, tensor_in.shape()}; @@ -247,6 +250,9 @@ void DnnPoolingGradOp<T>::Compute( Tensor* input_backprop = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, tensor_in_shape, &input_backprop)); + if (tensor_in_shape.num_elements() == 0) { + return; + } PoolParameters params{context, size, stride, padding, data_format, tensor_in_shape}; diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index a7cbc76b879..3e9bd3dade6 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -796,6 +796,20 @@ class Conv2DTest(test.TestCase): data_format=data_format, use_gpu=use_gpu) + @test_util.run_in_graph_and_eager_modes() + def testConv2DBackpropFilterWithEmptyInput(self): + expected = [0, 0, 0, 0] + for (data_format, use_gpu) in GetTestConfigs(): + self._RunAndVerifyBackpropFilter( + input_sizes=[0, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + output_sizes=[0, 1, 2, 1], + strides=[1, 1], + padding="VALID", + expected=expected, + data_format=data_format, + use_gpu=use_gpu) + @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Depth3ValidBackpropFilter(self): expected = [ diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 6be8997cabd..5c0ea8ec8ed 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -361,6 +361,16 @@ class PoolingTest(test.TestCase): expected=expected_output, use_gpu=use_gpu) + def _testAvgPoolEmptyInput(self, use_gpu): + self._VerifyValues( + nn_ops.avg_pool, + input_sizes=[0, 8, 8, 8], + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[], + use_gpu=use_gpu) + def testAvgPooling(self): for use_gpu in True, False: self._testAvgPoolValidPadding(use_gpu) @@ -371,6 +381,7 @@ class PoolingTest(test.TestCase): self._testAvgPoolSamePadding4(use_gpu) self._testAvgPoolSamePaddingPacket4(use_gpu) self._testAvgPoolSamePaddingPacket8(use_gpu) + self._testAvgPoolEmptyInput(use_gpu) def _testMaxPoolValidPadding(self, use_gpu): expected_output = [13.0, 14.0, 15.0] @@ -543,6 +554,16 @@ class PoolingTest(test.TestCase): use_gpu=use_gpu, v2=v2) + def _testMaxPoolEmptyInput(self, use_gpu): + self._VerifyValues( + gen_nn_ops._max_pool_v2, + input_sizes=[0, 8, 8, 8], + ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[], + use_gpu=use_gpu) + def testMaxPooling(self): for use_gpu in True, False: self._testMaxPoolValidPadding(use_gpu) @@ -551,6 +572,7 @@ class PoolingTest(test.TestCase): self._testMaxPoolValidPaddingUnevenStride(use_gpu) self._testMaxPoolSamePaddingPacket4(use_gpu) self._testMaxPoolSamePaddingPacket8(use_gpu) + self._testMaxPoolEmptyInput(use_gpu) # Tests for DepthwiseMaxPooling on CPU only. def testDepthwiseMaxPool1x1DepthWindow1(self): diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index ff7137d492c..0593ed2cfa6 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -171,6 +171,10 @@ class BatchNormalizationTest(test.TestCase): x, x_shape, y, y_shape, delta=1e-3, x_init_value=x_init_val) _, numerical_grad = gradient_checker.compute_gradient( x32, x_shape, y32, y_shape, delta=1e-3, x_init_value=x32_init_val) + + # If grad is empty, no error. + if theoretical_grad.size == 0 and numerical_grad.size == 0: + return 0 return np.fabs(theoretical_grad - numerical_grad).max() def _test_gradient(self, @@ -371,6 +375,17 @@ class BatchNormalizationTest(test.TestCase): self._test_inference( x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + def testInferenceShape5(self): + x_shape = [0, 131, 127, 6] + for dtype in [np.float16, np.float32]: + if test.is_gpu_available(cuda_only=True): + self._test_inference( + x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW') + self._test_inference( + x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC') + self._test_inference( + x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + def testTrainingShape1(self): x_shape = [1, 1, 6, 1] for dtype in [np.float16, np.float32]: @@ -409,6 +424,17 @@ class BatchNormalizationTest(test.TestCase): self._test_training( x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + def testTrainingShape5(self): + x_shape = [0, 131, 127, 6] + for dtype in [np.float16, np.float32]: + if test.is_gpu_available(cuda_only=True): + self._test_training( + x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW') + self._test_training( + x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC') + self._test_training( + x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC') + def testBatchNormGradShape1(self): for is_training in [True, False]: x_shape = [1, 1, 6, 1] @@ -496,6 +522,33 @@ class BatchNormalizationTest(test.TestCase): data_format='NHWC', is_training=is_training) + def testBatchNormGradShape5(self): + for is_training in [True, False]: + x_shape = [0, 7, 11, 4] + for dtype in [np.float16, np.float32]: + if test.is_gpu_available(cuda_only=True): + self._test_gradient( + x_shape, + dtype, [7], + np.float32, + use_gpu=True, + data_format='NCHW', + is_training=is_training) + self._test_gradient( + x_shape, + dtype, [4], + np.float32, + use_gpu=True, + data_format='NHWC', + is_training=is_training) + self._test_gradient( + x_shape, + dtype, [4], + np.float32, + use_gpu=False, + data_format='NHWC', + is_training=is_training) + def _testBatchNormGradGrad(self, config): shape = config['shape'] err_tolerance = config['err_tolerance']