diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 95a010a119d..224e5ea123b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -121,6 +121,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index a988d3c33ed..47e517a6576 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -64,7 +64,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // } static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto y_equals_0 = xla::Eq(y, zero); auto zeros = xla::ZerosLike(x); @@ -84,7 +84,7 @@ XLA_MAKE_BINARY(DivNoNan, // } static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); if (DataTypeIsUnsigned(dtype)) { return xla::Div(x, y); } @@ -105,7 +105,7 @@ XLA_MAKE_BINARY(FloorDiv, static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); @@ -114,7 +114,7 @@ XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto is_zero = xla::Eq(x, zero); return xla::Select(is_zero, zero, xla::Div(x, y)); @@ -126,7 +126,7 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 696c1c39bef..9bb11fb67e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -37,59 +32,9 @@ class BroadcastToOp : public XlaOpKernel { TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); - OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), - errors::InvalidArgument( - "Input rank (", input_shape.dims(), - ") must be less than or equal to the output rank (", - output_shape.dims(), ")")); - - auto input_dims = input_shape.dim_sizes(); - auto output_dims = output_shape.dim_sizes(); - - // Broadcasting is done right-to-left on right-aligned dimensions; reverse - // the two vectors so elements to be broadcast are aligned. - absl::c_reverse(input_dims); - absl::c_reverse(output_dims); - - std::vector broadcast_dims; - std::vector broadcast_shape; - for (int i = 0; i < output_shape.dims(); ++i) { - if (i < input_shape.dims()) { - OP_REQUIRES( - context, - (output_dims[i] == 0 && input_dims[i] == 0) || - (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), - errors::InvalidArgument("invalid shape to broadcast from ", - input_shape.DebugString(), " to ", - output_shape.DebugString())); - - broadcast_dims.push_back(broadcast_shape.size()); - if (output_dims[i] == input_dims[i]) { - broadcast_shape.push_back(output_dims[i]); - } else if (output_dims[i] != input_dims[i]) { - // Add dimensions [I, O/I], which we will later flatten to just - // [O]. We must do this in two phases since XLA broadcasting does not - // support tiling. - broadcast_shape.push_back(input_dims[i]); - broadcast_shape.push_back(output_dims[i] / input_dims[i]); - } - } else { - broadcast_shape.push_back(output_dims[i]); - } - } - absl::c_reverse(broadcast_dims); - int broadcast_shape_size = broadcast_shape.size(); - for (int64& broadcast_dim : broadcast_dims) { - broadcast_dim = broadcast_shape_size - broadcast_dim - 1; - } - absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::Reshape( - xla::BroadcastInDim(context->Input(0), - xla::ShapeUtil::MakeShape( - context->input_xla_type(0), broadcast_shape), - broadcast_dims), - output_shape.dim_sizes()); - context->SetOutput(0, output); + auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output.status()); + context->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index ef1015552d1..234f7b4a019 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // compute valid broadcast shapes, but rely below on XLA to // automatically perform the broadcast assuming its valid shapes are // a superset of TensorFlow's valid shapes. - BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), + /*fewer_dims_optimization=*/false); if (!bcast.IsValid()) { ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", @@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } /* static */ std::pair XlaBinaryOp::Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper) { - // Manually construct the broadcasting since MapN does not do - // automatic broadcasting. The bcast helper ensures that - // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and - // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have - // the same shape, so can be operated on by MapN. - - // First reshape the inputs, which should be a metadata-only - // operation since we are flattening the dimensions in order. - auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); - - // Next broadcast the necessary input dimensions. We rely on the - // XLA optimizer to be smart about the fact that we are asking - // it to broadcast size 1 on some of these dimensions, to avoid - // adding complexity to this code. - auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); - int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); - int rhs_size = broadcast_helper.y_bcast().size(); - - // Now reshape them to the correct output shape. After the - // broadcast each side is twice as wide as it should be, since the - // broadcast dimensions were prepended to the shape. Reshape - // flattening each original dimension with the prepended broadcast - // dimension. E.g. if we started out with lhs_shaped with shape - // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have - // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. - std::vector lhs_reorder; - for (int i = 0; i < lhs_size; ++i) { - lhs_reorder.push_back(i); - lhs_reorder.push_back(i + lhs_size); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) { + auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape()); + if (!lhs_output.ok()) { + xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); + return {error, error}; } - auto lhs_output = - xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); - std::vector rhs_reorder; - for (int i = 0; i < rhs_size; ++i) { - rhs_reorder.push_back(i); - rhs_reorder.push_back(i + rhs_size); + auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape()); + if (!rhs_output.ok()) { + xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); + return {error, error}; } - auto rhs_output = - xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); - - return {lhs_output, rhs_output}; + return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()}; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 6653944a911..516ead4bfe8 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel { // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. static std::pair Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 8597e7f139d..1ce3930fd1c 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -31,6 +31,22 @@ cc_library( ], ) +cc_library( + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "cholesky", srcs = ["cholesky.cc"], diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc new file mode 100644 index 00000000000..3e402ef855c --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { + +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + absl::Span input_dims = + xla::AsInt64Slice(input_shape.dimensions()); + + if (input_dims == output_dims) { + return input; + } + + if (input_dims.size() > output_dims.size()) { + return errors::InvalidArgument( + "Input shape (", xla::ShapeUtil::HumanString(input_shape), + ") must have rank less than or equal to the output shape [", + absl::StrJoin(output_dims, ","), "]"); + } + + std::vector broadcast_dims; + std::vector broadcast_shape; + auto input_it = input_dims.rbegin(); + for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend(); + ++output_it) { + if (input_it != input_dims.rend()) { + if (!(*output_it == 0 && *input_it == 0) && + !(*input_it != 0 && *output_it % *input_it == 0)) { + return errors::InvalidArgument("Invalid shape broadcast from ", + xla::ShapeUtil::HumanString(input_shape), + " to [", absl::StrJoin(output_dims, ","), + "]"); + } + + broadcast_dims.push_back(broadcast_shape.size()); + if (*output_it == *input_it) { + broadcast_shape.push_back(*output_it); + } else if (*output_it != *input_it) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(*input_it); + broadcast_shape.push_back(*output_it / *input_it); + } + ++input_it; + } else { + broadcast_shape.push_back(*output_it); + } + } + TF_RET_CHECK(input_it == input_dims.rend()); + + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::BroadcastInDim( + input, + xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape), + broadcast_dims); + if (broadcast_shape != output_dims) { + output = xla::Reshape(output, output_dims); + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h new file mode 100644 index 00000000000..591e696f06b --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +xla::StatusOr BroadcastTo(xla::XlaOp input, + absl::Span output_dims); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_