From 1f120e3f004e6444095b26f851f69ff86b4c346d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Jun 2016 08:08:47 -0800 Subject: [PATCH] Add C++ shape inference function for broadcasting binary ops. This is the same logic as python, with an addition in the case where 1 value is unknown and the other is unknown - in this case, we propagate the unknown input dim instead of a new unknown input dim (this case did not apply in python where None was used for the unknown input). Change: 126308395 --- tensorflow/core/ops/math_ops.cc | 112 +++++++++++++++++++++++++-- tensorflow/core/ops/math_ops_test.cc | 48 ++++++++++++ 2 files changed, 152 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 51377b6cf08..0f9ee4942aa 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -47,6 +47,74 @@ Add all input tensors element wise. inputs: Must all be the same size and shape. )doc"); +namespace { + +// Shape inference function for binary operators that broadcast their inputs. +Status BroadcastBinaryOpShapeFn(InferenceContext* c) { + const Shape* shape_x = c->input(0); + const Shape* shape_y = c->input(1); + if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { + c->set_output(0, c->CreateUnknownShape()); + return Status::OK(); + } + const int32 rank_x = c->Rank(shape_x); + const int32 rank_y = c->Rank(shape_y); + const int32 rank_out = std::max(rank_x, rank_y); + + // To compute the broadcast dimensions, we zip together shape_x and shape_y + // and + // pad with 1 to make them the same length. + std::vector dims; + const Dimension* dim_one = rank_x == rank_y ? nullptr : c->CreateDim(1); + for (int i = 0; i < rank_out; ++i) { + const auto* dim_x = i < (rank_out - rank_x) + ? dim_one + : c->Dim(shape_x, i - (rank_out - rank_x)); + const auto* dim_y = i < (rank_out - rank_y) + ? dim_one + : c->Dim(shape_y, i - (rank_out - rank_y)); + if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) { + // One or both dimensions is unknown. + // + // - If either dimension is greater than 1, we assume that the program is + // correct, and the other dimension will be broadcast to match it. + // TODO(cwhipkey): For shape inference, if we eliminate the shape checks + // in C++ op code, we must still assert that the unknown dim is either 1 + // or the same as the known dim. + // - If either dimension is 1, the other dimension is the output. + if (c->Value(dim_x) > 1) { + dims.push_back(dim_x); + } else if (c->Value(dim_y) > 1) { + dims.push_back(dim_y); + } else if (c->Value(dim_x) == 1) { + dims.push_back(dim_y); + } else if (c->Value(dim_y) == 1) { + dims.push_back(dim_x); + } else { + dims.push_back(c->CreateUnknownDim()); + } + } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) { + if (c->Value(dim_x) == 1 && dim_y != dim_one) { + // We will broadcast dim_x to dim_y. + dims.push_back(dim_y); + } else { + DCHECK_EQ(c->Value(dim_y), 1); + // We will broadcast dim_y to dim_x. + dims.push_back(dim_x); + } + } else { + const Dimension* dim; + TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim)); + dims.push_back(dim); + } + } + + c->set_output(0, c->CreateShape(dims)); + return Status::OK(); +} + +} // namespace + // -------------------------------------------------------------------------- REGISTER_OP("BatchMatMul") @@ -373,6 +441,7 @@ REGISTER_OP("Add") .Attr( "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " "complex128, string}") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns x + y element-wise. @@ -381,6 +450,7 @@ Returns x + y element-wise. REGISTER_OP("Sub") .BINARY_FEWER() + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns x - y element-wise. )doc"); @@ -388,12 +458,14 @@ Returns x - y element-wise. REGISTER_OP("Mul") .BINARY_MORE() .SetIsCommutative() + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns x * y element-wise. )doc"); REGISTER_OP("Div") .BINARY_MORE() + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns x / y element-wise. )doc"); @@ -401,6 +473,7 @@ Returns x / y element-wise. REGISTER_OP("SquaredDifference") .BINARY_FEWER() .SetIsCommutative() + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns (x - y)(x - y) element-wise. )doc"); @@ -414,6 +487,7 @@ REGISTER_OP("Maximum") .Output("z: T") .Attr("T: {half, float, double, int32, int64}") .SetIsCommutative() + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns the max of x and y (i.e. x > y ? x : y) element-wise, broadcasts. )doc"); @@ -424,6 +498,7 @@ REGISTER_OP("Minimum") .Output("z: T") .Attr("T: {half, float, double, int32, int64}") .SetIsCommutative() + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns the min of x and y (i.e. x < y ? x : y) element-wise, broadcasts. )doc"); @@ -433,6 +508,7 @@ REGISTER_OP("Mod") .Input("y: T") .Output("z: T") .Attr("T: {int32, int64, float, double}") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Returns element-wise remainder of division. )doc"); @@ -442,6 +518,7 @@ REGISTER_OP("Pow") .Input("y: T") .Output("z: T") .Attr("T: {half, float, double, int32, int64, complex64, complex128}") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Computes the power of one value to another. @@ -460,6 +537,7 @@ REGISTER_OP("Igammac") .Input("x: T") .Output("z: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Compute the upper regularized incomplete Gamma function `Q(a, x)`. @@ -483,6 +561,7 @@ REGISTER_OP("Igamma") .Input("x: T") .Output("z: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Compute the lower regularized incomplete Gamma function `Q(a, x)`. @@ -506,6 +585,7 @@ REGISTER_OP("Zeta") .Input("q: T") .Output("z: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Compute the Hurwitz zeta function \\(\zeta(x, q)\\). @@ -521,6 +601,7 @@ REGISTER_OP("Polygamma") .Input("x: T") .Output("z: T") .Attr("T: {float, double}") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Compute the polygamma function \\(\psi^{(n)}(x)\\). @@ -536,8 +617,12 @@ where \\(\psi(x)\\) is the digamma function. // Declares cwise binary comparison operations signature: 't, 't -> bool, // where 't has a natural total order. -#define COMPARISON() \ - Input("x: T").Input("y: T").Output("z: bool").Attr("T: realnumbertype") +#define COMPARISON() \ + Input("x: T") \ + .Input("y: T") \ + .Output("z: bool") \ + .Attr("T: realnumbertype") \ + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) REGISTER_OP("Less") .COMPARISON() @@ -567,10 +652,16 @@ Returns the truth value of (x >= y) element-wise. // -------------------------------------------------------------------------- -#define EQUALITY_COMPARISON() \ - Input("x: T").Input("y: T").Output("z: bool").SetIsCommutative().Attr( \ - "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " \ - "quint8, qint8, qint32, string, bool, complex128}") +#define EQUALITY_COMPARISON() \ + Input("x: T") \ + .Input("y: T") \ + .Output("z: bool") \ + .SetIsCommutative() \ + .Attr( \ + "T: {half, float, double, uint8, int8, int16, int32, int64, " \ + "complex64, " \ + "quint8, qint8, qint32, string, bool, complex128}") \ + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) REGISTER_OP("Equal") .EQUALITY_COMPARISON() @@ -596,8 +687,12 @@ REGISTER_OP("LogicalNot") Returns the truth value of NOT x element-wise. )doc"); -#define BINARY_LOGICAL() \ - Input("x: bool").Input("y: bool").Output("z: bool").SetIsCommutative() +#define BINARY_LOGICAL() \ + Input("x: bool") \ + .Input("y: bool") \ + .Output("z: bool") \ + .SetIsCommutative() \ + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) REGISTER_OP("LogicalAnd") .BINARY_LOGICAL() @@ -1271,6 +1366,7 @@ REGISTER_OP("Complex") .Output("out: Tout") .Attr("T: {float, double} = DT_FLOAT") .Attr("Tout: {complex64, complex128} = DT_COMPLEX64") + .SetShapeFn(OpShapeInferenceFn(BroadcastBinaryOpShapeFn)) .Doc(R"doc( Converts two real numbers to a complex number. diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index c75b0310a28..0aee63ce1d3 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -121,4 +121,52 @@ TEST(MathOpsTest, Segment) { } } +TEST(MathOpsTest, BroadcastBinaryOps) { + for (const auto* op : {"Add", "Complex", + "Div", "Equal", + "Greater", "GreaterEqual", + "Igamma", "Igammac", + "Zeta", "Polygamma", + "Less", "LessEqual", + "LogicalAnd", "LogicalOr", + "Maximum", "Minimum", + "Mod", "Mul", + "NotEqual", "Pow", + "Sub", "SquaredDifference"}) { + INFER_OK(op, "?;?", "?"); + INFER_OK(op, "[1,2];?", "?"); + INFER_OK(op, "?;[1,2]", "?"); + + INFER_OK(op, "[?];[1]", "[d0_0]"); + INFER_OK(op, "[1];[?]", "[d1_0]"); + INFER_OK(op, "[?];[2]", "[d1_0]"); + INFER_OK(op, "[2];[?]", "[d0_0]"); + INFER_OK(op, "[?];[?]", "[?]"); + INFER_OK(op, "[];[?]", "[d1_0]"); + INFER_OK(op, "[?];[]", "[d0_0]"); + + INFER_OK(op, "[1];[1]", "[d0_0|d1_0]"); + INFER_OK(op, "[];[1]", "[d1_0]"); + INFER_OK(op, "[1];[]", "[d0_0]"); + + INFER_OK(op, "[2];[2]", "[d0_0|d1_0]"); + INFER_OK(op, "[];[2]", "[d1_0]"); + INFER_OK(op, "[1];[2]", "[d1_0]"); + INFER_OK(op, "[2];[1]", "[d0_0]"); + INFER_OK(op, "[2];[]", "[d0_0]"); + + INFER_OK(op, "[0];[0]", "[d0_0|d1_0]"); + INFER_OK(op, "[];[0]", "[d1_0]"); + INFER_OK(op, "[1];[0]", "[d1_0]"); + INFER_OK(op, "[0];[1]", "[d0_0]"); + INFER_OK(op, "[0];[]", "[d0_0]"); + + // Multiple dimension cases (same test cases, switching x and y). + INFER_OK(op, "[?,1,2,3,4,5];[3,1,?]", + "[d0_0,d0_1,d0_2,d0_3|d1_0,d0_4,d0_5]"); + INFER_OK(op, "[3,1,?];[?,1,2,3,4,5]", + "[d1_0,d1_1,d1_2,d1_3|d0_0,d1_4,d1_5]"); + } +} + } // end namespace tensorflow