Merge pull request #40480 from yongtang:40471-equal-output-shapes-autograph
PiperOrigin-RevId: 317396929 Change-Id: Id5ce5a139055d367e97ff39fc65cbe68af3146b6
This commit is contained in:
commit
e264e71a44
@ -1955,6 +1955,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
// in C++ op code, we must still assert that the unknown dim is either 1
|
||||
// or the same as the known dim.
|
||||
// - If either dimension is 1, the other dimension is the output.
|
||||
// - If both are unknown then dimension is unknown
|
||||
if (c->Value(dim_x) > 1) {
|
||||
if (!incompatible_shape_error) {
|
||||
*out = c->UnknownShape();
|
||||
@ -1973,6 +1974,8 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
|
||||
dims.push_back(dim_x);
|
||||
} else if (dim_y.SameHandle(dim_x)) {
|
||||
dims.push_back(dim_x);
|
||||
} else if (!c->ValueKnown(dim_x) && !c->ValueKnown(dim_y)) {
|
||||
dims.push_back(c->UnknownDim());
|
||||
} else {
|
||||
if (!incompatible_shape_error) {
|
||||
*out = c->UnknownShape();
|
||||
|
@ -120,7 +120,7 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
|
||||
INFER_OK(op, "[1];[?]", "[d1_0]");
|
||||
INFER_OK(op, "[?];[2]", incompatible_shape_error ? "[d1_0]" : "?");
|
||||
INFER_OK(op, "[2];[?]", incompatible_shape_error ? "[d0_0]" : "?");
|
||||
INFER_OK(op, "[?];[?]", incompatible_shape_error ? "[?]" : "?");
|
||||
INFER_OK(op, "[?];[?]", "[?]");
|
||||
INFER_OK(op, "[];[?]", "[d1_0]");
|
||||
INFER_OK(op, "[?];[]", "[d0_0]");
|
||||
|
||||
@ -604,4 +604,22 @@ TEST(MathOpsTest, SobolSample) {
|
||||
|
||||
INFER_OK(op, "[];[];[]", "[?,?]");
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, EqualOp) {
|
||||
ShapeInferenceTestOp op("Equal");
|
||||
AddNodeAttr("incompatible_shape_error", true, &op.node_def);
|
||||
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[1,2];?", "?");
|
||||
INFER_OK(op, "?;[1,2]", "?");
|
||||
|
||||
INFER_OK(op, "[1,2,3];[1]", "[d0_0,d0_1,d0_2]");
|
||||
INFER_OK(op, "[?,2,1];[1,3]", "[d0_0,d0_1,d1_1]");
|
||||
INFER_OK(op, "[1,?,3];[3,1]", "[d0_0,d1_0,d0_2]");
|
||||
INFER_OK(op, "[1,2,3];[2,1,3]", "[d1_0,d0_1,d0_2]");
|
||||
|
||||
// Note: Test case for GitHub issue 40471
|
||||
INFER_OK(op, "[?,10,1];[?,1,4]", "[?,d0_1,d1_2]");
|
||||
INFER_OK(op, "[10,?,1];[1,?,4]", "[d0_0,?,d1_2]");
|
||||
}
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user