diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc index 44c57a6e1b2..abad82d9fc0 100644 --- a/tensorflow/core/kernels/cwise_ops_common.cc +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -62,15 +62,15 @@ static TensorShape ToShape(const BCast::Vec& vec) { BinaryOpShared::BinaryOpState::BinaryOpState(OpKernelContext* ctx) : in0(ctx->input(0)), in1(ctx->input(1)), - bcast(FromShape(in0.shape()), FromShape(in1.shape())) { + bcast(BCast::FromShape(in0.shape()), BCast::FromShape(in1.shape())) { if (!bcast.IsValid()) { ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", in0.shape().DebugString(), " vs. ", in1.shape().DebugString())); return; } - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, ToShape(bcast.output_shape()), &out)); + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, BCast::ToShape(bcast.output_shape()), &out)); out_num_elements = out->NumElements(); in0_num_elements = in0.NumElements(); in1_num_elements = in1.NumElements(); diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc index c045ee902b1..d49512819cc 100644 --- a/tensorflow/core/util/bcast.cc +++ b/tensorflow/core/util/bcast.cc @@ -21,29 +21,29 @@ namespace tensorflow { /* static */ void BCast::Reverse(Vec* shape) { std::reverse(shape->begin(), shape->end()); } -BCast::BCast(const Vec& sx, const Vec& sy) { - if (sx == sy) { +BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) { + if (sx == sy && TF_PREDICT_TRUE(fewer_dims_optimization)) { // Fast path for common case of identical shapes for sx and sy int64 elements = 1; const int n = sx.size(); output_.resize(n); for (int i = 0; i < n; i++) { - int64 dim = sx[i]; + const int64 dim = sx[i]; elements *= dim; output_[i] = dim; } + result_.push_back(elements); x_reshape_.push_back(elements); y_reshape_.push_back(elements); x_bcast_.push_back(1); y_bcast_.push_back(1); - result_.push_back(elements); // grad_x_reduce_ and grad_y_reduce_ are left as empty } else { // Reverse the shape of x and y for convenience. // After the reverse, 0-th is the inner-most dimension. Vec x = sx; - Reverse(&x); Vec y = sy; + Reverse(&x); Reverse(&y); // 1-extend and align x and y so that they are the same size. @@ -108,11 +108,18 @@ BCast::BCast(const Vec& sx, const Vec& sy) { // Both side are 1s. grad_x_reduce_idx_.push_back(n - 1 - i); grad_y_reduce_idx_.push_back(n - 1 - i); + if (!TF_PREDICT_TRUE(fewer_dims_optimization)) { + result_.push_back(o_i); + x_reshape_.push_back(x_i); + x_bcast_.push_back(bx_i); + y_reshape_.push_back(y_i); + y_bcast_.push_back(by_i); + } continue; - } else if (prev == curr) { - // It is a run of the same cases (no broadcast, x broadcast to - // y, y broadcast to x). We can reshape the input so that fewer - // dimensions are involved in the intermediate computation. + } else if (TF_PREDICT_TRUE(fewer_dims_optimization) && prev == curr) { + // It is a run of the same cases(no broadcast, x broadcast to y, y + // broadcast to x). We can reshape the input so that fewer dimensions + // are involved in the intermediate computation. result_.back() *= o_i; x_reshape_.back() *= x_i; x_bcast_.back() *= bx_i; @@ -150,4 +157,18 @@ BCast::BCast(const Vec& sx, const Vec& sy) { } } +BCast::Vec BCast::FromShape(const TensorShape& shape) { + const int N = shape.dims(); + BCast::Vec ret(N); + for (int i = 0; i < N; ++i) { + ret[i] = shape.dim_size(i); + } + return ret; +} + +TensorShape BCast::ToShape(const BCast::Vec& vec) { + TensorShape shape(vec); + return shape; +} + } // end namespace tensorflow diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h index 1eeb97778f1..45584bb37fd 100644 --- a/tensorflow/core/util/bcast.h +++ b/tensorflow/core/util/bcast.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -73,7 +74,15 @@ class BCast { // it's more convenient to manipulate Vec directly for this module. typedef gtl::InlinedVector Vec; - BCast(const Vec& x, const Vec& y); + // Constructs all helper shapes, following the aforementioned rules. + // + // If "fewer_dims_optimization" is set to true (the default), the + // implementation tries to reduce intermediate dimensions needed to be more + // efficient. This is transparent to the caller. + // + // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have + // the same number of dimensions as the larger of the two inputs. + BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true); ~BCast() {} // Returns true iff two operands are compatible according to the @@ -92,6 +101,10 @@ class BCast { const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; } const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; } + // Static helpers. + static Vec FromShape(const TensorShape& shape); + static TensorShape ToShape(const BCast::Vec& vec); + private: bool valid_ = true; Vec x_reshape_; diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc index a9602e6db8d..16f75e19295 100644 --- a/tensorflow/core/util/bcast_test.cc +++ b/tensorflow/core/util/bcast_test.cc @@ -23,8 +23,9 @@ limitations under the License. namespace tensorflow { namespace { -string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y) { - tensorflow::BCast b(x, y); +string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y, + const bool fewer_dims_optimization = true) { + tensorflow::BCast b(x, y, fewer_dims_optimization); if (!b.IsValid()) { return "invalid"; } @@ -43,10 +44,13 @@ string BCast(const tensorflow::BCast::Vec& x, const tensorflow::BCast::Vec& y) { } TEST(BCastTest, Invalid) { - EXPECT_EQ("invalid", BCast({5, 3, 2}, {3})); - EXPECT_EQ("invalid", BCast({5, 3, 2}, {2, 2})); - EXPECT_EQ("invalid", BCast({5, 3, 2}, {10, 1, 1})); - EXPECT_EQ("invalid", BCast({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1})); + for (const bool use_optimization : {true, false}) { + EXPECT_EQ("invalid", BCast({5, 3, 2}, {3}, use_optimization)); + EXPECT_EQ("invalid", BCast({5, 3, 2}, {2, 2}, use_optimization)); + EXPECT_EQ("invalid", BCast({5, 3, 2}, {10, 1, 1}, use_optimization)); + EXPECT_EQ("invalid", + BCast({1, 2, 1, 2, 1, 2}, {2, 4, 2, 1, 2, 1}, use_optimization)); + } } TEST(BCastTest, Basic_SameShape) { @@ -56,6 +60,12 @@ TEST(BCastTest, Basic_SameShape) { "[2310]" "[11,7,5,3,2]" "[][]"); + + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {11, 7, 5, 3, 2}, false), + "[11,7,5,3,2][1,1,1,1,1][11,7,5,3,2][1,1,1,1,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[][]"); } TEST(BCastTest, Basic_SameShapeWithZeroDim) { @@ -65,6 +75,12 @@ TEST(BCastTest, Basic_SameShapeWithZeroDim) { "[0]" "[11,7,0,3,2]" "[][]"); + + EXPECT_EQ(BCast({11, 7, 0, 3, 2}, {11, 7, 0, 3, 2}, false), + "[11,7,0,3,2][1,1,1,1,1][11,7,0,3,2][1,1,1,1,1]" + "[11,7,0,3,2]" + "[11,7,0,3,2]" + "[][]"); } TEST(BCastTest, Basic_Scalar_Scalar) { @@ -76,12 +92,24 @@ TEST(BCastTest, Basic_Scalar_Scalar) { "[1,1]" "[0,1][0,1]"); + EXPECT_EQ(BCast({1, 1}, {1}, false), + "[1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1]"); + // [1] [1, 1] EXPECT_EQ(BCast({1}, {1, 1}), "[1][1][1][1]" "[1]" "[1,1]" "[0,1][0,1]"); + + EXPECT_EQ(BCast({1}, {1, 1}, false), + "[1,1][1,1][1,1][1,1]" + "[1,1]" + "[1,1]" + "[0,1][0,1]"); } TEST(BCastTest, Basic_Tensor_Scalar) { @@ -93,12 +121,24 @@ TEST(BCastTest, Basic_Tensor_Scalar) { "[11,7,5,3,2]" "[][0,1,2,3,4]"); + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {1}, false), + "[11,7,5,3,2][1,1,1,1,1][1,1,1,1,1][11,7,5,3,2]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[][0,1,2,3,4]"); + // [1] [11, 7, 5, 3, 2] EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}), "[1][2310][2310][1]" "[2310]" "[11,7,5,3,2]" "[0,1,2,3,4][]"); + + EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2}, false), + "[1,1,1,1,1][11,7,5,3,2][11,7,5,3,2][1,1,1,1,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[0,1,2,3,4][]"); } TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) { @@ -110,6 +150,12 @@ TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) { "[11,7,5,3,2,1]" "[5][0,1,2,3,4,5]"); + EXPECT_EQ(BCast({11, 7, 5, 3, 2, 1}, {1}, false), + "[11,7,5,3,2,1][1,1,1,1,1,1][1,1,1,1,1,1][11,7,5,3,2,1]" + "[11,7,5,3,2,1]" + "[11,7,5,3,2,1]" + "[5][0,1,2,3,4,5]"); + // [1] [11, 7, 5, 3, 2, 1] EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}), "[1][2310][2310][1]" @@ -117,6 +163,12 @@ TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) { "[11,7,5,3,2,1]" "[0,1,2,3,4,5][5]"); + EXPECT_EQ(BCast({1}, {11, 7, 5, 3, 2, 1}, false), + "[1,1,1,1,1,1][11,7,5,3,2,1][11,7,5,3,2,1][1,1,1,1,1,1]" + "[11,7,5,3,2,1]" + "[11,7,5,3,2,1]" + "[0,1,2,3,4,5][5]"); + // Effectively it's a tensor and a scalar. // [11, 7, 5, 1, 1, 3, 2, 1] [1] EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}), @@ -125,12 +177,26 @@ TEST(BCastTest, Basic_Tensor_With_DimSize_1_Scalar) { "[11,7,5,1,1,3,2,1,1]" "[3,4,7,8][0,1,2,3,4,5,6,7,8]"); + EXPECT_EQ(BCast({11, 7, 5, 1, 1, 3, 2, 1, 1}, {1}, false), + "[11,7,5,1,1,3,2,1,1][1,1,1,1,1,1,1,1,1]" // x_reshape(), x_bcast() + "[1,1,1,1,1,1,1,1,1][11,7,5,1,1,3,2,1,1]" // y_reshape(), y_bcast() + "[11,7,5,1,1,3,2,1,1]" + "[11,7,5,1,1,3,2,1,1]" + "[3,4,7,8][0,1,2,3,4,5,6,7,8]"); + // [1] [11, 7, 5, 1, 1, 3, 2, 1] EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}), "[1][2310][2310][1]" "[2310]" "[11,7,5,1,1,3,2,1,1]" "[0,1,2,3,4,5,6,7,8][3,4,7,8]"); + + EXPECT_EQ(BCast({1}, {11, 7, 5, 1, 1, 3, 2, 1, 1}, false), + "[1,1,1,1,1,1,1,1,1][11,7,5,1,1,3,2,1,1]" // x_reshape(), x_bcast() + "[11,7,5,1,1,3,2,1,1][1,1,1,1,1,1,1,1,1]" // y_reshape(), y_bcast() + "[11,7,5,1,1,3,2,1,1]" + "[11,7,5,1,1,3,2,1,1]" + "[0,1,2,3,4,5,6,7,8][3,4,7,8]"); } TEST(BCastTest, Basic_Tensor_Vector) { @@ -141,12 +207,24 @@ TEST(BCastTest, Basic_Tensor_Vector) { "[11,7,5,3,2]" "[][0,1,2,3]"); + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {2}, false), + "[11,7,5,3,2][1,1,1,1,1][1,1,1,1,2][11,7,5,3,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[][0,1,2,3]"); + // [2] [11, 7, 5, 3, 2] EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}), "[1,2][1155,1][1155,2][1,1]" "[1155,2]" "[11,7,5,3,2]" "[0,1,2,3][]"); + + EXPECT_EQ(BCast({2}, {11, 7, 5, 3, 2}, false), + "[1,1,1,1,2][11,7,5,3,1][11,7,5,3,2][1,1,1,1,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[0,1,2,3][]"); } TEST(BCastTest, Basic_Tensor_Matrix) { @@ -156,12 +234,25 @@ TEST(BCastTest, Basic_Tensor_Matrix) { "[385,6]" "[11,7,5,3,2]" "[][0,1,2]"); + + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 2}, false), + "[11,7,5,3,2][1,1,1,1,1][1,1,1,3,2][11,7,5,1,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[][0,1,2]"); + // [3, 2] [11, 7, 5, 3, 2] EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}), "[1,6][385,1][385,6][1,1]" "[385,6]" "[11,7,5,3,2]" "[0,1,2][]"); + + EXPECT_EQ(BCast({3, 2}, {11, 7, 5, 3, 2}, false), + "[1,1,1,3,2][11,7,5,1,1][11,7,5,3,2][1,1,1,1,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[0,1,2][]"); } TEST(BCastTest, Basic_Tensor_Matrix_Column) { @@ -172,12 +263,24 @@ TEST(BCastTest, Basic_Tensor_Matrix_Column) { "[11,7,5,3,2]" "[][0,1,2,4]"); + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {3, 1}, false), + "[11,7,5,3,2][1,1,1,1,1][1,1,1,3,1][11,7,5,1,2]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[][0,1,2,4]"); + // [3, 1] [11, 7, 5, 3, 2] EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}), "[1,3,1][385,1,2][385,3,2][1,1,1]" "[385,3,2]" "[11,7,5,3,2]" "[0,1,2,4][]"); + + EXPECT_EQ(BCast({3, 1}, {11, 7, 5, 3, 2}, false), + "[1,1,1,3,1][11,7,5,1,2][11,7,5,3,2][1,1,1,1,1]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[0,1,2,4][]"); } TEST(BCastTest, Basic_Tensor_Matrix_As_Tensor) { @@ -188,12 +291,23 @@ TEST(BCastTest, Basic_Tensor_Matrix_As_Tensor) { "[11,7,5,3,2]" "[][0,3,4]"); + EXPECT_EQ(BCast({11, 7, 5, 3, 2}, {7, 5, 1, 1}, false), + "[11,7,5,3,2][1,1,1,1,1][1,7,5,1,1][11,1,1,3,2]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[][0,3,4]"); + // [7, 5, 1, 1] [11, 7, 5, 3, 2] EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}), "[1,35,1][11,1,6][11,35,6][1,1,1]" "[11,35,6]" "[11,7,5,3,2]" "[0,3,4][]"); + + EXPECT_EQ(BCast({7, 5, 1, 1}, {11, 7, 5, 3, 2}, false), + "[1,7,5,1,1][11,1,1,3,2][11,7,5,3,2][1,1,1,1,1]" + "[11,7,5,3,2][11,7,5,3,2]" + "[0,3,4][]"); } TEST(BCastTest, Complex_BCast_To_Each_Other) { @@ -205,14 +319,18 @@ TEST(BCastTest, Complex_BCast_To_Each_Other) { // y = np.arange(0,21).reshape([7,1,3,1]) // np.shape(x + y) // Out[.]: (11, 7, 5, 3, 2) - EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}), - "[11,1,5,1,2][1,7,1,3,1][1,7,1,3,1][11,1,5,1,2]" - "[11,7,5,3,2]" - "[11,7,5,3,2]" - "[1,3][0,2,4]"); + string truth = + "[11,1,5,1,2][1,7,1,3,1][1,7,1,3,1][11,1,5,1,2]" + "[11,7,5,3,2]" + "[11,7,5,3,2]" + "[1,3][0,2,4]"; + + EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}), truth); + EXPECT_EQ(BCast({11, 1, 5, 1, 2}, {7, 1, 3, 1}, false), truth); } TEST(BCastTest, TestZeroDimensionShape) { + // (2,0,5) and (5) in both orders EXPECT_EQ(BCast({2, 0, 5}, {5}), "[0,5][1,1][1,5][0,1]" "[0,5]" @@ -224,6 +342,18 @@ TEST(BCastTest, TestZeroDimensionShape) { "[2,0,5]" "[0,1][]"); + EXPECT_EQ(BCast({2, 0, 5}, {5}, false), + "[2,0,5][1,1,1][1,1,5][2,0,1]" + "[2,0,5]" + "[2,0,5]" + "[][0,1]"); + EXPECT_EQ(BCast({5}, {2, 0, 5}, false), + "[1,1,5][2,0,1][2,0,5][1,1,1]" + "[2,0,5]" + "[2,0,5]" + "[0,1][]"); + + // (2,0,3,0,5) and (5) in both orders EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}), "[0,5][1,1][1,5][0,1]" "[0,5]" @@ -235,6 +365,18 @@ TEST(BCastTest, TestZeroDimensionShape) { "[2,0,3,0,5]" "[0,1,2,3][]"); + EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {5}, false), + "[2,0,3,0,5][1,1,1,1,1][1,1,1,1,5][2,0,3,0,1]" + "[2,0,3,0,5]" + "[2,0,3,0,5]" + "[][0,1,2,3]"); + EXPECT_EQ(BCast({5}, {2, 0, 3, 0, 5}, false), + "[1,1,1,1,5][2,0,3,0,1][2,0,3,0,5][1,1,1,1,1]" + "[2,0,3,0,5]" + "[2,0,3,0,5]" + "[0,1,2,3][]"); + + // (2,0,3,0,5) and (3,1,5) in both orders EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}), "[0,3,0,5][1,1,1,1][1,3,1,5][0,1,0,1]" "[0,3,0,5]" @@ -245,6 +387,17 @@ TEST(BCastTest, TestZeroDimensionShape) { "[0,3,0,5]" "[2,0,3,0,5]" "[0,1,3][]"); + + EXPECT_EQ(BCast({2, 0, 3, 0, 5}, {3, 1, 5}, false), + "[2,0,3,0,5][1,1,1,1,1][1,1,3,1,5][2,0,1,0,1]" + "[2,0,3,0,5]" + "[2,0,3,0,5]" + "[][0,1,3]"); + EXPECT_EQ(BCast({3, 1, 5}, {2, 0, 3, 0, 5}, false), + "[1,1,3,1,5][2,0,1,0,1][2,0,3,0,5][1,1,1,1,1]" + "[2,0,3,0,5]" + "[2,0,3,0,5]" + "[0,1,3][]"); } static void BM_BCastSetup(int iters, int same_shape) {