Add additional test cases for Bincount Shape function, and fix clang-format issue

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2018-08-11 18:39:13 +00:00
parent 740c58b6fa
commit e6981fc222
3 changed files with 33 additions and 1 deletions

View File

@ -1430,7 +1430,8 @@ REGISTER_OP("Bincount")
// Return `[size]` shape if size is known.
int32 size_val = size_tensor->scalar<int32>()();
if (size_val < 0) {
return errors::InvalidArgument("size (", size_val, ") must be non-negative");
return errors::InvalidArgument("size (", size_val,
") must be non-negative");
}
c->set_output(0, c->MakeShape({size_val}));
return Status::OK();

View File

@ -558,4 +558,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
TEST(MathOpsTest, Bincount_ShapeFn) {
ShapeInferenceTestOp op("Bincount");
// size should be scalar.
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
INFER_OK(op, "?;?;?", "[?]");
INFER_OK(op, "?;[];?", "[?]");
INFER_OK(op, "[?];[];?", "[?]");
INFER_OK(op, "[?];[];[?]", "[?]");
}
} // end namespace tensorflow

View File

@ -22,6 +22,8 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@ -97,6 +99,23 @@ class BincountTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError):
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
def test_shape_function(self):
# size must be scalar.
with self.assertRaisesRegexp(
ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"):
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
# size must be positive.
with self.assertRaisesRegexp(
ValueError, "must be non-negative"):
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
# if size is a constant then the shape is known.
v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
self.assertAllEqual(v1.get_shape().as_list(), [5])
# if size is a placeholder then the shape is unknown.
s = array_ops.placeholder(dtype=dtypes.int32)
v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
self.assertAllEqual(v2.get_shape().as_list(), [None])
if __name__ == "__main__":
googletest.main()