diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 3eff728f032..a9e5e7824d2 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount") .Attr("T: {int32, int64, float32, float64}") .Output("bins: T") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(1)); + ShapeHandle unused; + // The input `size` must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + const Tensor* size_tensor = c->input_tensor(1); + if (size_tensor == nullptr) { + // Return unknown shape if size is not known. + c->set_output(0, c->UnknownShapeOfRank(1)); + return Status::OK(); + } + + // Return `[size]` shape if size is known. + int32 size_val = size_tensor->scalar()(); + if (size_val < 0) { + return errors::InvalidArgument("size (", size_val, + ") must be non-negative"); + } + c->set_output(0, c->MakeShape({size_val})); return Status::OK(); }); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index be4c3ed2b6e..05379a7d699 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?"); INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]"); } + +TEST(MathOpsTest, Bincount_ShapeFn) { + ShapeInferenceTestOp op("Bincount"); + + // size should be scalar. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?"); + + INFER_OK(op, "?;?;?", "[?]"); + INFER_OK(op, "?;[];?", "[?]"); + INFER_OK(op, "[?];[];?", "[?]"); + INFER_OK(op, "[?];[];[?]", "[?]"); +} } // end namespace tensorflow diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py index 8a58b3f97ec..8177cdd454e 100644 --- a/tensorflow/python/kernel_tests/bincount_op_test.py +++ b/tensorflow/python/kernel_tests/bincount_op_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -97,6 +99,22 @@ class BincountTest(test_util.TensorFlowTestCase): with self.assertRaises(errors.InvalidArgumentError): math_ops.bincount([1, 2, 3, -1, 6, 8]).eval() + def test_shape_function(self): + # size must be scalar. + with self.assertRaisesRegexp( + ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"): + gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], []) + # size must be positive. + with self.assertRaisesRegexp(ValueError, "must be non-negative"): + gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, []) + # if size is a constant then the shape is known. + v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, []) + self.assertAllEqual(v1.get_shape().as_list(), [5]) + # if size is a placeholder then the shape is unknown. + s = array_ops.placeholder(dtype=dtypes.int32) + v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, []) + self.assertAllEqual(v2.get_shape().as_list(), [None]) + if __name__ == "__main__": googletest.main()