Add support for PadOp in HLO exporter.
Generic support cannot be used for PadOp because hlo dialect representation breaks the `PaddingConfig` Pad HLO instruction into three separate attributes of int64 vectors in hlo dialect. PiperOrigin-RevId: 281399246 Change-Id: Iab6f9b5fff925c9a3936ad8714d68c49b8893c00
This commit is contained in:
parent
0231137ed0
commit
15d103c409
tensorflow/compiler/mlir/xla
@ -446,7 +446,22 @@ LogicalResult ExportXlaOp(IotaOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { return failure(); }
|
||||
LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
xla::PaddingConfig padding_config;
|
||||
auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low());
|
||||
auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high());
|
||||
auto interior_padding = ConvertDenseIntAttr(op.interior_padding());
|
||||
for (xla::int64 i = 0; i < edge_padding_low.size(); ++i) {
|
||||
auto* dims = padding_config.add_dimensions();
|
||||
dims->set_edge_padding_low(edge_padding_low[i]);
|
||||
dims->set_edge_padding_high(edge_padding_high[i]);
|
||||
dims->set_interior_padding(interior_padding[i]);
|
||||
}
|
||||
value_map[op] = xla::Pad(value_map[op.getOperand(0)],
|
||||
value_map[op.getOperand(1)], padding_config);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(ReduceOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
|
13
tensorflow/compiler/mlir/xla/tests/translate/pad.mlir
Normal file
13
tensorflow/compiler/mlir/xla/tests/translate/pad.mlir
Normal file
@ -0,0 +1,13 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
func @main(%arg: tensor<4x6xf32>, %pad: tensor<f32>) -> tensor<13x19xf32> {
|
||||
%0 = "xla_hlo.pad"(%arg, %pad) {edge_padding_high = dense<[4,5]> : tensor<2xi64>, edge_padding_low = dense<[2,3]> : tensor<2xi64>, interior_padding = dense<1> : tensor<2xi64>} : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
|
||||
return %0 : tensor<13x19xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: [[ARG:%.*]] = f32[4,6] parameter(0)
|
||||
// CHECK: [[PADDING_VAL:%.*]] = f32[] parameter(1)
|
||||
// CHECK-LABEL: ROOT
|
||||
// CHECK-SAME: f32[13,19] pad(f32[4,6] [[ARG]], f32[] [[PADDING_VAL]]), padding=2_4_1x3_5_1
|
||||
// CHECK: }
|
Loading…
Reference in New Issue
Block a user