Enable GPU tests for bucketize op and fix a hard crash on empty input tensors.
PiperOrigin-RevId: 268139518
This commit is contained in:
parent
dd9b975b03
commit
1ec2677a9c
@ -67,9 +67,11 @@ class BucketizeOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||
&output_tensor));
|
||||
auto output = output_tensor->template flat<int32>();
|
||||
if (input.size() > 0) {
|
||||
OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute(
|
||||
context, input, boundaries_, output));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<float> boundaries_;
|
||||
|
@ -3680,7 +3680,7 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
cuda_py_test(
|
||||
name = "bucketize_op_test",
|
||||
size = "small",
|
||||
srcs = ["bucketize_op_test.py"],
|
||||
|
@ -18,9 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -35,6 +39,13 @@ class BucketizationOpTest(test.TestCase):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
self.assertAllEqual(expected_out, self.evaluate(op))
|
||||
|
||||
def testEmptyFloat(self):
|
||||
op = math_ops._bucketize(
|
||||
array_ops.zeros([0, 3], dtype=dtypes.float32), boundaries=[])
|
||||
expected_out = np.zeros([0, 3], dtype=np.float32)
|
||||
with self.session(use_gpu=True):
|
||||
self.assertAllEqual(expected_out, self.evaluate(op))
|
||||
|
||||
def testFloat(self):
|
||||
op = math_ops._bucketize(
|
||||
constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]),
|
||||
|
Loading…
Reference in New Issue
Block a user