Improve BroadcastTo() to also support trivially broadcasting 1 -> n (as well as n -> n). Then, remove special casing in tile_ops.cc.
In all cases that would be triggered by the special case, it now will build a BroadcastInDim() such that broadcast_shape == output_dims and the xla::Reshape will not be triggered. Thus, the special case is not needed. PiperOrigin-RevId: 273015433
This commit is contained in:
parent
c5c69cd6a8
commit
441d8667ac
@ -79,29 +79,6 @@ class TileOp : public XlaOpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool can_tile_with_implicit_broadcast = true;
|
|
||||||
for (int i = 0; i < input_dims; ++i) {
|
|
||||||
int64 multiple = multiples[i];
|
|
||||||
// If the multiple and input dimension are not 1, then tile cannot be
|
|
||||||
// implemented with a single hlo broadcast.
|
|
||||||
if (multiple != 1 && input_shape.dim_size(i) != 1) {
|
|
||||||
can_tile_with_implicit_broadcast = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (can_tile_with_implicit_broadcast) {
|
|
||||||
// Create a constant Zero the size of the output shape to leverage binary
|
|
||||||
// operation broadcast semantics.
|
|
||||||
auto broadcasted_zero = xla::Broadcast(
|
|
||||||
XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_dims);
|
|
||||||
if (ctx->input_type(0) == DT_BOOL) {
|
|
||||||
ctx->SetOutput(0, xla::Or(broadcasted_zero, input));
|
|
||||||
} else {
|
|
||||||
ctx->SetOutput(0, xla::Add(broadcasted_zero, input));
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto result = BroadcastTo(ctx->Input("input"), output_dims);
|
auto result = BroadcastTo(ctx->Input("input"), output_dims);
|
||||||
OP_REQUIRES_OK(ctx, result.status());
|
OP_REQUIRES_OK(ctx, result.status());
|
||||||
ctx->SetOutput(0, result.ValueOrDie());
|
ctx->SetOutput(0, result.ValueOrDie());
|
||||||
|
@ -61,7 +61,7 @@ xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
broadcast_dims.push_back(broadcast_shape.size());
|
broadcast_dims.push_back(broadcast_shape.size());
|
||||||
if (*output_it == *input_it) {
|
if (*output_it == *input_it || *input_it == 1) {
|
||||||
broadcast_shape.push_back(*output_it);
|
broadcast_shape.push_back(*output_it);
|
||||||
} else if (*output_it != *input_it) {
|
} else if (*output_it != *input_it) {
|
||||||
// Add dimensions [I, O/I], which we will later flatten to just
|
// Add dimensions [I, O/I], which we will later flatten to just
|
||||||
|
Loading…
Reference in New Issue
Block a user