Don't crash in 3D pooling ops with empty batch size on GPU.
PiperOrigin-RevId: 313299099 Change-Id: I40ce8f57efc386ae820460a325cfebee1be14d77
This commit is contained in:
parent
b6f542de70
commit
b100b185ee
@ -192,6 +192,7 @@ class Pooling3DOp : public UnaryOp<T> {
|
||||
{{out[2], out[1], out[0]}}, depth);
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
if (out_shape.num_elements() == 0) return;
|
||||
LaunchPoolingOp<Device, T, Type>::launch(context, tensor_in, window, stride,
|
||||
padding, data_format_, padding_,
|
||||
output);
|
||||
|
@ -205,14 +205,14 @@ class PoolingTest(test.TestCase):
|
||||
padding="VALID",
|
||||
expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5])
|
||||
|
||||
def _MaxPool3DEmptyTensorOutputShape(self):
|
||||
def testMaxPool3DEmptyTensorOutputShape(self):
|
||||
"""Verifies the output shape of the max pooling function when tensor is empty.
|
||||
|
||||
Args: none
|
||||
"""
|
||||
input_sizes = [0, 112, 112, 112, 64]
|
||||
|
||||
input_data = 1
|
||||
input_data = 1.
|
||||
input_tensor = constant_op.constant(
|
||||
input_data, shape=input_sizes, name="input")
|
||||
max_pool_3d = nn_ops.max_pool3d(
|
||||
|
Loading…
Reference in New Issue
Block a user