Add C++ shape inference function for broadcasting binary ops.
This is the same logic as python, with an addition in the case where 1 value is unknown and the other is unknown - in this case, we propagate the unknown input dim instead of a new unknown input dim (this case did not apply in python where None was used for the unknown input). Change: 126308395
This commit is contained in:
parent
fdc2d055da
commit
1f120e3f00
@ -47,6 +47,74 @@ Add all input tensors element wise.
|
||||
inputs: Must all be the same size and shape.
|
||||
)doc");
|
||||
|
||||
namespace {
|
||||
|
||||
// Shape inference function for binary operators that broadcast their inputs.
|
||||
Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
|
||||
const Shape* shape_x = c->input(0);
|
||||
const Shape* shape_y = c->input(1);
|
||||
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
|
||||
c->set_output(0, c->CreateUnknownShape());
|
||||
return Status::OK();
|
||||
}
|
||||
const int32 rank_x = c->Rank(shape_x);
|
||||
const int32 rank_y = c->Rank(shape_y);
|
||||
const int32 rank_out = std::max(rank_x, rank_y);
|
||||
|
||||
// To compute the broadcast dimensions, we zip together shape_x and shape_y
|
||||
// and
|
||||
// pad with 1 to make them the same length.
|
||||
std::vector<const Dimension*> dims;
|
||||
const Dimension* dim_one = rank_x == rank_y ? nullptr : c->CreateDim(1);
|
||||
for (int i = 0; i < rank_out; ++i) {
|
||||
const auto* dim_x = i < (rank_out - rank_x)
|
||||
? dim_one
|
||||
: c->Dim(shape_x, i - (rank_out - rank_x));
|
||||
const auto* dim_y = i < (rank_out - rank_y)
|
||||
? dim_one
|
||||
: c->Dim(shape_y, i - (rank_out - rank_y));
|
||||
if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
|
||||
// One or both dimensions is unknown.
|
||||
//
|
||||
// - If either dimension is greater than 1, we assume that the program is
|
||||
// correct, and the other dimension will be broadcast to match it.
|
||||
// TODO(cwhipkey): For shape inference, if we eliminate the shape checks
|
||||
// in C++ op code, we must still assert that the unknown dim is either 1
|
||||
// or the same as the known dim.
|
||||
// - If either dimension is 1, the other dimension is the output.
|
||||
if (c->Value(dim_x) > 1) {
|
||||
dims.push_back(dim_x);
|
||||
} else if (c->Value(dim_y) > 1) {
|
||||
dims.push_back(dim_y);
|
||||
} else if (c->Value(dim_x) == 1) {
|
||||
dims.push_back(dim_y);
|
||||
} else if (c->Value(dim_y) == 1) {
|
||||
dims.push_back(dim_x);
|
||||
} else {
|
||||
dims.push_back(c->CreateUnknownDim());
|
||||
}
|
||||
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
|
||||
if (c->Value(dim_x) == 1 && dim_y != dim_one) {
|
||||
// We will broadcast dim_x to dim_y.
|
||||
dims.push_back(dim_y);
|
||||
} else {
|
||||
DCHECK_EQ(c->Value(dim_y), 1);
|
||||
// We will broadcast dim_y to dim_x.
|
||||
dims.push_back(dim_x);
|
||||
}
|
||||
} else {
|
||||
const Dimension* dim;
|
||||
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
|
||||
dims.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
c->set_output(0, c->CreateShape(dims));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("BatchMatMul")
|
||||
@ -373,6 +441,7 @@ REGISTER_OP("Add")
|
||||
.Attr(
|
||||
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
|
||||
"complex128, string}")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns x + y element-wise.
|
||||
|
||||
@ -381,6 +450,7 @@ Returns x + y element-wise.
|
||||
|
||||
REGISTER_OP("Sub")
|
||||
.BINARY_FEWER()
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns x - y element-wise.
|
||||
)doc");
|
||||
@ -388,12 +458,14 @@ Returns x - y element-wise.
|
||||
REGISTER_OP("Mul")
|
||||
.BINARY_MORE()
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns x * y element-wise.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("Div")
|
||||
.BINARY_MORE()
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns x / y element-wise.
|
||||
)doc");
|
||||
@ -401,6 +473,7 @@ Returns x / y element-wise.
|
||||
REGISTER_OP("SquaredDifference")
|
||||
.BINARY_FEWER()
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns (x - y)(x - y) element-wise.
|
||||
)doc");
|
||||
@ -414,6 +487,7 @@ REGISTER_OP("Maximum")
|
||||
.Output("z: T")
|
||||
.Attr("T: {half, float, double, int32, int64}")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns the max of x and y (i.e. x > y ? x : y) element-wise, broadcasts.
|
||||
)doc");
|
||||
@ -424,6 +498,7 @@ REGISTER_OP("Minimum")
|
||||
.Output("z: T")
|
||||
.Attr("T: {half, float, double, int32, int64}")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns the min of x and y (i.e. x < y ? x : y) element-wise, broadcasts.
|
||||
)doc");
|
||||
@ -433,6 +508,7 @@ REGISTER_OP("Mod")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {int32, int64, float, double}")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Returns element-wise remainder of division.
|
||||
)doc");
|
||||
@ -442,6 +518,7 @@ REGISTER_OP("Pow")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {half, float, double, int32, int64, complex64, complex128}")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Computes the power of one value to another.
|
||||
|
||||
@ -460,6 +537,7 @@ REGISTER_OP("Igammac")
|
||||
.Input("x: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {float, double}")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Compute the upper regularized incomplete Gamma function `Q(a, x)`.
|
||||
|
||||
@ -483,6 +561,7 @@ REGISTER_OP("Igamma")
|
||||
.Input("x: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {float, double}")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Compute the lower regularized incomplete Gamma function `Q(a, x)`.
|
||||
|
||||
@ -506,6 +585,7 @@ REGISTER_OP("Zeta")
|
||||
.Input("q: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {float, double}")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
|
||||
|
||||
@ -521,6 +601,7 @@ REGISTER_OP("Polygamma")
|
||||
.Input("x: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {float, double}")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Compute the polygamma function \\(\psi^{(n)}(x)\\).
|
||||
|
||||
@ -536,8 +617,12 @@ where \\(\psi(x)\\) is the digamma function.
|
||||
|
||||
// Declares cwise binary comparison operations signature: 't, 't -> bool,
|
||||
// where 't has a natural total order.
|
||||
#define COMPARISON() \
|
||||
Input("x: T").Input("y: T").Output("z: bool").Attr("T: realnumbertype")
|
||||
#define COMPARISON() \
|
||||
Input("x: T") \
|
||||
.Input("y: T") \
|
||||
.Output("z: bool") \
|
||||
.Attr("T: realnumbertype") \
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
|
||||
REGISTER_OP("Less")
|
||||
.COMPARISON()
|
||||
@ -567,10 +652,16 @@ Returns the truth value of (x >= y) element-wise.
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
#define EQUALITY_COMPARISON() \
|
||||
Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \
|
||||
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " \
|
||||
"quint8, qint8, qint32, string, bool, complex128}")
|
||||
#define EQUALITY_COMPARISON() \
|
||||
Input("x: T") \
|
||||
.Input("y: T") \
|
||||
.Output("z: bool") \
|
||||
.SetIsCommutative() \
|
||||
.Attr( \
|
||||
"T: {half, float, double, uint8, int8, int16, int32, int64, " \
|
||||
"complex64, " \
|
||||
"quint8, qint8, qint32, string, bool, complex128}") \
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
|
||||
REGISTER_OP("Equal")
|
||||
.EQUALITY_COMPARISON()
|
||||
@ -596,8 +687,12 @@ REGISTER_OP("LogicalNot")
|
||||
Returns the truth value of NOT x element-wise.
|
||||
)doc");
|
||||
|
||||
#define BINARY_LOGICAL() \
|
||||
Input("x: bool").Input("y: bool").Output("z: bool").SetIsCommutative()
|
||||
#define BINARY_LOGICAL() \
|
||||
Input("x: bool") \
|
||||
.Input("y: bool") \
|
||||
.Output("z: bool") \
|
||||
.SetIsCommutative() \
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
|
||||
REGISTER_OP("LogicalAnd")
|
||||
.BINARY_LOGICAL()
|
||||
@ -1271,6 +1366,7 @@ REGISTER_OP("Complex")
|
||||
.Output("out: Tout")
|
||||
.Attr("T: {float, double} = DT_FLOAT")
|
||||
.Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
|
||||
.Doc(R"doc(
|
||||
Converts two real numbers to a complex number.
|
||||
|
||||
|
@ -121,4 +121,52 @@ TEST(MathOpsTest, Segment) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MathOpsTest, BroadcastBinaryOps) {
|
||||
for (const auto* op : {"Add", "Complex",
|
||||
"Div", "Equal",
|
||||
"Greater", "GreaterEqual",
|
||||
"Igamma", "Igammac",
|
||||
"Zeta", "Polygamma",
|
||||
"Less", "LessEqual",
|
||||
"LogicalAnd", "LogicalOr",
|
||||
"Maximum", "Minimum",
|
||||
"Mod", "Mul",
|
||||
"NotEqual", "Pow",
|
||||
"Sub", "SquaredDifference"}) {
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[1,2];?", "?");
|
||||
INFER_OK(op, "?;[1,2]", "?");
|
||||
|
||||
INFER_OK(op, "[?];[1]", "[d0_0]");
|
||||
INFER_OK(op, "[1];[?]", "[d1_0]");
|
||||
INFER_OK(op, "[?];[2]", "[d1_0]");
|
||||
INFER_OK(op, "[2];[?]", "[d0_0]");
|
||||
INFER_OK(op, "[?];[?]", "[?]");
|
||||
INFER_OK(op, "[];[?]", "[d1_0]");
|
||||
INFER_OK(op, "[?];[]", "[d0_0]");
|
||||
|
||||
INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
|
||||
INFER_OK(op, "[];[1]", "[d1_0]");
|
||||
INFER_OK(op, "[1];[]", "[d0_0]");
|
||||
|
||||
INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
|
||||
INFER_OK(op, "[];[2]", "[d1_0]");
|
||||
INFER_OK(op, "[1];[2]", "[d1_0]");
|
||||
INFER_OK(op, "[2];[1]", "[d0_0]");
|
||||
INFER_OK(op, "[2];[]", "[d0_0]");
|
||||
|
||||
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
|
||||
INFER_OK(op, "[];[0]", "[d1_0]");
|
||||
INFER_OK(op, "[1];[0]", "[d1_0]");
|
||||
INFER_OK(op, "[0];[1]", "[d0_0]");
|
||||
INFER_OK(op, "[0];[]", "[d0_0]");
|
||||
|
||||
// Multiple dimension cases (same test cases, switching x and y).
|
||||
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
|
||||
"[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
|
||||
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
|
||||
"[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user