Fix tf.bincount errors in eager mode

PiperOrigin-RevId: 323718056
Change-Id: Ic454b2b76a313d68b18e22e705cd7cbe9ae107b7
This commit is contained in:
Gaurav Jain 2020-07-28 21:52:27 -07:00 committed by TensorFlower Gardener
parent d227b141aa
commit 5a71093625
2 changed files with 21 additions and 13 deletions

View File

@ -175,13 +175,15 @@ class BincountOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& arr_t = ctx->input(0);
const Tensor& size_tensor = ctx->input(1);
const Tensor& weights_t = ctx->input(2);
OP_REQUIRES(ctx, size_tensor.dims() == 0,
errors::InvalidArgument("Shape must be rank 0 but is rank ",
size_tensor.dims()));
int32 size = size_tensor.scalar<int32>()();
OP_REQUIRES(
ctx, size >= 0,
errors::InvalidArgument("size (", size, ") must be non-negative"));
const Tensor& weights_t = ctx->input(2);
const auto arr = arr_t.flat<int32>();
const auto weights = weights_t.flat<T>();
Tensor* output_t;
@ -226,6 +228,10 @@ class DenseBincountOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& data = ctx->input(0);
OP_REQUIRES(ctx, data.dims() <= 2,
errors::InvalidArgument(
"Shape must be at most rank 2 but is rank ", data.dims()));
const Tensor& size_t = ctx->input(1);
const Tensor& weights = ctx->input(2);

View File

@ -119,22 +119,24 @@ class BincountTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(bincount_ops.bincount([1, 2, 3, -1, 6, 8]))
@test_util.run_deprecated_v1
def test_shape_function(self):
# size must be scalar.
with self.assertRaisesRegex(
ValueError, "Shape must be rank 0 but is rank 1 for .*Bincount"):
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
(ValueError, errors.InvalidArgumentError),
"Shape must be rank 0 but is rank 1 .*Bincount"):
gen_math_ops.bincount([1, 2, 3, 1, 6, 8], [1], [])
# size must be positive.
with self.assertRaisesRegex(ValueError, "must be non-negative"):
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be non-negative"):
gen_math_ops.bincount([1, 2, 3, 1, 6, 8], -5, [])
# if size is a constant then the shape is known.
v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
v1 = gen_math_ops.bincount([1, 2, 3, 1, 6, 8], 5, [])
self.assertAllEqual(v1.get_shape().as_list(), [5])
# if size is a placeholder then the shape is unknown.
s = array_ops.placeholder(dtype=dtypes.int32)
v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
self.assertAllEqual(v2.get_shape().as_list(), [None])
with ops.Graph().as_default():
s = array_ops.placeholder(dtype=dtypes.int32)
v2 = gen_math_ops.bincount([1, 2, 3, 1, 6, 8], s, [])
self.assertAllEqual(v2.get_shape().as_list(), [None])
class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@ -322,9 +324,9 @@ class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
size = 10
self._test_bincount_col_binary(num_rows, num_cols, size, dtype)
@test_util.run_deprecated_v1
def test_invalid_rank(self):
with self.assertRaisesRegex(ValueError, "at most rank 2"):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"at most rank 2"):
with test_util.use_gpu():
self.evaluate(
gen_math_ops.dense_bincount(