diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 7c42100e433..be8b2e13daf 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -2,7 +2,7 @@ HloModule TestModule -// CHECK: func @TestComputation +// CHECK-LABEL: func @TestComputation FusedComputation { // CHECK: tensor_load %arg0 {minor_to_major = dense<[0, 1]> : tensor<2xindex>} @@ -24,7 +24,7 @@ update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { ROOT rhs = s32[] parameter(1) } -// CHECK: func @main +// CHECK-LABEL: func @main // CHECK: "lmhlo.scatter" // CHECK: ^bb0(%[[ARG5:.*]]: tensor, %[[ARG6:.*]]: tensor): // CHECK: "mhlo.return"(%[[ARG6]]) @@ -46,3 +46,38 @@ ENTRY main { scatter_dims_to_operand_dims={0}, index_vector_dim=1 } + +// ----- + +HloModule SelectAndScatter + +%ge_F32 (lhs.5: f32[], rhs.6: f32[]) -> pred[] { + %lhs.5 = f32[] parameter(0) + %rhs.6 = f32[] parameter(1) + ROOT %compare.7 = pred[] compare(f32[] %lhs.5, f32[] %rhs.6), direction=GE +} + +%add_F32 (lhs.9: f32[], rhs.10: f32[]) -> f32[] { + %lhs.9 = f32[] parameter(0) + %rhs.10 = f32[] parameter(1) + ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10) +} + +// CHECK-LABEL: func @main +// CHECK: "lmhlo.select_and_scatter" +// CHECK: ^bb0(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor): +// CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = "GE"} +// CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor) -> () +// CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): +// CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG3]] +// CHECK: "mhlo.return"(%[[ADD]]) : (tensor) -> () +// CHECK: padding = dense<0> : tensor<1xi64> +// CHECK: window_dimensions = dense<3> : tensor<1xi64> +// CHECK: window_strides = dense<3> : tensor<1xi64> +// CHECK: (memref<6xf32>, memref<2xf32>, memref, memref<6xf32>) -> () +ENTRY main () -> f32[6] { + %operand = f32[6]{0} parameter(0) + %source = f32[2]{0} parameter(1) + %init = f32[] parameter(2) + ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32 +} diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 0efb8a16ba3..2ce50dffcd0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -407,9 +407,9 @@ LhloDialectEmitter::GetScatterDimensionNumbers(HloInstruction* instr) { const ::xla::ScatterDimensionNumbers& xla_scatter_dim = scatter_instr->scatter_dimension_numbers(); auto scatter_dimension_numbers = mhlo::ScatterDimensionNumbers::get( - getI64DenseElementsAttr(xla_scatter_dim.update_window_dims()), - getI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()), - getI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()), + GetI64DenseElementsAttr(xla_scatter_dim.update_window_dims()), + GetI64DenseElementsAttr(xla_scatter_dim.inserted_window_dims()), + GetI64DenseElementsAttr(xla_scatter_dim.scatter_dims_to_operand_dims()), builder_.getI64IntegerAttr(xla_scatter_dim.index_vector_dim()), module_.getContext()); return scatter_dimension_numbers; @@ -443,6 +443,43 @@ Status LhloDialectEmitter::HandleScatter(HloInstruction* instr) { return EmitScatterOp(instr).status(); } +StatusOr LhloDialectEmitter::EmitSelectAndScatterOp( + HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto select_and_scatter, + CreateOpWithoutAttrs(instr)); + + // copy attributes + auto* select_and_scatter_instr = + ::xla::Cast<::xla::HloSelectAndScatterInstruction>(instr); + const ::xla::Window& window = select_and_scatter_instr->window(); + + select_and_scatter.window_dimensionsAttr( + GetWindowElements(window, [](const ::xla::WindowDimension& dim) { + return static_cast(dim.size()); + })); + select_and_scatter.window_stridesAttr( + GetWindowElements(window, [](const ::xla::WindowDimension& dim) { + return static_cast(dim.stride()); + })); + select_and_scatter.paddingAttr( + GetWindowElements(window, [](const ::xla::WindowDimension& dim) { + return static_cast(dim.padding_low()); + })); + + // import select and scatter computation as region + TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion( + *select_and_scatter_instr->select(), &select_and_scatter.select(), + &builder_)); + TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion( + *select_and_scatter_instr->scatter(), &select_and_scatter.scatter(), + &builder_)); + return select_and_scatter; +} + +Status LhloDialectEmitter::HandleSelectAndScatter(HloInstruction* instr) { + return EmitSelectAndScatterOp(instr).status(); +} + StatusOr LhloDialectEmitter::GetOrCreateArrayView( const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape, const ::xla::ShapeIndex& shape_index) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 97a9b17e81d..47cde92a8bc 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -49,17 +49,30 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { ::xla::StatusOr EmitScatterOp(::xla::HloInstruction* instr); ::xla::StatusOr GetScatterDimensionNumbers( ::xla::HloInstruction* instr); + ::xla::StatusOr EmitSelectAndScatterOp( + ::xla::HloInstruction* instr); private: template ::xla::StatusOr CreateOpWithoutAttrs(::xla::HloInstruction* instr); template - DenseIntElementsAttr getI64DenseElementsAttr(const T& container) { + DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) { return builder_.getI64TensorAttr( {container.data(), static_cast(container.size())}); } + DenseIntElementsAttr GetWindowElements( + const ::xla::Window& window, + std::function getter) { + llvm::SmallVector elements; + elements.reserve(window.dimensions_size()); + for (const ::xla::WindowDimension& dim : window.dimensions()) { + elements.push_back(getter(dim)); + } + return GetI64DenseElementsAttr(elements); + } + tensorflow::Status DefaultAction(::xla::HloInstruction* instr) final; // Computation parameters don't need any specific handling when they are @@ -71,6 +84,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { tensorflow::Status HandleSort(::xla::HloInstruction* instr) final; tensorflow::Status HandleFusion(::xla::HloInstruction* instr) final; tensorflow::Status HandleScatter(::xla::HloInstruction* instr) final; + tensorflow::Status HandleSelectAndScatter(::xla::HloInstruction* instr) final; // Helper function that recursively visits the tuple structure in // `current_shape`, and reconstruct a matching lmhlo::TupleOp.