From 441d8667ac91f64c86442d4b00077760c5869239 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 4 Oct 2019 21:47:59 -0700 Subject: [PATCH] Improve BroadcastTo() to also support trivially broadcasting 1 -> n (as well as n -> n). Then, remove special casing in tile_ops.cc. In all cases that would be triggered by the special case, it now will build a BroadcastInDim() such that broadcast_shape == output_dims and the xla::Reshape will not be triggered. Thus, the special case is not needed. PiperOrigin-RevId: 273015433 --- .../compiler/tf2xla/kernels/tile_ops.cc | 23 ------------------- tensorflow/compiler/tf2xla/lib/broadcast.cc | 2 +- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index e1c764f3d5c..e8804cae037 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -79,29 +79,6 @@ class TileOp : public XlaOpKernel { return; } - bool can_tile_with_implicit_broadcast = true; - for (int i = 0; i < input_dims; ++i) { - int64 multiple = multiples[i]; - // If the multiple and input dimension are not 1, then tile cannot be - // implemented with a single hlo broadcast. - if (multiple != 1 && input_shape.dim_size(i) != 1) { - can_tile_with_implicit_broadcast = false; - } - } - - if (can_tile_with_implicit_broadcast) { - // Create a constant Zero the size of the output shape to leverage binary - // operation broadcast semantics. - auto broadcasted_zero = xla::Broadcast( - XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_dims); - if (ctx->input_type(0) == DT_BOOL) { - ctx->SetOutput(0, xla::Or(broadcasted_zero, input)); - } else { - ctx->SetOutput(0, xla::Add(broadcasted_zero, input)); - } - return; - } - auto result = BroadcastTo(ctx->Input("input"), output_dims); OP_REQUIRES_OK(ctx, result.status()); ctx->SetOutput(0, result.ValueOrDie()); diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index a0789f982c3..7251a2e3dc6 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -61,7 +61,7 @@ xla::StatusOr BroadcastTo(xla::XlaOp input, } broadcast_dims.push_back(broadcast_shape.size()); - if (*output_it == *input_it) { + if (*output_it == *input_it || *input_it == 1) { broadcast_shape.push_back(*output_it); } else if (*output_it != *input_it) { // Add dimensions [I, O/I], which we will later flatten to just