Preserve symbolic shape information as much as possible during shape creation

PiperOrigin-RevId: 178448208
This commit is contained in:
Benoit Steiner 2017-12-08 17:10:40 -08:00 committed by TensorFlower Gardener
parent 74780531e9
commit c697d96ec6
5 changed files with 17 additions and 23 deletions

View File

@ -342,8 +342,8 @@ Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
for (int i = 0; i < rank; ++i) {
dims.push_back(UnknownDim());
}
*out = shape_manager_.MakeShape(dims);
return Status::OK();
ShapeHandle shp = shape_manager_.MakeShape(dims);
return Merge(shape, shp, out);
}
*out = nullptr;
@ -357,13 +357,10 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64 rank,
return errors::InvalidArgument("Rank cannot exceed kint32max");
}
const int32 existing = Rank(shape);
if (existing >= rank) {
if (existing >= rank || existing == kUnknownRank) {
*out = shape;
return Status::OK();
}
if (existing == kUnknownRank) {
return ReturnUnknownShape(out);
}
*out = nullptr;
return errors::InvalidArgument("Shape must be at least rank ", rank,
" but is rank ", existing);
@ -375,10 +372,7 @@ Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64 rank,
return errors::InvalidArgument("Rank cannot exceed kint32max");
}
const int32 existing = Rank(shape);
if (existing == kUnknownRank) {
return ReturnUnknownShape(out);
}
if (existing <= rank) {
if (existing <= rank || existing == kUnknownRank) {
*out = shape;
return Status::OK();
}
@ -395,8 +389,8 @@ Status InferenceContext::WithValue(DimensionHandle dim, int64 value,
return Status::OK();
}
if (existing == kUnknownDim) {
*out = MakeDim(value);
return Status::OK();
DimensionHandle d = MakeDim(value);
return Merge(dim, d, out);
}
*out = nullptr;
return errors::InvalidArgument("Dimension must be ", value, " but is ",

View File

@ -359,11 +359,11 @@ TEST_F(ShapeInferenceTest, WithRankAtMost) {
// WithRankAtMost on a shape with unknown dimensionality always succeeds.
EXPECT_TRUE(c.WithRankAtMost(in0, 1, &s1).ok());
EXPECT_EQ("?", c.DebugString(s1));
EXPECT_FALSE(SameHandle(in0, s1));
EXPECT_TRUE(SameHandle(in0, s1));
EXPECT_TRUE(c.WithRankAtMost(in0, 2, &s2).ok());
EXPECT_EQ("?", c.DebugString(s2));
EXPECT_FALSE(SameHandle(s1, s2));
EXPECT_TRUE(SameHandle(s1, s2));
// WithRankAtMost on shape with known dimensionality.
s1 = in1;
@ -398,11 +398,11 @@ TEST_F(ShapeInferenceTest, WithRankAtLeast) {
// WithRankAtLeast on a shape with unknown dimensionality always succeeds.
EXPECT_TRUE(c.WithRankAtLeast(in0, 1, &s1).ok());
EXPECT_EQ("?", c.DebugString(s1));
EXPECT_FALSE(SameHandle(in0, s1));
EXPECT_TRUE(SameHandle(in0, s1));
EXPECT_TRUE(c.WithRankAtLeast(in0, 2, &s2).ok());
EXPECT_EQ("?", c.DebugString(s2));
EXPECT_FALSE(SameHandle(s1, s2));
EXPECT_TRUE(SameHandle(s1, s2));
// WithRankAtLeast on shape with known dimensionality.
s1 = in1;

View File

@ -521,7 +521,7 @@ TEST(ArrayOpsTest, MatrixSetDiag_ShapeFn) {
INFER_ERROR("Dimensions must be equal, but are 2 and 3", op, "[2,3];[3]");
// Output matches input.
INFER_OK(op, "?;?", "?");
INFER_OK(op, "?;?", "in0");
INFER_OK(op, "[1,2,2];[1,2]", "in0");
INFER_OK(op, "[1,2,3];?", "in0");
INFER_OK(op, "[1,3,2];?", "in0");
@ -1619,7 +1619,7 @@ TEST(ArrayOpsTest, UnchangedWithQuantizationScalars_ShapeFn) {
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannel");
INFER_OK(op, "?;?;?", "?");
INFER_OK(op, "?;?;?", "in0");
INFER_OK(op, "[?];?;?", "in0");
INFER_OK(op, "[1,?,3];[3];[3]", "in0");
INFER_OK(op, "[3];[3];[3]", "in0");
@ -1638,7 +1638,7 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
INFER_OK(op, "?;?;?;?", "?;[?];[?]");
INFER_OK(op, "?;?;?;?", "in0;[?];[?]");
INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");

View File

@ -522,7 +522,7 @@ TEST(MathOpsTest, Cross_ShapeFn) {
INFER_ERROR("Dimension 0 in both shapes must be equal, but", op, "[3];[5]");
INFER_ERROR("Dimension must be 3 but", op, "[3,5];[3,5]");
INFER_OK(op, "?;?", "?");
INFER_OK(op, "?;?", "in0");
INFER_OK(op, "[?];[?]", "in0");
INFER_OK(op, "[1,?,3];[?,?,?]", "in0");
}

View File

@ -22,7 +22,7 @@ namespace tensorflow {
TEST(MathOpsTest, FFT_ShapeFn) {
for (const auto* op_name : {"FFT", "IFFT"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?", "?");
INFER_OK(op, "?", "in0");
INFER_ERROR("Shape must be at least rank 1 but is rank 0", op, "[]");
INFER_OK(op, "[?]", "in0");
INFER_OK(op, "[1]", "in0");
@ -31,7 +31,7 @@ TEST(MathOpsTest, FFT_ShapeFn) {
for (const auto* op_name : {"FFT2D", "IFFT2D"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?", "?");
INFER_OK(op, "?", "in0");
INFER_ERROR("Shape must be at least rank 2 but is rank 1", op, "[1]");
INFER_OK(op, "[?,1]", "in0");
INFER_OK(op, "[1,2]", "in0");
@ -40,7 +40,7 @@ TEST(MathOpsTest, FFT_ShapeFn) {
for (const auto* op_name : {"FFT3D", "IFFT3D"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?", "?");
INFER_OK(op, "?", "in0");
INFER_ERROR("Shape must be at least rank 3 but is rank 2", op, "[1,2]");
INFER_OK(op, "[?,1,?]", "in0");
INFER_OK(op, "[1,2,3]", "in0");