Merge pull request #20476 from yongtang:06052018-bincount-shape
PiperOrigin-RevId: 215947463
This commit is contained in:
commit
c5bd63fd52
@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount")
|
||||
.Attr("T: {int32, int64, float32, float64}")
|
||||
.Output("bins: T")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShapeOfRank(1));
|
||||
ShapeHandle unused;
|
||||
// The input `size` must be a scalar.
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
|
||||
const Tensor* size_tensor = c->input_tensor(1);
|
||||
if (size_tensor == nullptr) {
|
||||
// Return unknown shape if size is not known.
|
||||
c->set_output(0, c->UnknownShapeOfRank(1));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Return `[size]` shape if size is known.
|
||||
int32 size_val = size_tensor->scalar<int32>()();
|
||||
if (size_val < 0) {
|
||||
return errors::InvalidArgument("size (", size_val,
|
||||
") must be non-negative");
|
||||
}
|
||||
c->set_output(0, c->MakeShape({size_val}));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
|
@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
|
||||
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
|
||||
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, Bincount_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Bincount");
|
||||
|
||||
// size should be scalar.
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
|
||||
|
||||
INFER_OK(op, "?;?;?", "[?]");
|
||||
INFER_OK(op, "?;[];?", "[?]");
|
||||
INFER_OK(op, "[?];[];?", "[?]");
|
||||
INFER_OK(op, "[?];[];[?]", "[?]");
|
||||
}
|
||||
} // end namespace tensorflow
|
||||
|
@ -22,6 +22,8 @@ import numpy as np
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
@ -97,6 +99,22 @@ class BincountTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
|
||||
|
||||
def test_shape_function(self):
|
||||
# size must be scalar.
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"):
|
||||
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
|
||||
# size must be positive.
|
||||
with self.assertRaisesRegexp(ValueError, "must be non-negative"):
|
||||
gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
|
||||
# if size is a constant then the shape is known.
|
||||
v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
|
||||
self.assertAllEqual(v1.get_shape().as_list(), [5])
|
||||
# if size is a placeholder then the shape is unknown.
|
||||
s = array_ops.placeholder(dtype=dtypes.int32)
|
||||
v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
|
||||
self.assertAllEqual(v2.get_shape().as_list(), [None])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user