Add support for PadOp in HLO exporter.

Generic support cannot be used for PadOp because hlo dialect representation breaks the `PaddingConfig` Pad HLO instruction into three separate attributes of int64 vectors in hlo dialect.

PiperOrigin-RevId: 281399246
Change-Id: Iab6f9b5fff925c9a3936ad8714d68c49b8893c00
This commit is contained in:
Prakalp Srivastava 2019-11-19 16:00:21 -08:00 committed by TensorFlower Gardener
parent 0231137ed0
commit 15d103c409
2 changed files with 29 additions and 1 deletions
tensorflow/compiler/mlir/xla

View File

@ -446,7 +446,22 @@ LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
return success();
}
LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { return failure(); }
LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
xla::PaddingConfig padding_config;
auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
for (xla::int64 i = 0; i < edge_padding_low.size(); ++i) {
auto* dims = padding_config.add_dimensions();
dims->set_edge_padding_low(edge_padding_low[i]);
dims->set_edge_padding_high(edge_padding_high[i]);
dims->set_interior_padding(interior_padding[i]);
}
value_map[op] = xla::Pad(value_map[op.getOperand(0)],
value_map[op.getOperand(1)], padding_config);
return success();
}
LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;

View File

@ -0,0 +1,13 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
return %0 : tensor<13x19xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
// CHECK-LABEL: ROOT
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
// CHECK: }