diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 216002ad8e7..b9efddf4cdb 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1955,6 +1955,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, // in C++ op code, we must still assert that the unknown dim is either 1 // or the same as the known dim. // - If either dimension is 1, the other dimension is the output. + // - If both are unknown then dimension is unknown if (c->Value(dim_x) > 1) { if (!incompatible_shape_error) { *out = c->UnknownShape(); @@ -1973,6 +1974,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, dims.push_back(dim_x); } else if (dim_y.SameHandle(dim_x)) { dims.push_back(dim_x); + } else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) { + dims.push_back(c->UnknownDim()); } else { if (!incompatible_shape_error) { *out = c->UnknownShape(); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 5c69a2a7f1c..2b65f88042c 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -120,7 +120,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) { INFER_OK(op, "[1];[?]", "[d1_0]"); INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?"); INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?"); - INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?"); + INFER_OK(op, "[?];[?]", "[?]"); INFER_OK(op, "[];[?]", "[d1_0]"); INFER_OK(op, "[?];[]", "[d0_0]"); @@ -604,4 +604,22 @@ TEST(MathOpsTest, SobolSample) { INFER_OK(op, "[];[];[]", "[?,?]"); } + +TEST(MathOpsTest, EqualOp) { + ShapeInferenceTestOp op("Equal"); + AddNodeAttr("incompatible_shape_error", true, &op.node_def); + + INFER_OK(op, "?;?", "?"); + INFER_OK(op, "[1,2];?", "?"); + INFER_OK(op, "?;[1,2]", "?"); + + INFER_OK(op, "[1,2,3];[1]", "[d0_0,d0_1,d0_2]"); + INFER_OK(op, "[?,2,1];[1,3]", "[d0_0,d0_1,d1_1]"); + INFER_OK(op, "[1,?,3];[3,1]", "[d0_0,d1_0,d0_2]"); + INFER_OK(op, "[1,2,3];[2,1,3]", "[d1_0,d0_1,d0_2]"); + + // Note: Test case for GitHub issue 40471 + INFER_OK(op, "[?,10,1];[?,1,4]", "[?,d0_1,d1_2]"); + INFER_OK(op, "[10,?,1];[1,?,4]", "[d0_0,?,d1_2]"); +} } // end namespace tensorflow