Support empty input tensor for some ops (fix ) ()

* Support empty input tensor for FusedBatchNorm,FusedBatchNormGrad,Conv2DBackpropFilter (fix )

* Also fix pooling ops

* Add some comments in ops

* Add tests for conv/pooling/bn.

* Return NaN mean/variance when input is empty

* update comments

* fix typo

* Move fill_functor implementations to :fill_functor
This commit is contained in:
Yuxin Wu 2017-12-28 16:38:35 -08:00 committed by drpngx
parent 1c5f27a713
commit 3a3b7530eb
12 changed files with 181 additions and 37 deletions

View File

@ -194,10 +194,9 @@ cc_library(
],
)
cc_library(
tf_kernel_library(
name = "fill_functor",
srcs = ["fill_functor.cc"],
hdrs = ["fill_functor.h"],
prefix = "fill_functor",
deps = [
"//tensorflow/core:framework",
"//third_party/eigen3",
@ -3043,6 +3042,7 @@ tf_kernel_library(
"//conditions:default": [],
}),
hdrs = [
"fill_functor.h",
"conv_grad_ops.h",
"deep_conv2d.h",
"gemm_functors.h",
@ -3067,6 +3067,7 @@ tf_kernel_library(
":conv_2d",
":conv_3d",
":image_resizer_state",
":fill_functor",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -151,18 +151,6 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
namespace functor {
// Partial specialization of FillFunctor<Device=CPUDevice, T>.
template <typename T>
struct FillFunctor<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstScalar in) {
out.device(d) = out.constant(in());
}
};
} // end namespace functor
template <typename Device, typename T, typename Index>
class FillOp : public OpKernel {
@ -191,28 +179,6 @@ class FillOp : public OpKernel {
}
};
#ifdef TENSORFLOW_USE_SYCL
namespace functor {
// Partial specialization of FillFunctor<Device=SYCLDevice, T>.
template <typename T>
struct FillFunctor<SYCLDevice, T> {
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstScalar in) {
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::array<int, 1> rank1{1};
#else
Eigen::IndexList<Eigen::type2index<1> > rank1;
#endif
const int size = out.dimension(0);
Eigen::array<int, 1> broadcast_dims{size};
To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
}
};
} // namespace functor
#endif // TENSORFLOW_USE_SYCL
#define REGISTER_KERNEL(D, TYPE) \
REGISTER_KERNEL_BUILDER(Name("Fill") \
.Device(DEVICE_##D) \

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/fill_functor.h"
#ifdef TENSORFLOW_USE_LIBXSMM
#include "tensorflow/core/kernels/xsmm_conv2d.h"
#endif
@ -595,6 +596,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
if (filter_shape.num_elements() == 0) {
return;
}
// If input is empty, set gradients to zero.
if (input.shape().num_elements() == 0) {
functor::SetZeroFunctor<Device, T> f;
f(context->eigen_device<Device>(), filter_backprop->flat<T>());
return;
}
// For now we take the stride from the second and third dimensions only (we
// do not support striding on the batch or depth dimension).

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
@ -74,6 +75,7 @@ DEFINE_SETZERO_SYCL(int32);
DEFINE_SETZERO_SYCL(int64);
#undef DEFINE_SETZERO_SYCL
#endif // TENSORFLOW_USE_SYCL
template <typename T>
void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()(
const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
@ -112,5 +114,47 @@ DEFINE_SETONE_SYCL(double);
#undef DEFINE_SETONE_SYCL
#endif // TENSORFLOW_USE_SYCL
template <typename T>
struct FillFunctor<Eigen::ThreadPoolDevice, T> {
void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstScalar in) {
out.device(d) = out.constant(in());
}
};
// Explicit instantiations.
#define DEFINE_FILL_CPU(T) \
template struct FillFunctor<Eigen::ThreadPoolDevice, T>;
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
DEFINE_FILL_CPU(quint8);
DEFINE_FILL_CPU(quint16);
#undef DEFINE_FILL_CPU
#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct FillFunctor<Eigen::SyclDevice, T> {
void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out,
typename TTypes<T>::ConstScalar in) {
#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::array<int, 1> rank1{1};
#else
Eigen::IndexList<Eigen::type2index<1> > rank1;
#endif
const int size = out.dimension(0);
Eigen::array<int, 1> broadcast_dims{size};
To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
}
};
#define DEFINE_FILL_SYCL(T) \
template struct FillFunctor<Eigen::SyclDevice, T>;
DEFINE_FILL_SYCL(float);
DEFINE_FILL_SYCL(double);
TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL)
#undef DEFINE_FILL_SYCL
#endif // TENSORFLOW_USE_SYCL
} // namespace functor
} // namespace tensorflow

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
#include "tensorflow/core/util/tensor_format.h"
@ -239,6 +240,14 @@ struct FusedBatchNorm<GPUDevice, T, U> {
<< " offset shape: " << offset.shape().DebugString()
<< " tensor format: " << tensor_format;
// If input is empty, return NaN mean/variance
if (x.shape().num_elements() == 0) {
functor::SetNanFunctor<U> f;
f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
return;
}
Tensor x_maybe_transformed = x;
Tensor x_transformed;
Tensor y_transformed;
@ -656,6 +665,14 @@ class FusedBatchNormGradOp : public OpKernel {
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
FillZeros<Device>(placeholder_2);
// If input is empty, set gradients w.r.t scale/offset to zero.
if (x.shape().num_elements() == 0) {
functor::SetZeroFunctor<Device, U> f;
f(context->eigen_device<Device>(), scale_backprop->flat<U>());
f(context->eigen_device<Device>(), offset_backprop->flat<U>());
return;
}
if (is_training_) {
functor::FusedBatchNormGrad<Device, T, U>()(
context, y_backprop, x, scale, saved_mean_or_pop_mean,

View File

@ -65,8 +65,15 @@ void InvVarianceToVariance<T>::operator()(const Eigen::GpuDevice& d,
epsilon, sample_size, variance);
}
template <class T>
void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d,
typename TTypes<T>::Flat out) {
To32Bit(out).device(d) = To32Bit(out).constant(Eigen::NumTraits<T>::quiet_NaN());
}
template class VarianceToInvVariance<float>;
template class InvVarianceToVariance<float>;
template class SetNanFunctor<float>;
} // namespace functor
} // namespace tensorflow

View File

@ -49,6 +49,12 @@ struct InvVarianceToVariance {
int channels, T* variance);
};
// This function sets a GPU tensor to NaNs.
template <class T>
struct SetNanFunctor {
void operator()(const Eigen::GpuDevice& d, typename TTypes<T>::Flat out);
};
#endif // GOOGLE_CUDA
// Functor used by FusedBatchNormGradOp to do the computations when

View File

@ -147,6 +147,9 @@ void DnnPoolingOp<T>::Compute(
Tensor* tensor_out = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, tensor_out_shape, &tensor_out));
if (tensor_in.shape().num_elements() == 0) {
return;
}
PoolParameters params{context, size, stride,
padding, data_format, tensor_in.shape()};
@ -247,6 +250,9 @@ void DnnPoolingGradOp<T>::Compute(
Tensor* input_backprop = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, tensor_in_shape, &input_backprop));
if (tensor_in_shape.num_elements() == 0) {
return;
}
PoolParameters params{context, size, stride,
padding, data_format, tensor_in_shape};

View File

@ -796,6 +796,20 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
@test_util.run_in_graph_and_eager_modes()
def testConv2DBackpropFilterWithEmptyInput(self):
expected = [0, 0, 0, 0]
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropFilter(
input_sizes=[0, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
output_sizes=[0, 1, 2, 1],
strides=[1, 1],
padding="VALID",
expected=expected,
data_format=data_format,
use_gpu=use_gpu)
@test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Depth3ValidBackpropFilter(self):
expected = [

View File

@ -361,6 +361,16 @@ class PoolingTest(test.TestCase):
expected=expected_output,
use_gpu=use_gpu)
def _testAvgPoolEmptyInput(self, use_gpu):
self._VerifyValues(
nn_ops.avg_pool,
input_sizes=[0, 8, 8, 8],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=[],
use_gpu=use_gpu)
def testAvgPooling(self):
for use_gpu in True, False:
self._testAvgPoolValidPadding(use_gpu)
@ -371,6 +381,7 @@ class PoolingTest(test.TestCase):
self._testAvgPoolSamePadding4(use_gpu)
self._testAvgPoolSamePaddingPacket4(use_gpu)
self._testAvgPoolSamePaddingPacket8(use_gpu)
self._testAvgPoolEmptyInput(use_gpu)
def _testMaxPoolValidPadding(self, use_gpu):
expected_output = [13.0, 14.0, 15.0]
@ -543,6 +554,16 @@ class PoolingTest(test.TestCase):
use_gpu=use_gpu,
v2=v2)
def _testMaxPoolEmptyInput(self, use_gpu):
self._VerifyValues(
gen_nn_ops._max_pool_v2,
input_sizes=[0, 8, 8, 8],
ksize=[1, 3, 3, 1],
strides=[1, 2, 2, 1],
padding="SAME",
expected=[],
use_gpu=use_gpu)
def testMaxPooling(self):
for use_gpu in True, False:
self._testMaxPoolValidPadding(use_gpu)
@ -551,6 +572,7 @@ class PoolingTest(test.TestCase):
self._testMaxPoolValidPaddingUnevenStride(use_gpu)
self._testMaxPoolSamePaddingPacket4(use_gpu)
self._testMaxPoolSamePaddingPacket8(use_gpu)
self._testMaxPoolEmptyInput(use_gpu)
# Tests for DepthwiseMaxPooling on CPU only.
def testDepthwiseMaxPool1x1DepthWindow1(self):

View File

@ -171,6 +171,10 @@ class BatchNormalizationTest(test.TestCase):
x, x_shape, y, y_shape, delta=1e-3, x_init_value=x_init_val)
_, numerical_grad = gradient_checker.compute_gradient(
x32, x_shape, y32, y_shape, delta=1e-3, x_init_value=x32_init_val)
# If grad is empty, no error.
if theoretical_grad.size == 0 and numerical_grad.size == 0:
return 0
return np.fabs(theoretical_grad - numerical_grad).max()
def _test_gradient(self,
@ -371,6 +375,17 @@ class BatchNormalizationTest(test.TestCase):
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testInferenceShape5(self):
x_shape = [0, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_inference(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_inference(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape1(self):
x_shape = [1, 1, 6, 1]
for dtype in [np.float16, np.float32]:
@ -409,6 +424,17 @@ class BatchNormalizationTest(test.TestCase):
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testTrainingShape5(self):
x_shape = [0, 131, 127, 6]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_training(
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
self._test_training(
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
def testBatchNormGradShape1(self):
for is_training in [True, False]:
x_shape = [1, 1, 6, 1]
@ -496,6 +522,33 @@ class BatchNormalizationTest(test.TestCase):
data_format='NHWC',
is_training=is_training)
def testBatchNormGradShape5(self):
for is_training in [True, False]:
x_shape = [0, 7, 11, 4]
for dtype in [np.float16, np.float32]:
if test.is_gpu_available(cuda_only=True):
self._test_gradient(
x_shape,
dtype, [7],
np.float32,
use_gpu=True,
data_format='NCHW',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=True,
data_format='NHWC',
is_training=is_training)
self._test_gradient(
x_shape,
dtype, [4],
np.float32,
use_gpu=False,
data_format='NHWC',
is_training=is_training)
def _testBatchNormGradGrad(self, config):
shape = config['shape']
err_tolerance = config['err_tolerance']