* Support empty input tensor for FusedBatchNorm,FusedBatchNormGrad,Conv2DBackpropFilter (fix #14657) * Also fix pooling ops * Add some comments in ops * Add tests for conv/pooling/bn. * Return NaN mean/variance when input is empty * update comments * fix typo * Move fill_functor implementations to :fill_functor
This commit is contained in:
parent
1c5f27a713
commit
3a3b7530eb
@ -194,10 +194,9 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
tf_kernel_library(
|
||||
name = "fill_functor",
|
||||
srcs = ["fill_functor.cc"],
|
||||
hdrs = ["fill_functor.h"],
|
||||
prefix = "fill_functor",
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//third_party/eigen3",
|
||||
@ -3043,6 +3042,7 @@ tf_kernel_library(
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
hdrs = [
|
||||
"fill_functor.h",
|
||||
"conv_grad_ops.h",
|
||||
"deep_conv2d.h",
|
||||
"gemm_functors.h",
|
||||
@ -3067,6 +3067,7 @@ tf_kernel_library(
|
||||
":conv_2d",
|
||||
":conv_3d",
|
||||
":image_resizer_state",
|
||||
":fill_functor",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -151,18 +151,6 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
typedef Eigen::SyclDevice SYCLDevice;
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
namespace functor {
|
||||
|
||||
// Partial specialization of FillFunctor<Device=CPUDevice, T>.
|
||||
template <typename T>
|
||||
struct FillFunctor<CPUDevice, T> {
|
||||
void operator()(const CPUDevice& d, typename TTypes<T>::Flat out,
|
||||
typename TTypes<T>::ConstScalar in) {
|
||||
out.device(d) = out.constant(in());
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace functor
|
||||
|
||||
template <typename Device, typename T, typename Index>
|
||||
class FillOp : public OpKernel {
|
||||
@ -191,28 +179,6 @@ class FillOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
|
||||
namespace functor {
|
||||
// Partial specialization of FillFunctor<Device=SYCLDevice, T>.
|
||||
template <typename T>
|
||||
struct FillFunctor<SYCLDevice, T> {
|
||||
void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
|
||||
typename TTypes<T>::ConstScalar in) {
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::array<int, 1> rank1{1};
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1> > rank1;
|
||||
#endif
|
||||
const int size = out.dimension(0);
|
||||
Eigen::array<int, 1> broadcast_dims{size};
|
||||
|
||||
To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
|
||||
}
|
||||
};
|
||||
} // namespace functor
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
#define REGISTER_KERNEL(D, TYPE) \
|
||||
REGISTER_KERNEL_BUILDER(Name("Fill") \
|
||||
.Device(DEVICE_##D) \
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#ifdef TENSORFLOW_USE_LIBXSMM
|
||||
#include "tensorflow/core/kernels/xsmm_conv2d.h"
|
||||
#endif
|
||||
@ -595,6 +596,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
|
||||
if (filter_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
// If input is empty, set gradients to zero.
|
||||
if (input.shape().num_elements() == 0) {
|
||||
functor::SetZeroFunctor<Device, T> f;
|
||||
f(context->eigen_device<Device>(), filter_backprop->flat<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// For now we take the stride from the second and third dimensions only (we
|
||||
// do not support striding on the batch or depth dimension).
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
|
||||
@ -74,6 +75,7 @@ DEFINE_SETZERO_SYCL(int32);
|
||||
DEFINE_SETZERO_SYCL(int64);
|
||||
#undef DEFINE_SETZERO_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
template <typename T>
|
||||
void SetOneFunctor<Eigen::ThreadPoolDevice, T>::operator()(
|
||||
const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out) {
|
||||
@ -112,5 +114,47 @@ DEFINE_SETONE_SYCL(double);
|
||||
#undef DEFINE_SETONE_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
template <typename T>
|
||||
struct FillFunctor<Eigen::ThreadPoolDevice, T> {
|
||||
void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<T>::Flat out,
|
||||
typename TTypes<T>::ConstScalar in) {
|
||||
out.device(d) = out.constant(in());
|
||||
}
|
||||
};
|
||||
|
||||
// Explicit instantiations.
|
||||
#define DEFINE_FILL_CPU(T) \
|
||||
template struct FillFunctor<Eigen::ThreadPoolDevice, T>;
|
||||
|
||||
TF_CALL_ALL_TYPES(DEFINE_FILL_CPU);
|
||||
DEFINE_FILL_CPU(quint8);
|
||||
DEFINE_FILL_CPU(quint16);
|
||||
#undef DEFINE_FILL_CPU
|
||||
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
template <typename T>
|
||||
struct FillFunctor<Eigen::SyclDevice, T> {
|
||||
void operator()(const Eigen::SyclDevice& d, typename TTypes<T>::Flat out,
|
||||
typename TTypes<T>::ConstScalar in) {
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::array<int, 1> rank1{1};
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1> > rank1;
|
||||
#endif
|
||||
const int size = out.dimension(0);
|
||||
Eigen::array<int, 1> broadcast_dims{size};
|
||||
|
||||
To32Bit(out).device(d) = in.reshape(rank1).broadcast(broadcast_dims);
|
||||
}
|
||||
};
|
||||
|
||||
#define DEFINE_FILL_SYCL(T) \
|
||||
template struct FillFunctor<Eigen::SyclDevice, T>;
|
||||
DEFINE_FILL_SYCL(float);
|
||||
DEFINE_FILL_SYCL(double);
|
||||
TF_CALL_INTEGRAL_TYPES(DEFINE_FILL_SYCL)
|
||||
#undef DEFINE_FILL_SYCL
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -239,6 +240,14 @@ struct FusedBatchNorm<GPUDevice, T, U> {
|
||||
<< " offset shape: " << offset.shape().DebugString()
|
||||
<< " tensor format: " << tensor_format;
|
||||
|
||||
// If input is empty, return NaN mean/variance
|
||||
if (x.shape().num_elements() == 0) {
|
||||
functor::SetNanFunctor<U> f;
|
||||
f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
|
||||
f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor x_maybe_transformed = x;
|
||||
Tensor x_transformed;
|
||||
Tensor y_transformed;
|
||||
@ -656,6 +665,14 @@ class FusedBatchNormGradOp : public OpKernel {
|
||||
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
|
||||
FillZeros<Device>(placeholder_2);
|
||||
|
||||
// If input is empty, set gradients w.r.t scale/offset to zero.
|
||||
if (x.shape().num_elements() == 0) {
|
||||
functor::SetZeroFunctor<Device, U> f;
|
||||
f(context->eigen_device<Device>(), scale_backprop->flat<U>());
|
||||
f(context->eigen_device<Device>(), offset_backprop->flat<U>());
|
||||
return;
|
||||
}
|
||||
|
||||
if (is_training_) {
|
||||
functor::FusedBatchNormGrad<Device, T, U>()(
|
||||
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
||||
|
@ -65,8 +65,15 @@ void InvVarianceToVariance<T>::operator()(const Eigen::GpuDevice& d,
|
||||
epsilon, sample_size, variance);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d,
|
||||
typename TTypes<T>::Flat out) {
|
||||
To32Bit(out).device(d) = To32Bit(out).constant(Eigen::NumTraits<T>::quiet_NaN());
|
||||
}
|
||||
|
||||
template class VarianceToInvVariance<float>;
|
||||
template class InvVarianceToVariance<float>;
|
||||
template class SetNanFunctor<float>;
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -49,6 +49,12 @@ struct InvVarianceToVariance {
|
||||
int channels, T* variance);
|
||||
};
|
||||
|
||||
// This function sets a GPU tensor to NaNs.
|
||||
template <class T>
|
||||
struct SetNanFunctor {
|
||||
void operator()(const Eigen::GpuDevice& d, typename TTypes<T>::Flat out);
|
||||
};
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
// Functor used by FusedBatchNormGradOp to do the computations when
|
||||
|
@ -147,6 +147,9 @@ void DnnPoolingOp<T>::Compute(
|
||||
Tensor* tensor_out = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, tensor_out_shape, &tensor_out));
|
||||
if (tensor_in.shape().num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
PoolParameters params{context, size, stride,
|
||||
padding, data_format, tensor_in.shape()};
|
||||
@ -247,6 +250,9 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
Tensor* input_backprop = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, tensor_in_shape, &input_backprop));
|
||||
if (tensor_in_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
PoolParameters params{context, size, stride,
|
||||
padding, data_format, tensor_in_shape};
|
||||
|
@ -796,6 +796,20 @@ class Conv2DTest(test.TestCase):
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2DBackpropFilterWithEmptyInput(self):
|
||||
expected = [0, 0, 0, 0]
|
||||
for (data_format, use_gpu) in GetTestConfigs():
|
||||
self._RunAndVerifyBackpropFilter(
|
||||
input_sizes=[0, 2, 3, 1],
|
||||
filter_sizes=[2, 2, 1, 1],
|
||||
output_sizes=[0, 1, 2, 1],
|
||||
strides=[1, 1],
|
||||
padding="VALID",
|
||||
expected=expected,
|
||||
data_format=data_format,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testConv2D2x2Depth3ValidBackpropFilter(self):
|
||||
expected = [
|
||||
|
@ -361,6 +361,16 @@ class PoolingTest(test.TestCase):
|
||||
expected=expected_output,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def _testAvgPoolEmptyInput(self, use_gpu):
|
||||
self._VerifyValues(
|
||||
nn_ops.avg_pool,
|
||||
input_sizes=[0, 8, 8, 8],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=[],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def testAvgPooling(self):
|
||||
for use_gpu in True, False:
|
||||
self._testAvgPoolValidPadding(use_gpu)
|
||||
@ -371,6 +381,7 @@ class PoolingTest(test.TestCase):
|
||||
self._testAvgPoolSamePadding4(use_gpu)
|
||||
self._testAvgPoolSamePaddingPacket4(use_gpu)
|
||||
self._testAvgPoolSamePaddingPacket8(use_gpu)
|
||||
self._testAvgPoolEmptyInput(use_gpu)
|
||||
|
||||
def _testMaxPoolValidPadding(self, use_gpu):
|
||||
expected_output = [13.0, 14.0, 15.0]
|
||||
@ -543,6 +554,16 @@ class PoolingTest(test.TestCase):
|
||||
use_gpu=use_gpu,
|
||||
v2=v2)
|
||||
|
||||
def _testMaxPoolEmptyInput(self, use_gpu):
|
||||
self._VerifyValues(
|
||||
gen_nn_ops._max_pool_v2,
|
||||
input_sizes=[0, 8, 8, 8],
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding="SAME",
|
||||
expected=[],
|
||||
use_gpu=use_gpu)
|
||||
|
||||
def testMaxPooling(self):
|
||||
for use_gpu in True, False:
|
||||
self._testMaxPoolValidPadding(use_gpu)
|
||||
@ -551,6 +572,7 @@ class PoolingTest(test.TestCase):
|
||||
self._testMaxPoolValidPaddingUnevenStride(use_gpu)
|
||||
self._testMaxPoolSamePaddingPacket4(use_gpu)
|
||||
self._testMaxPoolSamePaddingPacket8(use_gpu)
|
||||
self._testMaxPoolEmptyInput(use_gpu)
|
||||
|
||||
# Tests for DepthwiseMaxPooling on CPU only.
|
||||
def testDepthwiseMaxPool1x1DepthWindow1(self):
|
||||
|
@ -171,6 +171,10 @@ class BatchNormalizationTest(test.TestCase):
|
||||
x, x_shape, y, y_shape, delta=1e-3, x_init_value=x_init_val)
|
||||
_, numerical_grad = gradient_checker.compute_gradient(
|
||||
x32, x_shape, y32, y_shape, delta=1e-3, x_init_value=x32_init_val)
|
||||
|
||||
# If grad is empty, no error.
|
||||
if theoretical_grad.size == 0 and numerical_grad.size == 0:
|
||||
return 0
|
||||
return np.fabs(theoretical_grad - numerical_grad).max()
|
||||
|
||||
def _test_gradient(self,
|
||||
@ -371,6 +375,17 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testInferenceShape5(self):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_inference(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_inference(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTrainingShape1(self):
|
||||
x_shape = [1, 1, 6, 1]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
@ -409,6 +424,17 @@ class BatchNormalizationTest(test.TestCase):
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testTrainingShape5(self):
|
||||
x_shape = [0, 131, 127, 6]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_training(
|
||||
x_shape, dtype, [131], np.float32, use_gpu=True, data_format='NCHW')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=True, data_format='NHWC')
|
||||
self._test_training(
|
||||
x_shape, dtype, [6], np.float32, use_gpu=False, data_format='NHWC')
|
||||
|
||||
def testBatchNormGradShape1(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [1, 1, 6, 1]
|
||||
@ -496,6 +522,33 @@ class BatchNormalizationTest(test.TestCase):
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
|
||||
def testBatchNormGradShape5(self):
|
||||
for is_training in [True, False]:
|
||||
x_shape = [0, 7, 11, 4]
|
||||
for dtype in [np.float16, np.float32]:
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [7],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NCHW',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=True,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
self._test_gradient(
|
||||
x_shape,
|
||||
dtype, [4],
|
||||
np.float32,
|
||||
use_gpu=False,
|
||||
data_format='NHWC',
|
||||
is_training=is_training)
|
||||
|
||||
def _testBatchNormGradGrad(self, config):
|
||||
shape = config['shape']
|
||||
err_tolerance = config['err_tolerance']
|
||||
|
Loading…
Reference in New Issue
Block a user