diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index b5411e3b9ba..896fe0fff05 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -99,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; // Any pred, int or floating-point tensor types def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; +// A layout attribute (1D tensor of index type) +def HLO_LayoutAttr : Attr< + And<[IndexElementsAttr.predicate, + CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank() + == 1}]>]>, + "A 1D tensor of index type (layout)"> { + let storageType = IndexElementsAttr.storageType; + let returnType = IndexElementsAttr.returnType; + let convertFromStorage = IndexElementsAttr.convertFromStorage; +} + //===----------------------------------------------------------------------===// // MHLO nullary op definitions. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index d1bdd498eff..fb1e17692ed 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -641,7 +641,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType BoolAttr:$left_side, BoolAttr:$lower, BoolAttr:$unit_diagonal, - HLO_TransposeAttr:$transpose_a + HLO_TransposeAttr:$transpose_a, + HLO_LayoutAttr:$layout_a, + HLO_LayoutAttr:$layout_b, + HLO_LayoutAttr:$layout_output ); } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir index 76be69f81ab..b01beb79ca6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir @@ -865,7 +865,12 @@ func @replica_id_memrefs(%arg_out: memref) -> () { // CHECK-LABEL: func @triangular_solve_memrefs func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () { - "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} + "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) + {layout_a = dense<[1, 0]> : tensor<2xindex>, + layout_b = dense<[1, 0]> : tensor<2xindex>, + layout_output = dense<[1, 0]> : tensor<2xindex>, + left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", + unit_diagonal = true} : (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> () return } diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.cc b/tensorflow/compiler/mlir/xla/attribute_importer.cc index c18b2157f87..6cd7041a16a 100644 --- a/tensorflow/compiler/mlir/xla/attribute_importer.cc +++ b/tensorflow/compiler/mlir/xla/attribute_importer.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -138,4 +139,20 @@ StatusOr ConvertFftType(FftType type) { } } +StatusOr ConvertTranspose( + xla::TriangularSolveOptions_Transpose transpose) { + switch (transpose) { + case TriangularSolveOptions::NO_TRANSPOSE: + return mlir::mhlo::Transpose::NO_TRANSPOSE; + case TriangularSolveOptions::TRANSPOSE: + return mlir::mhlo::Transpose::TRANSPOSE; + case TriangularSolveOptions::ADJOINT: + return mlir::mhlo::Transpose::ADJOINT; + case TriangularSolveOptions::TRANSPOSE_INVALID: + return mlir::mhlo::Transpose::TRANSPOSE_INVALID; + default: + return InvalidArgument("Unknown transpose enum value #%d", transpose); + } +} + } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/attribute_importer.h b/tensorflow/compiler/mlir/xla/attribute_importer.h index e8bf8cb6950..1555e420c53 100644 --- a/tensorflow/compiler/mlir/xla/attribute_importer.h +++ b/tensorflow/compiler/mlir/xla/attribute_importer.h @@ -46,6 +46,8 @@ mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers( const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); StatusOr ConvertFftType(FftType type); +StatusOr ConvertTranspose( + TriangularSolveOptions_Transpose transpose); } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt index 48f7620d6cb..7f8d648ae6a 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt @@ -1,4 +1,5 @@ -// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s +// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s +/// | FileCheck %s HloModule TestModule @@ -641,6 +642,7 @@ ENTRY main { } // ----- + HloModule TestReplicaId // CHECK: func @main @@ -651,6 +653,7 @@ ENTRY main { } // ----- + HloModule fft // CHECK: func @main @@ -661,3 +664,22 @@ ENTRY main { %input = c64[5,8,32] parameter(0) ROOT %fft = c64[5,8,32] fft(c64[5,8,32] %input), fft_type=IFFT, fft_length={8,32} } + +// ----- + +HloModule TriangularSolve_module + +// CHECK: func @main +// CHECK: "lmhlo.triangular_solve" +// CHECK-SAME: layout_a = dense<[1, 0]> : tensor<2xindex> +// CHECK-SAME: layout_b = dense<[1, 0]> : tensor<2xindex> +// CHECK-SAME: layout_output = dense<[1, 0]> : tensor<2xindex> +// CHECK-SAME: left_side = false +// CHECK-SAME: lower = true +// CHECK-SAME: transpose_a = "NO_TRANSPOSE" +// CHECK-SAME: unit_diagonal = false +ENTRY main { + %a = f32[4,4]{1,0} parameter(0) + %b = f32[3,4]{1,0} parameter(1) + ROOT %triangular-solve = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a, f32[3,4]{1,0} %b), lower=true, transpose_a=NO_TRANSPOSE +} diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index dcf2e9546dd..7d924c51e89 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -379,6 +379,8 @@ StatusOr LhloDialectEmitter::EmitOp( return CreateOpWithoutAttrs(instr); case HloOpcode::kTranspose: return EmitTransposeOp(instr); + case HloOpcode::kTriangularSolve: + return EmitTriangularSolveOp(instr); case HloOpcode::kXor: return CreateOpWithoutAttrs(instr); case HloOpcode::kSort: @@ -462,7 +464,7 @@ StatusOr LhloDialectEmitter::RewriteFusionOperand( llvm::SmallVector minor_to_major( shape.layout().minor_to_major().begin(), shape.layout().minor_to_major().end()); - load->setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major)); + load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), b)); } return load.getResult(); } @@ -1290,6 +1292,38 @@ xla::StatusOr LhloDialectEmitter::EmitFftOp( return fft; } +xla::StatusOr +LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) { + auto hlo_triangular_solve = + xla::Cast(instr); + TF_ASSIGN_OR_RETURN(auto triangular_solve, + CreateOpWithoutAttrs(instr)); + const xla::TriangularSolveOptions& options = + hlo_triangular_solve->triangular_solve_options(); + triangular_solve.left_sideAttr(builder_.getBoolAttr(options.left_side())); + triangular_solve.lowerAttr(builder_.getBoolAttr(options.lower())); + triangular_solve.unit_diagonalAttr( + builder_.getBoolAttr(options.unit_diagonal())); + TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose, + xla::ConvertTranspose(options.transpose_a())); + triangular_solve.transpose_aAttr( + builder_.getStringAttr(mlir::mhlo::stringifyTranspose(transpose))); + triangular_solve.layout_aAttr( + GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_)); + triangular_solve.layout_bAttr( + GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_)); + triangular_solve.layout_outputAttr( + GetLayoutAttribute(instr->shape().layout(), &builder_)); + return triangular_solve; +} + +mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute( + const xla::Layout& layout, Builder* builder) { + llvm::SmallVector minor_to_major(layout.minor_to_major().begin(), + layout.minor_to_major().end()); + return builder->getIndexTensorAttr(minor_to_major); +} + StatusOr LhloDialectEmitter::GetOrCreateArrayView( const xla::HloInstruction* instr, const xla::Shape& current_shape, const xla::ShapeIndex& shape_index) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 4a7e1b631da..5c8a0fe92ff 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -126,6 +126,8 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { xla::StatusOr EmitRngGetAndUpdateStateOp( const xla::HloInstruction* instr); xla::StatusOr EmitFftOp(const xla::HloInstruction* instr); + xla::StatusOr EmitTriangularSolveOp( + const xla::HloInstruction* instr); // Create LHLO operation operands given an XLA HLO instruction. By default, // all XLA HLO operands and results are converted to MLIR and appended to @@ -173,6 +175,9 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault { return GetI64DenseElementsAttr(elements); } + static mlir::DenseIntElementsAttr GetLayoutAttribute( + const xla::Layout& layout, Builder* builder); + tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final; // Computation parameters don't need any specific handling when they are