diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 178005847c4..7867183012f 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -287,6 +287,12 @@ StatusOr HloFunctionImporter::ImportInstruction( llvm::ArrayRef(operands.begin() + 2, operands.end())) .getOperation(); } + case HloOpcode::kInfeed: { + attributes.push_back(builder_->getNamedAttr( + "infeed_config", mlir::StringAttr::get(instruction->infeed_config(), + builder_->getContext()))); + MakeAndReturn(InfeedOp); + } case HloOpcode::kPad: { const auto& padding_config = instruction->padding_config(); llvm::SmallVector edge_padding_low; @@ -448,7 +454,6 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kExp, ExpOp); NoAttributeCase(kExpm1, Expm1Op); NoAttributeCase(kFloor, FloorOp); - NoAttributeCase(kInfeed, InfeedOp); NoAttributeCase(kImag, ImagOp); NoAttributeCase(kLog, LogOp); NoAttributeCase(kLog1p, Log1pOp); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 3b867d5dd85..c535dcfb60b 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -343,7 +343,10 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> { See https://www.tensorflow.org/xla/operation_semantics#infeed. }]; - let arguments = (ins HLO_Token:$token); + let arguments = (ins + HLO_Token:$token, + DefaultValuedAttr:$infeed_config + ); let results = (outs HLO_Tuple); let hasCustomHLOConverter = 1; } diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index cb1eab1dcba..f6cc76693f7 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -558,8 +558,8 @@ LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) { // The shape argument expected by the xla client API is the type of the first // element in the result tuple. auto result_type = op.getType().cast().getType(0); - value_map[op] = xla::InfeedWithToken(value_map[op.token()], - xla::TypeToShape(result_type)); + value_map[op] = xla::InfeedWithToken( + value_map[op.token()], xla::TypeToShape(result_type), op.infeed_config()); return success(); } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index aac3a82081b..581e3e71ff6 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -398,13 +398,13 @@ func @main(%arg0: tuple, tensor>) -> tensor { // CHECK: HloModule func @main(%arg0: !xla_hlo.token) -> tuple, tensor>, !xla_hlo.token> { - %0 = "xla_hlo.infeed"(%arg0) : (!xla_hlo.token) -> tuple, tensor>, !xla_hlo.token> - return %0 : tuple, tensor>, !xla_hlo.token> - } + %0 = "xla_hlo.infeed"(%arg0) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple, tensor>, !xla_hlo.token> + return %0 : tuple, tensor>, !xla_hlo.token> +} // CHECK: ENTRY // CHECK: [[ARG:%.*]] = token[] parameter(0) -// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]]) +// CHECK: ROOT %[[RESULT:.*]] = ((s32[3], pred[]), token[]) infeed(token[] [[ARG]]), infeed_config="foobar" // ----- diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 5f9670be2f1..f425184677f 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -364,6 +364,16 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1) } +// CHECK-LABEL: func @test_infeed +// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple, !xla_hlo.token> { +%test_infeed (token0: token[]) -> (s32[3], token[]) { + %token0 = token[] parameter(0) + // CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]]) + // CHECK-SAME: infeed_config = "foobar" + ROOT %infeed = (s32[3], token[]) infeed(token[] %token0), infeed_config="foobar" +} + + // CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> { %test_iota_1 () -> f32[4] { // CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>