diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 2f0879778e5..77b5159fb1e 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -401,12 +401,18 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> { of the data. Multiple Infeed operations are allowed in a computation, but there must be a total order among the Infeed operations. + Attributes: + layout: Array attribute. Same shape as the output of the infeed, except + that every tensor is replaced by a minor_to_major array for the + tensor's layout. + See https://www.tensorflow.org/xla/operation_semantics#infeed. }]; let arguments = (ins HLO_Token:$token, - DefaultValuedAttr:$infeed_config + DefaultValuedAttr:$infeed_config, + OptionalAttr:$layout ); let results = (outs HLO_Tuple); let hasCustomHLOConverter = 1; diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir index 76d8896430c..5651f04fc65 100644 --- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -442,7 +442,7 @@ func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) - func @infeed_invalid_number_of_results(%token: !mhlo.token) -> tuple>, !mhlo.token, tensor> { // expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}} - %0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple>, !mhlo.token, tensor> + %0 = "mhlo.infeed"(%token) {infeed_config = "foobar", layout = [[[0]], unit, [0]]} : (!mhlo.token) -> tuple>, !mhlo.token, tensor> return %0 : tuple>, !mhlo.token, tensor> } @@ -450,7 +450,7 @@ func @infeed_invalid_number_of_results(%token: !mhlo.token) -> tuple tuple>, tensor> { // expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor'}} - %0 = "mhlo.infeed"(%token) {infeed_config = "foobar"} : (!mhlo.token) -> tuple>, tensor> + %0 = "mhlo.infeed"(%token) {infeed_config = "foobar", layout = [[[0]], [0]]} : (!mhlo.token) -> tuple>, tensor> return %0 : tuple>, tensor> } diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index deeb96b7d9d..55e3277afda 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -383,6 +383,19 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "infeed_config", mlir::StringAttr::get(builder_->getContext(), instruction->infeed_config()))); + // TODO(kramm): Support tuples and tokens. + if (instruction->shape().IsArray()) { + const xla::Layout l = instruction->shape().layout(); + absl::Span minor_to_major = l.minor_to_major(); + std::vector v; + for (int64 i : minor_to_major) { + v.push_back(builder_->getI32IntegerAttr(i)); + } + llvm::ArrayRef array_ref(v); + mlir::ArrayAttr layout = builder_->getArrayAttr(array_ref); + attributes.push_back(builder_->getNamedAttr("layout", layout)); + } + MakeAndReturn(InfeedOp); } case HloOpcode::kOutfeed: { diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 6cc053914e0..9ff44c04fdc 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -466,9 +466,11 @@ StatusOr MlirHloBuilder::InfeedWithTokenInternal( TF_ASSIGN_OR_RETURN(mlir::Type result_type, ConvertShapeToType( infeed_instruction_shape, builder_)); + mlir::ArrayAttr layout; return MakeXlaOp( builder_.create(loc_, result_type, GetValue(token), - /*infeed_config=*/config)); + /*infeed_config=*/config, + /*layout=*/layout)); } StatusOr MlirHloBuilder::OutfeedWithTokenInternal( diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 412a13b6358..a3165bafb8d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -4613,9 +4613,15 @@ class ConvertInfeedDequeueTupleOp auto data_and_token_type = mlir::TupleType::get( rewriter.getContext(), {data_tuple_type, token.getType()}); + ArrayAttr layout = + GetLayout(data_and_token_type, rewriter).cast(); auto data_and_token = rewriter.create(op.getLoc(), data_and_token_type, token, - /*infeed_config=*/rewriter.getStringAttr("")); + /*infeed_config=*/rewriter.getStringAttr(""), + /*layout=*/layout); + + // TODO(b/171212005): Reenable layout. + data_and_token.removeAttr("layout"); if (op._XlaSharding().hasValue()) { // _XlaSharding attribute in TF is a serialized string of the OpSharding