From 1ec2677a9c0780615a7f0732e2a4e6353c8a09d9 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Mon, 9 Sep 2019 20:41:34 -0700 Subject: [PATCH] Enable GPU tests for bucketize op and fix a hard crash on empty input tensors. PiperOrigin-RevId: 268139518 --- tensorflow/core/kernels/bucketize_op.cc | 6 ++++-- tensorflow/python/kernel_tests/BUILD | 2 +- tensorflow/python/kernel_tests/bucketize_op_test.py | 11 +++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc index 393254474e2..1b65d46dffb 100644 --- a/tensorflow/core/kernels/bucketize_op.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -67,8 +67,10 @@ class BucketizeOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output = output_tensor->template flat(); - OP_REQUIRES_OK(context, functor::BucketizeFunctor::Compute( - context, input, boundaries_, output)); + if (input.size() > 0) { + OP_REQUIRES_OK(context, functor::BucketizeFunctor::Compute( + context, input, boundaries_, output)); + } } private: diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index f124541d87f..cab700f1f3a 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -3680,7 +3680,7 @@ tf_py_test( ], ) -tf_py_test( +cuda_py_test( name = "bucketize_op_test", size = "small", srcs = ["bucketize_op_test.py"], diff --git a/tensorflow/python/kernel_tests/bucketize_op_test.py b/tensorflow/python/kernel_tests/bucketize_op_test.py index 95df6943705..128cc17db15 100644 --- a/tensorflow/python/kernel_tests/bucketize_op_test.py +++ b/tensorflow/python/kernel_tests/bucketize_op_test.py @@ -18,9 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -35,6 +39,13 @@ class BucketizationOpTest(test.TestCase): with self.session(use_gpu=True) as sess: self.assertAllEqual(expected_out, self.evaluate(op)) + def testEmptyFloat(self): + op = math_ops._bucketize( + array_ops.zeros([0, 3], dtype=dtypes.float32), boundaries=[]) + expected_out = np.zeros([0, 3], dtype=np.float32) + with self.session(use_gpu=True): + self.assertAllEqual(expected_out, self.evaluate(op)) + def testFloat(self): op = math_ops._bucketize( constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]),