Merge pull request #40480 from yongtang:40471-equal-output-shapes-autograph

PiperOrigin-RevId: 317396929
Change-Id: Id5ce5a139055d367e97ff39fc65cbe68af3146b6
This commit is contained in:
TensorFlower Gardener 2020-06-19 16:13:47 -07:00
commit e264e71a44
2 changed files with 22 additions and 1 deletions

View File

@ -1955,6 +1955,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
// in C++ op code, we must still assert that the unknown dim is either 1 // in C++ op code, we must still assert that the unknown dim is either 1
// or the same as the known dim. // or the same as the known dim.
// - If either dimension is 1, the other dimension is the output. // - If either dimension is 1, the other dimension is the output.
// - If both are unknown then dimension is unknown
if (c->Value(dim_x) > 1) { if (c->Value(dim_x) > 1) {
if (!incompatible_shape_error) { if (!incompatible_shape_error) {
*out = c->UnknownShape(); *out = c->UnknownShape();
@ -1973,6 +1974,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
dims.push_back(dim_x); dims.push_back(dim_x);
} else if (dim_y.SameHandle(dim_x)) { } else if (dim_y.SameHandle(dim_x)) {
dims.push_back(dim_x); dims.push_back(dim_x);
} else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
dims.push_back(c->UnknownDim());
} else { } else {
if (!incompatible_shape_error) { if (!incompatible_shape_error) {
*out = c->UnknownShape(); *out = c->UnknownShape();

View File

@ -120,7 +120,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
INFER_OK(op, "[1];[?]", "[d1_0]"); INFER_OK(op, "[1];[?]", "[d1_0]");
INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?"); INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?"); INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?"); INFER_OK(op, "[?];[?]", "[?]");
INFER_OK(op, "[];[?]", "[d1_0]"); INFER_OK(op, "[];[?]", "[d1_0]");
INFER_OK(op, "[?];[]", "[d0_0]"); INFER_OK(op, "[?];[]", "[d0_0]");
@ -604,4 +604,22 @@ TEST(MathOpsTest, SobolSample) {
INFER_OK(op, "[];[];[]", "[?,?]"); INFER_OK(op, "[];[];[]", "[?,?]");
} }
TEST(MathOpsTest, EqualOp) {
ShapeInferenceTestOp op("Equal");
AddNodeAttr("incompatible_shape_error", true, &op.node_def);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
INFER_OK(op, "?;[1,2]", "?");
INFER_OK(op, "[1,2,3];[1]", "[d0_0,d0_1,d0_2]");
INFER_OK(op, "[?,2,1];[1,3]", "[d0_0,d0_1,d1_1]");
INFER_OK(op, "[1,?,3];[3,1]", "[d0_0,d1_0,d0_2]");
INFER_OK(op, "[1,2,3];[2,1,3]", "[d1_0,d0_1,d0_2]");
// Note: Test case for GitHub issue 40471
INFER_OK(op, "[?,10,1];[?,1,4]", "[?,d0_1,d1_2]");
INFER_OK(op, "[10,?,1];[1,?,4]", "[d0_0,?,d1_2]");
}
} // end namespace tensorflow } // end namespace tensorflow