Don't crash in 3D pooling ops with empty batch size on GPU.

PiperOrigin-RevId: 313299099
Change-Id: I40ce8f57efc386ae820460a325cfebee1be14d77
This commit is contained in:
A. Unique TensorFlower 2020-05-26 17:28:20 -07:00 committed by TensorFlower Gardener
parent b6f542de70
commit b100b185ee
2 changed files with 3 additions and 2 deletions

View File

@ -192,6 +192,7 @@ class Pooling3DOp : public UnaryOp<T> {
{{out[2], out[1], out[0]}}, depth);
Tensor* output;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
if (out_shape.num_elements() == 0) return;
LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride,
padding, data_format_, padding_,
output);

View File

@ -205,14 +205,14 @@ class PoolingTest(test.TestCase):
padding="VALID",
expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5])
def _MaxPool3DEmptyTensorOutputShape(self):
def testMaxPool3DEmptyTensorOutputShape(self):
"""Verifies the output shape of the max pooling function when tensor is empty.
Args: none
"""
input_sizes = [0, 112, 112, 112, 64]
input_data = 1
input_data = 1.
input_tensor = constant_op.constant(
input_data, shape=input_sizes, name="input")
max_pool_3d = nn_ops.max_pool3d(