Add C++ shape inference function for broadcasting binary ops.

This is the same logic as python, with an addition in the case where 1 value is
unknown and the other is unknown - in this case, we propagate the unknown input
dim instead of a new unknown input dim (this case did not apply in python where
None was used for the unknown input).
Change: 126308395
This commit is contained in:
A. Unique TensorFlower 2016-06-30 08:08:47 -08:00 committed by TensorFlower Gardener
parent fdc2d055da
commit 1f120e3f00
2 changed files with 152 additions and 8 deletions

View File

@ -47,6 +47,74 @@ Add all input tensors element wise.
inputs: Must all be the same size and shape.
)doc");
namespace {
// Shape inference function for binary operators that broadcast their inputs.
Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
const Shape* shape_x = c->input(0);
const Shape* shape_y = c->input(1);
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
c->set_output(0, c->CreateUnknownShape());
return Status::OK();
}
const int32 rank_x = c->Rank(shape_x);
const int32 rank_y = c->Rank(shape_y);
const int32 rank_out = std::max(rank_x, rank_y);
// To compute the broadcast dimensions, we zip together shape_x and shape_y
// and
// pad with 1 to make them the same length.
std::vector<const Dimension*> dims;
const Dimension* dim_one = rank_x == rank_y ? nullptr : c->CreateDim(1);
for (int i = 0; i < rank_out; ++i) {
const auto* dim_x = i < (rank_out - rank_x)
? dim_one
: c->Dim(shape_x, i - (rank_out - rank_x));
const auto* dim_y = i < (rank_out - rank_y)
? dim_one
: c->Dim(shape_y, i - (rank_out - rank_y));
if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) {
// One or both dimensions is unknown.
//
// - If either dimension is greater than 1, we assume that the program is
// correct, and the other dimension will be broadcast to match it.
// TODO(cwhipkey): For shape inference, if we eliminate the shape checks
// in C++ op code, we must still assert that the unknown dim is either 1
// or the same as the known dim.
// - If either dimension is 1, the other dimension is the output.
if (c->Value(dim_x) > 1) {
dims.push_back(dim_x);
} else if (c->Value(dim_y) > 1) {
dims.push_back(dim_y);
} else if (c->Value(dim_x) == 1) {
dims.push_back(dim_y);
} else if (c->Value(dim_y) == 1) {
dims.push_back(dim_x);
} else {
dims.push_back(c->CreateUnknownDim());
}
} else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) {
if (c->Value(dim_x) == 1 && dim_y != dim_one) {
// We will broadcast dim_x to dim_y.
dims.push_back(dim_y);
} else {
DCHECK_EQ(c->Value(dim_y), 1);
// We will broadcast dim_y to dim_x.
dims.push_back(dim_x);
}
} else {
const Dimension* dim;
TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim));
dims.push_back(dim);
}
}
c->set_output(0, c->CreateShape(dims));
return Status::OK();
}
} // namespace
// --------------------------------------------------------------------------
REGISTER_OP("BatchMatMul")
@ -373,6 +441,7 @@ REGISTER_OP("Add")
.Attr(
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
"complex128, string}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x + y element-wise.
@ -381,6 +450,7 @@ Returns x + y element-wise.
REGISTER_OP("Sub")
.BINARY_FEWER()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x - y element-wise.
)doc");
@ -388,12 +458,14 @@ Returns x - y element-wise.
REGISTER_OP("Mul")
.BINARY_MORE()
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x * y element-wise.
)doc");
REGISTER_OP("Div")
.BINARY_MORE()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns x / y element-wise.
)doc");
@ -401,6 +473,7 @@ Returns x / y element-wise.
REGISTER_OP("SquaredDifference")
.BINARY_FEWER()
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns (x - y)(x - y) element-wise.
)doc");
@ -414,6 +487,7 @@ REGISTER_OP("Maximum")
.Output("z: T")
.Attr("T: {half, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns the max of x and y (i.e. x > y ? x : y) element-wise, broadcasts.
)doc");
@ -424,6 +498,7 @@ REGISTER_OP("Minimum")
.Output("z: T")
.Attr("T: {half, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns the min of x and y (i.e. x < y ? x : y) element-wise, broadcasts.
)doc");
@ -433,6 +508,7 @@ REGISTER_OP("Mod")
.Input("y: T")
.Output("z: T")
.Attr("T: {int32, int64, float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Returns element-wise remainder of division.
)doc");
@ -442,6 +518,7 @@ REGISTER_OP("Pow")
.Input("y: T")
.Output("z: T")
.Attr("T: {half, float, double, int32, int64, complex64, complex128}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Computes the power of one value to another.
@ -460,6 +537,7 @@ REGISTER_OP("Igammac")
.Input("x: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the upper regularized incomplete Gamma function `Q(a, x)`.
@ -483,6 +561,7 @@ REGISTER_OP("Igamma")
.Input("x: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the lower regularized incomplete Gamma function `Q(a, x)`.
@ -506,6 +585,7 @@ REGISTER_OP("Zeta")
.Input("q: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
@ -521,6 +601,7 @@ REGISTER_OP("Polygamma")
.Input("x: T")
.Output("z: T")
.Attr("T: {float, double}")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Compute the polygamma function \\(\psi^{(n)}(x)\\).
@ -536,8 +617,12 @@ where \\(\psi(x)\\) is the digamma function.
// Declares cwise binary comparison operations signature: 't, 't -> bool,
// where 't has a natural total order.
#define COMPARISON() \
Input("x: T").Input("y: T").Output("z: bool").Attr("T: realnumbertype")
#define COMPARISON() \
Input("x: T") \
.Input("y: T") \
.Output("z: bool") \
.Attr("T: realnumbertype") \
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
REGISTER_OP("Less")
.COMPARISON()
@ -567,10 +652,16 @@ Returns the truth value of (x >= y) element-wise.
// --------------------------------------------------------------------------
#define EQUALITY_COMPARISON() \
Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \
"T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " \
"quint8, qint8, qint32, string, bool, complex128}")
#define EQUALITY_COMPARISON() \
Input("x: T") \
.Input("y: T") \
.Output("z: bool") \
.SetIsCommutative() \
.Attr( \
"T: {half, float, double, uint8, int8, int16, int32, int64, " \
"complex64, " \
"quint8, qint8, qint32, string, bool, complex128}") \
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
REGISTER_OP("Equal")
.EQUALITY_COMPARISON()
@ -596,8 +687,12 @@ REGISTER_OP("LogicalNot")
Returns the truth value of NOT x element-wise.
)doc");
#define BINARY_LOGICAL() \
Input("x: bool").Input("y: bool").Output("z: bool").SetIsCommutative()
#define BINARY_LOGICAL() \
Input("x: bool") \
.Input("y: bool") \
.Output("z: bool") \
.SetIsCommutative() \
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
REGISTER_OP("LogicalAnd")
.BINARY_LOGICAL()
@ -1271,6 +1366,7 @@ REGISTER_OP("Complex")
.Output("out: Tout")
.Attr("T: {float, double} = DT_FLOAT")
.Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn))
.Doc(R"doc(
Converts two real numbers to a complex number.

View File

@ -121,4 +121,52 @@ TEST(MathOpsTest, Segment) {
}
}
TEST(MathOpsTest, BroadcastBinaryOps) {
for (const auto* op : {"Add", "Complex",
"Div", "Equal",
"Greater", "GreaterEqual",
"Igamma", "Igammac",
"Zeta", "Polygamma",
"Less", "LessEqual",
"LogicalAnd", "LogicalOr",
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
"Sub", "SquaredDifference"}) {
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
INFER_OK(op, "?;[1,2]", "?");
INFER_OK(op, "[?];[1]", "[d0_0]");
INFER_OK(op, "[1];[?]", "[d1_0]");
INFER_OK(op, "[?];[2]", "[d1_0]");
INFER_OK(op, "[2];[?]", "[d0_0]");
INFER_OK(op, "[?];[?]", "[?]");
INFER_OK(op, "[];[?]", "[d1_0]");
INFER_OK(op, "[?];[]", "[d0_0]");
INFER_OK(op, "[1];[1]", "[d0_0|d1_0]");
INFER_OK(op, "[];[1]", "[d1_0]");
INFER_OK(op, "[1];[]", "[d0_0]");
INFER_OK(op, "[2];[2]", "[d0_0|d1_0]");
INFER_OK(op, "[];[2]", "[d1_0]");
INFER_OK(op, "[1];[2]", "[d1_0]");
INFER_OK(op, "[2];[1]", "[d0_0]");
INFER_OK(op, "[2];[]", "[d0_0]");
INFER_OK(op, "[0];[0]", "[d0_0|d1_0]");
INFER_OK(op, "[];[0]", "[d1_0]");
INFER_OK(op, "[1];[0]", "[d1_0]");
INFER_OK(op, "[0];[1]", "[d0_0]");
INFER_OK(op, "[0];[]", "[d0_0]");
// Multiple dimension cases (same test cases, switching x and y).
INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]",
"[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]");
INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]",
"[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]");
}
}
} // end namespace tensorflow