[TF2XLA]Tile Op: Support dynamic multipliers.

PiperOrigin-RevId: 305920299
Change-Id: I6c06b626e82f334ca62ff2fce026f16e4b9beabd
This commit is contained in:
Yunxing Dai 2020-04-10 12:11:47 -07:00 committed by TensorFlower Gardener
parent 2947e86977
commit 8c138f1684

View File

@ -78,10 +78,28 @@ class TileOp : public XlaOpKernel {
ctx->SetOutput(0, input);
return;
}
std::vector<int64> dynamic_multiples;
ctx->set_dynamic_dimension_is_minus_one(true);
// The multiplier can be a dynamic value.
OP_REQUIRES_OK(
ctx, ctx->ConstantInputAsIntVector("multiples", &dynamic_multiples));
auto result = BroadcastTo(ctx->Input("input"), output_dims);
OP_REQUIRES_OK(ctx, result.status());
ctx->SetOutput(0, result.ValueOrDie());
auto result_or = BroadcastTo(ctx->Input("input"), output_dims);
OP_REQUIRES_OK(ctx, result_or.status());
auto result = result_or.ValueOrDie();
for (int64 i = 0; i < dynamic_multiples.size(); ++i) {
// If a dimension is dynamic, call set-dimension-size on the output.
if (dynamic_multiples[i] == -1) {
auto dynamic_dim_size =
xla::Slice(ctx->Input("multiples"), {i}, {i + 1}, {1});
dynamic_dim_size = xla::Reshape(dynamic_dim_size, {});
dynamic_dim_size = xla::ConvertElementType(dynamic_dim_size, xla::S32);
result = xla::SetDimensionSize(result, dynamic_dim_size, i);
}
}
ctx->SetOutput(0, result);
}
private: