diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 2018f793741..7b0557c4595 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1441,6 +1441,53 @@ REGISTER_OP("_MklConjugateTranspose") #endif // INTEL_MKL // -------------------------------------------------------------------------- +namespace { +Status UniqueIdxShapeFn(InferenceContext* c) { + ShapeHandle input = c->input(0); + const Tensor* axis_t = c->input_tensor(1); + if (axis_t == nullptr || !c->RankKnown(input)) { + c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); + return Status::OK(); + } + + if (c->Rank(c->input(1)) != 1) { + return errors::InvalidArgument("axis expects a 1D vector."); + } + + int32 n = axis_t->NumElements(); + if (n == 0) { + if (c->Rank(input) != 1) { + return errors::InvalidArgument("x expects a 1D vector."); + } + c->set_output(1, input); + return Status::OK(); + } else if (n == 1) { + int64 axis; + if (axis_t->dtype() == DT_INT32) { + axis = static_cast(axis_t->flat()(0)); + } else { + axis = axis_t->flat()(0); + } + + int64 input_rank = c->Rank(input); + if (axis < -input_rank || axis >= input_rank) { + return errors::InvalidArgument("axis expects to be in the range [", + -input_rank, ", ", input_rank, ")"); + } + if (axis < 0) { + axis += input_rank; + } + c->set_output(1, c->Vector(c->Dim(input, axis))); + return Status::OK(); + } else { + return errors::InvalidArgument( + "axis does not support input tensors larger than 1 elements."); + } + + return Status::OK(); +} +} // namespace + REGISTER_OP("Unique") .Input("x: T") .Output("y: T") @@ -1465,7 +1512,7 @@ REGISTER_OP("UniqueV2") .Attr("out_idx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); - c->set_output(1, c->input(0)); + TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c)); return Status::OK(); }); @@ -1496,7 +1543,7 @@ REGISTER_OP("UniqueWithCountsV2") .Attr("out_idx: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); - c->set_output(1, c->input(0)); + TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c)); c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); return Status::OK(); }); diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py index 436fef8171f..188a4e1bae1 100644 --- a/tensorflow/python/kernel_tests/unique_op_test.py +++ b/tensorflow/python/kernel_tests/unique_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.platform import test @@ -106,6 +107,49 @@ class UniqueTest(test.TestCase): for i in range(len(x)): self.assertEqual(x[i], tf_y[tf_idx[i]]) + @test_util.run_deprecated_v1 + def testShapeInferenceV2(self): + """Test shape inference.""" + x = np.random.randint(2, high=10, size=(3, 2, 1)) + _, idx = gen_array_ops.unique_v2(x, axis=[0]) + self.assertEqual(idx.shape.as_list(), [3]) + _, idx = gen_array_ops.unique_v2(x, axis=[1]) + self.assertEqual(idx.shape.as_list(), [2]) + _, idx = gen_array_ops.unique_v2(x, axis=[2]) + self.assertEqual(idx.shape.as_list(), [1]) + _, idx = gen_array_ops.unique_v2(x, axis=[-1]) + self.assertEqual(idx.shape.as_list(), [1]) + _, idx = gen_array_ops.unique_v2(x, axis=[-2]) + self.assertEqual(idx.shape.as_list(), [2]) + _, idx = gen_array_ops.unique_v2(x, axis=[-3]) + self.assertEqual(idx.shape.as_list(), [3]) + + with self.assertRaisesRegexp(ValueError, "axis expects a 1D vector"): + gen_array_ops.unique_v2(x, axis=[[0]]) + + with self.assertRaisesRegexp(ValueError, "x expects a 1D vector"): + gen_array_ops.unique_v2(x, axis=[]) + + with self.assertRaisesRegexp( + ValueError, "axis does not support input tensors larger than"): + gen_array_ops.unique_v2(x, axis=[1, 2]) + + with self.assertRaisesRegexp( + ValueError, r"axis expects to be in the range \[-3, 3\)"): + gen_array_ops.unique_v2(x, axis=[3]) + + with self.assertRaisesRegexp( + ValueError, r"axis expects to be in the range \[-3, 3\)"): + gen_array_ops.unique_v2(x, axis=[-4]) + + x_t = array_ops.placeholder(dtypes.int32, shape=None) + _, idx = gen_array_ops.unique_v2(x_t, axis=[0]) + self.assertEqual(idx.shape.as_list(), [None]) + + axis_t = array_ops.placeholder(dtypes.int32, shape=None) + _, idx = gen_array_ops.unique_v2(x, axis=axis_t) + self.assertEqual(idx.shape.as_list(), [None]) + class UniqueWithCountsTest(test.TestCase):