Enable GPU tests for bucketize op and fix a hard crash on empty input tensors.
PiperOrigin-RevId: 268139518
This commit is contained in:
parent
dd9b975b03
commit
1ec2677a9c
@ -67,9 +67,11 @@ class BucketizeOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||||
&output_tensor));
|
&output_tensor));
|
||||||
auto output = output_tensor->template flat<int32>();
|
auto output = output_tensor->template flat<int32>();
|
||||||
|
if (input.size() > 0) {
|
||||||
OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute(
|
OP_REQUIRES_OK(context, functor::BucketizeFunctor<Device, T>::Compute(
|
||||||
context, input, boundaries_, output));
|
context, input, boundaries_, output));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<float> boundaries_;
|
std::vector<float> boundaries_;
|
||||||
|
@ -3680,7 +3680,7 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_py_test(
|
cuda_py_test(
|
||||||
name = "bucketize_op_test",
|
name = "bucketize_op_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["bucketize_op_test.py"],
|
srcs = ["bucketize_op_test.py"],
|
||||||
|
@ -18,9 +18,13 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -35,6 +39,13 @@ class BucketizationOpTest(test.TestCase):
|
|||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
self.assertAllEqual(expected_out, self.evaluate(op))
|
self.assertAllEqual(expected_out, self.evaluate(op))
|
||||||
|
|
||||||
|
def testEmptyFloat(self):
|
||||||
|
op = math_ops._bucketize(
|
||||||
|
array_ops.zeros([0, 3], dtype=dtypes.float32), boundaries=[])
|
||||||
|
expected_out = np.zeros([0, 3], dtype=np.float32)
|
||||||
|
with self.session(use_gpu=True):
|
||||||
|
self.assertAllEqual(expected_out, self.evaluate(op))
|
||||||
|
|
||||||
def testFloat(self):
|
def testFloat(self):
|
||||||
op = math_ops._bucketize(
|
op = math_ops._bucketize(
|
||||||
constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]),
|
constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]),
|
||||||
|
Loading…
Reference in New Issue
Block a user