Merge pull request #45280 from WindQAQ:fix-unique-shape-inference

PiperOrigin-RevId: 346991239
Change-Id: I50ea0421d99b72b1fb44a7da41198c21098e54bd
This commit is contained in:
TensorFlower Gardener 2020-12-11 07:14:08 -08:00
commit 62197c31b9
2 changed files with 92 additions and 2 deletions

View File

@ -1441,6 +1441,50 @@ REGISTER_OP("_MklConjugateTranspose")
#endif // INTEL_MKL
// --------------------------------------------------------------------------
namespace {
Status UniqueIdxShapeFn(InferenceContext* c) {
ShapeHandle input = c->input(0);
const Tensor* axis_t = c->input_tensor(1);
if (axis_t == nullptr || !c->RankKnown(input)) {
c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
}
if (c->Rank(c->input(1)) != 1) {
return errors::InvalidArgument("axis expects a 1D vector.");
}
int32 n = axis_t->NumElements();
if (n == 0) {
if (c->Rank(input) != 1) {
return errors::InvalidArgument("x expects a 1D vector.");
}
c->set_output(1, input);
return Status::OK();
} else if (n == 1) {
int64 axis;
if (axis_t->dtype() == DT_INT32) {
axis = static_cast<int64>(axis_t->flat<int32>()(0));
} else {
axis = axis_t->flat<int64>()(0);
}
int64 input_rank = c->Rank(input);
if (axis < -input_rank || axis >= input_rank) {
return errors::InvalidArgument("axis expects to be in the range [",
-input_rank, ", ", input_rank, ")");
}
if (axis < 0) {
axis += input_rank;
}
c->set_output(1, c->Vector(c->Dim(input, axis)));
return Status::OK();
}
return errors::InvalidArgument(
"axis does not support input tensors larger than 1 elements.");
}
} // namespace
REGISTER_OP("Unique")
.Input("x: T")
.Output("y: T")
@ -1465,7 +1509,7 @@ REGISTER_OP("UniqueV2")
.Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
c->set_output(1, c->input(0));
TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
return Status::OK();
});
@ -1496,7 +1540,7 @@ REGISTER_OP("UniqueWithCountsV2")
.Attr("out_idx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
c->set_output(1, c->input(0));
TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
});

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.platform import test
@ -106,6 +107,51 @@ class UniqueTest(test.TestCase):
for i in range(len(x)):
self.assertEqual(x[i], tf_y[tf_idx[i]])
@test_util.run_deprecated_v1
def testShapeInferenceV2(self):
"""Test shape inference."""
x = np.arange(6).reshape(3, 2, 1)
_, idx = gen_array_ops.unique_v2(x, axis=[0])
self.assertEqual(idx.shape.as_list(), [3])
_, idx = gen_array_ops.unique_v2(x, axis=[1])
self.assertEqual(idx.shape.as_list(), [2])
_, idx = gen_array_ops.unique_v2(x, axis=[2])
self.assertEqual(idx.shape.as_list(), [1])
_, idx = gen_array_ops.unique_v2(x, axis=[-1])
self.assertEqual(idx.shape.as_list(), [1])
_, idx = gen_array_ops.unique_v2(x, axis=[-2])
self.assertEqual(idx.shape.as_list(), [2])
_, idx = gen_array_ops.unique_v2(x, axis=[-3])
self.assertEqual(idx.shape.as_list(), [3])
_, idx = gen_array_ops.unique_v2([0, 1, 2], axis=[])
self.assertEqual(idx.shape.as_list(), [3])
with self.assertRaisesRegexp(ValueError, "axis expects a 1D vector"):
gen_array_ops.unique_v2(x, axis=[[0]])
with self.assertRaisesRegexp(ValueError, "x expects a 1D vector"):
gen_array_ops.unique_v2(x, axis=[])
with self.assertRaisesRegexp(
ValueError, "axis does not support input tensors larger than"):
gen_array_ops.unique_v2(x, axis=[1, 2])
with self.assertRaisesRegexp(ValueError,
r"axis expects to be in the range \[-3, 3\)"):
gen_array_ops.unique_v2(x, axis=[3])
with self.assertRaisesRegexp(ValueError,
r"axis expects to be in the range \[-3, 3\)"):
gen_array_ops.unique_v2(x, axis=[-4])
x_t = array_ops.placeholder(dtypes.int32, shape=None)
_, idx = gen_array_ops.unique_v2(x_t, axis=[0])
self.assertEqual(idx.shape.as_list(), [None])
axis_t = array_ops.placeholder(dtypes.int32, shape=None)
_, idx = gen_array_ops.unique_v2(x, axis=axis_t)
self.assertEqual(idx.shape.as_list(), [None])
class UniqueWithCountsTest(test.TestCase):