Preserve symbolic shape information as much as possible during shape creation
PiperOrigin-RevId: 178448208
This commit is contained in:
parent
74780531e9
commit
c697d96ec6
@ -342,8 +342,8 @@ Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
dims.push_back(UnknownDim());
|
||||
}
|
||||
*out = shape_manager_.MakeShape(dims);
|
||||
return Status::OK();
|
||||
ShapeHandle shp = shape_manager_.MakeShape(dims);
|
||||
return Merge(shape, shp, out);
|
||||
}
|
||||
*out = nullptr;
|
||||
|
||||
@ -357,13 +357,10 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
|
||||
return errors::InvalidArgument("Rank cannot exceed kint32max");
|
||||
}
|
||||
const int32 existing = Rank(shape);
|
||||
if (existing >= rank) {
|
||||
if (existing >= rank || existing == kUnknownRank) {
|
||||
*out = shape;
|
||||
return Status::OK();
|
||||
}
|
||||
if (existing == kUnknownRank) {
|
||||
return ReturnUnknownShape(out);
|
||||
}
|
||||
*out = nullptr;
|
||||
return errors::InvalidArgument("Shape must be at least rank ", rank,
|
||||
" but is rank ", existing);
|
||||
@ -375,10 +372,7 @@ Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
|
||||
return errors::InvalidArgument("Rank cannot exceed kint32max");
|
||||
}
|
||||
const int32 existing = Rank(shape);
|
||||
if (existing == kUnknownRank) {
|
||||
return ReturnUnknownShape(out);
|
||||
}
|
||||
if (existing <= rank) {
|
||||
if (existing <= rank || existing == kUnknownRank) {
|
||||
*out = shape;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -395,8 +389,8 @@ Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
|
||||
return Status::OK();
|
||||
}
|
||||
if (existing == kUnknownDim) {
|
||||
*out = MakeDim(value);
|
||||
return Status::OK();
|
||||
DimensionHandle d = MakeDim(value);
|
||||
return Merge(dim, d, out);
|
||||
}
|
||||
*out = nullptr;
|
||||
return errors::InvalidArgument("Dimension must be ", value, " but is ",
|
||||
|
@ -359,11 +359,11 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
|
||||
// WithRankAtMost on a shape with unknown dimensionality always succeeds.
|
||||
EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
|
||||
EXPECT_EQ("?", c.DebugString(s1));
|
||||
EXPECT_FALSE(SameHandle(in0, s1));
|
||||
EXPECT_TRUE(SameHandle(in0, s1));
|
||||
|
||||
EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok());
|
||||
EXPECT_EQ("?", c.DebugString(s2));
|
||||
EXPECT_FALSE(SameHandle(s1, s2));
|
||||
EXPECT_TRUE(SameHandle(s1, s2));
|
||||
|
||||
// WithRankAtMost on shape with known dimensionality.
|
||||
s1 = in1;
|
||||
@ -398,11 +398,11 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
|
||||
// WithRankAtLeast on a shape with unknown dimensionality always succeeds.
|
||||
EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
|
||||
EXPECT_EQ("?", c.DebugString(s1));
|
||||
EXPECT_FALSE(SameHandle(in0, s1));
|
||||
EXPECT_TRUE(SameHandle(in0, s1));
|
||||
|
||||
EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok());
|
||||
EXPECT_EQ("?", c.DebugString(s2));
|
||||
EXPECT_FALSE(SameHandle(s1, s2));
|
||||
EXPECT_TRUE(SameHandle(s1, s2));
|
||||
|
||||
// WithRankAtLeast on shape with known dimensionality.
|
||||
s1 = in1;
|
||||
|
@ -521,7 +521,7 @@ TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) {
|
||||
INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
|
||||
|
||||
// Output matches input.
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "?;?", "in0");
|
||||
INFER_OK(op, "[1,2,2];[1,2]", "in0");
|
||||
INFER_OK(op, "[1,2,3];?", "in0");
|
||||
INFER_OK(op, "[1,3,2];?", "in0");
|
||||
@ -1619,7 +1619,7 @@ TEST(ArrayOpsTest, UnchangedWithQuantizationScalars_ShapeFn) {
|
||||
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
|
||||
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel");
|
||||
|
||||
INFER_OK(op, "?;?;?", "?");
|
||||
INFER_OK(op, "?;?;?", "in0");
|
||||
INFER_OK(op, "[?];?;?", "in0");
|
||||
INFER_OK(op, "[1,?,3];[3];[3]", "in0");
|
||||
INFER_OK(op, "[3];[3];[3]", "in0");
|
||||
@ -1638,7 +1638,7 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
|
||||
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
|
||||
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
|
||||
|
||||
INFER_OK(op, "?;?;?;?", "?;[?];[?]");
|
||||
INFER_OK(op, "?;?;?;?", "in0;[?];[?]");
|
||||
INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
|
||||
INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
|
||||
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
|
||||
|
@ -522,7 +522,7 @@ TEST(MathOpsTest, Cross_ShapeFn) {
|
||||
INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
|
||||
INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
|
||||
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "?;?", "in0");
|
||||
INFER_OK(op, "[?];[?]", "in0");
|
||||
INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ namespace tensorflow {
|
||||
TEST(MathOpsTest, FFT_ShapeFn) {
|
||||
for (const auto* op_name : {"FFT", "IFFT"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
INFER_OK(op, "?", "?");
|
||||
INFER_OK(op, "?", "in0");
|
||||
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
|
||||
INFER_OK(op, "[?]", "in0");
|
||||
INFER_OK(op, "[1]", "in0");
|
||||
@ -31,7 +31,7 @@ TEST(MathOpsTest, FFT_ShapeFn) {
|
||||
|
||||
for (const auto* op_name : {"FFT2D", "IFFT2D"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
INFER_OK(op, "?", "?");
|
||||
INFER_OK(op, "?", "in0");
|
||||
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
|
||||
INFER_OK(op, "[?,1]", "in0");
|
||||
INFER_OK(op, "[1,2]", "in0");
|
||||
@ -40,7 +40,7 @@ TEST(MathOpsTest, FFT_ShapeFn) {
|
||||
|
||||
for (const auto* op_name : {"FFT3D", "IFFT3D"}) {
|
||||
ShapeInferenceTestOp op(op_name);
|
||||
INFER_OK(op, "?", "?");
|
||||
INFER_OK(op, "?", "in0");
|
||||
INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]");
|
||||
INFER_OK(op, "[?,1,?]", "in0");
|
||||
INFER_OK(op, "[1,2,3]", "in0");
|
||||
|
Loading…
x
Reference in New Issue
Block a user