Enable GPU tests for bucketize op and fix a hard crash on empty input tensors.

PiperOrigin-RevId: 268139518
This commit is contained in:
Eugene Brevdo 2019-09-09 20:41:34 -07:00 committed by TensorFlower Gardener
parent dd9b975b03
commit 1ec2677a9c
3 changed files with 16 additions and 3 deletions

View File

@ -67,9 +67,11 @@ class BucketizeOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor)); &output_tensor));
auto output = output_tensor->template flat<int32>(); auto output = output_tensor->template flat<int32>();
if (input.size() > 0) {
OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute( OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute(
context, input, boundaries_, output)); context, input, boundaries_, output));
} }
}
private: private:
std::vector<float> boundaries_; std::vector<float> boundaries_;

View File

@ -3680,7 +3680,7 @@ tf_py_test(
], ],
) )
tf_py_test( cuda_py_test(
name = "bucketize_op_test", name = "bucketize_op_test",
size = "small", size = "small",
srcs = ["bucketize_op_test.py"], srcs = ["bucketize_op_test.py"],

View File

@ -18,9 +18,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -35,6 +39,13 @@ class BucketizationOpTest(test.TestCase):
with self.session(use_gpu=True) as sess: with self.session(use_gpu=True) as sess:
self.assertAllEqual(expected_out, self.evaluate(op)) self.assertAllEqual(expected_out, self.evaluate(op))
def testEmptyFloat(self):
op = math_ops._bucketize(
array_ops.zeros([0, 3], dtype=dtypes.float32), boundaries=[])
expected_out = np.zeros([0, 3], dtype=np.float32)
with self.session(use_gpu=True):
self.assertAllEqual(expected_out, self.evaluate(op))
def testFloat(self): def testFloat(self):
op = math_ops._bucketize( op = math_ops._bucketize(
constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]), constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]),