Fix tf.bincount errors in eager mode
PiperOrigin-RevId: 323718056 Change-Id: Ic454b2b76a313d68b18e22e705cd7cbe9ae107b7
This commit is contained in:
parent
d227b141aa
commit
5a71093625
@ -175,13 +175,15 @@ class BincountOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& arr_t = ctx->input(0);
|
||||
const Tensor& size_tensor = ctx->input(1);
|
||||
const Tensor& weights_t = ctx->input(2);
|
||||
|
||||
OP_REQUIRES(ctx, size_tensor.dims() == 0,
|
||||
errors::InvalidArgument("Shape must be rank 0 but is rank ",
|
||||
size_tensor.dims()));
|
||||
int32 size = size_tensor.scalar<int32>()();
|
||||
OP_REQUIRES(
|
||||
ctx, size >= 0,
|
||||
errors::InvalidArgument("size (", size, ") must be non-negative"));
|
||||
|
||||
const Tensor& weights_t = ctx->input(2);
|
||||
const auto arr = arr_t.flat<int32>();
|
||||
const auto weights = weights_t.flat<T>();
|
||||
Tensor* output_t;
|
||||
@ -226,6 +228,10 @@ class DenseBincountOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& data = ctx->input(0);
|
||||
OP_REQUIRES(ctx, data.dims() <= 2,
|
||||
errors::InvalidArgument(
|
||||
"Shape must be at most rank 2 but is rank ", data.dims()));
|
||||
|
||||
const Tensor& size_t = ctx->input(1);
|
||||
const Tensor& weights = ctx->input(2);
|
||||
|
||||
|
@ -119,22 +119,24 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(bincount_ops.bincount([1, 2, 3, -1, 6, 8]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_shape_function(self):
|
||||
# size must be scalar.
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Shape must be rank 0 but is rank 1 for .*Bincount"):
|
||||
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
|
||||
(ValueError, errors.InvalidArgumentError),
|
||||
"Shape must be rank 0 but is rank 1 .*Bincount"):
|
||||
gen_math_ops.bincount([1, 2, 3, 1, 6, 8], [1], [])
|
||||
# size must be positive.
|
||||
with self.assertRaisesRegex(ValueError, "must be non-negative"):
|
||||
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
|
||||
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||
"must be non-negative"):
|
||||
gen_math_ops.bincount([1, 2, 3, 1, 6, 8], -5, [])
|
||||
# if size is a constant then the shape is known.
|
||||
v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
|
||||
v1 = gen_math_ops.bincount([1, 2, 3, 1, 6, 8], 5, [])
|
||||
self.assertAllEqual(v1.get_shape().as_list(), [5])
|
||||
# if size is a placeholder then the shape is unknown.
|
||||
s = array_ops.placeholder(dtype=dtypes.int32)
|
||||
v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
|
||||
self.assertAllEqual(v2.get_shape().as_list(), [None])
|
||||
with ops.Graph().as_default():
|
||||
s = array_ops.placeholder(dtype=dtypes.int32)
|
||||
v2 = gen_math_ops.bincount([1, 2, 3, 1, 6, 8], s, [])
|
||||
self.assertAllEqual(v2.get_shape().as_list(), [None])
|
||||
|
||||
|
||||
class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
@ -322,9 +324,9 @@ class BincountOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
size = 10
|
||||
self._test_bincount_col_binary(num_rows, num_cols, size, dtype)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_invalid_rank(self):
|
||||
with self.assertRaisesRegex(ValueError, "at most rank 2"):
|
||||
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
|
||||
"at most rank 2"):
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(
|
||||
gen_math_ops.dense_bincount(
|
||||
|
Loading…
Reference in New Issue
Block a user