Don't crash in 3D pooling ops with empty batch size on GPU.
PiperOrigin-RevId: 313299099 Change-Id: I40ce8f57efc386ae820460a325cfebee1be14d77
This commit is contained in:
parent
b6f542de70
commit
b100b185ee
@ -192,6 +192,7 @@ class Pooling3DOp : public UnaryOp<T> {
|
|||||||
{{out[2], out[1], out[0]}}, depth);
|
{{out[2], out[1], out[0]}}, depth);
|
||||||
Tensor* output;
|
Tensor* output;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||||
|
if (out_shape.num_elements() == 0) return;
|
||||||
LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride,
|
LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride,
|
||||||
padding, data_format_, padding_,
|
padding, data_format_, padding_,
|
||||||
output);
|
output);
|
||||||
|
@ -205,14 +205,14 @@ class PoolingTest(test.TestCase):
|
|||||||
padding="VALID",
|
padding="VALID",
|
||||||
expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5])
|
expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5])
|
||||||
|
|
||||||
def _MaxPool3DEmptyTensorOutputShape(self):
|
def testMaxPool3DEmptyTensorOutputShape(self):
|
||||||
"""Verifies the output shape of the max pooling function when tensor is empty.
|
"""Verifies the output shape of the max pooling function when tensor is empty.
|
||||||
|
|
||||||
Args: none
|
Args: none
|
||||||
"""
|
"""
|
||||||
input_sizes = [0, 112, 112, 112, 64]
|
input_sizes = [0, 112, 112, 112, 64]
|
||||||
|
|
||||||
input_data = 1
|
input_data = 1.
|
||||||
input_tensor = constant_op.constant(
|
input_tensor = constant_op.constant(
|
||||||
input_data, shape=input_sizes, name="input")
|
input_data, shape=input_sizes, name="input")
|
||||||
max_pool_3d = nn_ops.max_pool3d(
|
max_pool_3d = nn_ops.max_pool3d(
|
||||||
|
Loading…
Reference in New Issue
Block a user