Merge pull request #45280 from WindQAQ:fix-unique-shape-inference
PiperOrigin-RevId: 346991239 Change-Id: I50ea0421d99b72b1fb44a7da41198c21098e54bd
This commit is contained in:
commit
62197c31b9
@ -1441,6 +1441,50 @@ REGISTER_OP("_MklConjugateTranspose")
|
||||
#endif // INTEL_MKL
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
namespace {
|
||||
Status UniqueIdxShapeFn(InferenceContext* c) {
|
||||
ShapeHandle input = c->input(0);
|
||||
const Tensor* axis_t = c->input_tensor(1);
|
||||
if (axis_t == nullptr || !c->RankKnown(input)) {
|
||||
c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (c->Rank(c->input(1)) != 1) {
|
||||
return errors::InvalidArgument("axis expects a 1D vector.");
|
||||
}
|
||||
|
||||
int32 n = axis_t->NumElements();
|
||||
if (n == 0) {
|
||||
if (c->Rank(input) != 1) {
|
||||
return errors::InvalidArgument("x expects a 1D vector.");
|
||||
}
|
||||
c->set_output(1, input);
|
||||
return Status::OK();
|
||||
} else if (n == 1) {
|
||||
int64 axis;
|
||||
if (axis_t->dtype() == DT_INT32) {
|
||||
axis = static_cast<int64>(axis_t->flat<int32>()(0));
|
||||
} else {
|
||||
axis = axis_t->flat<int64>()(0);
|
||||
}
|
||||
|
||||
int64 input_rank = c->Rank(input);
|
||||
if (axis < -input_rank || axis >= input_rank) {
|
||||
return errors::InvalidArgument("axis expects to be in the range [",
|
||||
-input_rank, ", ", input_rank, ")");
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += input_rank;
|
||||
}
|
||||
c->set_output(1, c->Vector(c->Dim(input, axis)));
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
"axis does not support input tensors larger than 1 elements.");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_OP("Unique")
|
||||
.Input("x: T")
|
||||
.Output("y: T")
|
||||
@ -1465,7 +1509,7 @@ REGISTER_OP("UniqueV2")
|
||||
.Attr("out_idx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
|
||||
c->set_output(1, c->input(0));
|
||||
TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
@ -1496,7 +1540,7 @@ REGISTER_OP("UniqueWithCountsV2")
|
||||
.Attr("out_idx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
|
||||
c->set_output(1, c->input(0));
|
||||
TF_RETURN_IF_ERROR(UniqueIdxShapeFn(c));
|
||||
c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
|
||||
return Status::OK();
|
||||
});
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -106,6 +107,51 @@ class UniqueTest(test.TestCase):
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(x[i], tf_y[tf_idx[i]])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testShapeInferenceV2(self):
|
||||
"""Test shape inference."""
|
||||
x = np.arange(6).reshape(3, 2, 1)
|
||||
_, idx = gen_array_ops.unique_v2(x, axis=[0])
|
||||
self.assertEqual(idx.shape.as_list(), [3])
|
||||
_, idx = gen_array_ops.unique_v2(x, axis=[1])
|
||||
self.assertEqual(idx.shape.as_list(), [2])
|
||||
_, idx = gen_array_ops.unique_v2(x, axis=[2])
|
||||
self.assertEqual(idx.shape.as_list(), [1])
|
||||
_, idx = gen_array_ops.unique_v2(x, axis=[-1])
|
||||
self.assertEqual(idx.shape.as_list(), [1])
|
||||
_, idx = gen_array_ops.unique_v2(x, axis=[-2])
|
||||
self.assertEqual(idx.shape.as_list(), [2])
|
||||
_, idx = gen_array_ops.unique_v2(x, axis=[-3])
|
||||
self.assertEqual(idx.shape.as_list(), [3])
|
||||
_, idx = gen_array_ops.unique_v2([0, 1, 2], axis=[])
|
||||
self.assertEqual(idx.shape.as_list(), [3])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "axis expects a 1D vector"):
|
||||
gen_array_ops.unique_v2(x, axis=[[0]])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "x expects a 1D vector"):
|
||||
gen_array_ops.unique_v2(x, axis=[])
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "axis does not support input tensors larger than"):
|
||||
gen_array_ops.unique_v2(x, axis=[1, 2])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"axis expects to be in the range \[-3, 3\)"):
|
||||
gen_array_ops.unique_v2(x, axis=[3])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r"axis expects to be in the range \[-3, 3\)"):
|
||||
gen_array_ops.unique_v2(x, axis=[-4])
|
||||
|
||||
x_t = array_ops.placeholder(dtypes.int32, shape=None)
|
||||
_, idx = gen_array_ops.unique_v2(x_t, axis=[0])
|
||||
self.assertEqual(idx.shape.as_list(), [None])
|
||||
|
||||
axis_t = array_ops.placeholder(dtypes.int32, shape=None)
|
||||
_, idx = gen_array_ops.unique_v2(x, axis=axis_t)
|
||||
self.assertEqual(idx.shape.as_list(), [None])
|
||||
|
||||
|
||||
class UniqueWithCountsTest(test.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user