[XLA:GPU] Add conversion from HLO -> MLIR LMHLO for TriangularSolve
- Also add layout attributes for inputs and output for error checking. PiperOrigin-RevId: 355863625 Change-Id: I8b37440b7a3253709780833b3716ebdc73c7084a
This commit is contained in:
parent
9922e83047
commit
aeeafe8f66
@ -99,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
|
||||
// Any pred, int or floating-point tensor types
|
||||
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
|
||||
|
||||
// A layout attribute (1D tensor of index type)
|
||||
def HLO_LayoutAttr : Attr<
|
||||
And<[IndexElementsAttr.predicate,
|
||||
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
|
||||
== 1}]>]>,
|
||||
"A 1D tensor of index type (layout)"> {
|
||||
let storageType = IndexElementsAttr.storageType;
|
||||
let returnType = IndexElementsAttr.returnType;
|
||||
let convertFromStorage = IndexElementsAttr.convertFromStorage;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MHLO nullary op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -641,7 +641,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType
|
||||
BoolAttr:$left_side,
|
||||
BoolAttr:$lower,
|
||||
BoolAttr:$unit_diagonal,
|
||||
HLO_TransposeAttr:$transpose_a
|
||||
HLO_TransposeAttr:$transpose_a,
|
||||
HLO_LayoutAttr:$layout_a,
|
||||
HLO_LayoutAttr:$layout_b,
|
||||
HLO_LayoutAttr:$layout_output
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -865,7 +865,12 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @triangular_solve_memrefs
|
||||
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
||||
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out)
|
||||
{layout_a = dense<[1, 0]> : tensor<2xindex>,
|
||||
layout_b = dense<[1, 0]> : tensor<2xindex>,
|
||||
layout_output = dense<[1, 0]> : tensor<2xindex>,
|
||||
left_side = true, lower = true, transpose_a = "NO_TRANSPOSE",
|
||||
unit_diagonal = true}
|
||||
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -138,4 +139,20 @@ StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type) {
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
|
||||
xla::TriangularSolveOptions_Transpose transpose) {
|
||||
switch (transpose) {
|
||||
case TriangularSolveOptions::NO_TRANSPOSE:
|
||||
return mlir::mhlo::Transpose::NO_TRANSPOSE;
|
||||
case TriangularSolveOptions::TRANSPOSE:
|
||||
return mlir::mhlo::Transpose::TRANSPOSE;
|
||||
case TriangularSolveOptions::ADJOINT:
|
||||
return mlir::mhlo::Transpose::ADJOINT;
|
||||
case TriangularSolveOptions::TRANSPOSE_INVALID:
|
||||
return mlir::mhlo::Transpose::TRANSPOSE_INVALID;
|
||||
default:
|
||||
return InvalidArgument("Unknown transpose enum value #%d", transpose);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -46,6 +46,8 @@ mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type);
|
||||
StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
|
||||
TriangularSolveOptions_Transpose transpose);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s
|
||||
// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s
|
||||
/// | FileCheck %s
|
||||
|
||||
HloModule TestModule
|
||||
|
||||
@ -641,6 +642,7 @@ ENTRY main {
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule TestReplicaId
|
||||
|
||||
// CHECK: func @main
|
||||
@ -651,6 +653,7 @@ ENTRY main {
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule fft
|
||||
|
||||
// CHECK: func @main
|
||||
@ -661,3 +664,22 @@ ENTRY main {
|
||||
%input = c64[5,8,32] parameter(0)
|
||||
ROOT %fft = c64[5,8,32] fft(c64[5,8,32] %input), fft_type=IFFT, fft_length={8,32}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule TriangularSolve_module
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: "lmhlo.triangular_solve"
|
||||
// CHECK-SAME: layout_a = dense<[1, 0]> : tensor<2xindex>
|
||||
// CHECK-SAME: layout_b = dense<[1, 0]> : tensor<2xindex>
|
||||
// CHECK-SAME: layout_output = dense<[1, 0]> : tensor<2xindex>
|
||||
// CHECK-SAME: left_side = false
|
||||
// CHECK-SAME: lower = true
|
||||
// CHECK-SAME: transpose_a = "NO_TRANSPOSE"
|
||||
// CHECK-SAME: unit_diagonal = false
|
||||
ENTRY main {
|
||||
%a = f32[4,4]{1,0} parameter(0)
|
||||
%b = f32[3,4]{1,0} parameter(1)
|
||||
ROOT %triangular-solve = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a, f32[3,4]{1,0} %b), lower=true, transpose_a=NO_TRANSPOSE
|
||||
}
|
||||
|
@ -379,6 +379,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
|
||||
return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
|
||||
case HloOpcode::kTranspose:
|
||||
return EmitTransposeOp(instr);
|
||||
case HloOpcode::kTriangularSolve:
|
||||
return EmitTriangularSolveOp(instr);
|
||||
case HloOpcode::kXor:
|
||||
return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
|
||||
case HloOpcode::kSort:
|
||||
@ -462,7 +464,7 @@ StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
|
||||
llvm::SmallVector<int64_t, 4> minor_to_major(
|
||||
shape.layout().minor_to_major().begin(),
|
||||
shape.layout().minor_to_major().end());
|
||||
load->setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major));
|
||||
load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), b));
|
||||
}
|
||||
return load.getResult();
|
||||
}
|
||||
@ -1290,6 +1292,38 @@ xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
|
||||
return fft;
|
||||
}
|
||||
|
||||
xla::StatusOr<lmhlo::TriangularSolveOp>
|
||||
LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
|
||||
auto hlo_triangular_solve =
|
||||
xla::Cast<xla::HloTriangularSolveInstruction>(instr);
|
||||
TF_ASSIGN_OR_RETURN(auto triangular_solve,
|
||||
CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
|
||||
const xla::TriangularSolveOptions& options =
|
||||
hlo_triangular_solve->triangular_solve_options();
|
||||
triangular_solve.left_sideAttr(builder_.getBoolAttr(options.left_side()));
|
||||
triangular_solve.lowerAttr(builder_.getBoolAttr(options.lower()));
|
||||
triangular_solve.unit_diagonalAttr(
|
||||
builder_.getBoolAttr(options.unit_diagonal()));
|
||||
TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
|
||||
xla::ConvertTranspose(options.transpose_a()));
|
||||
triangular_solve.transpose_aAttr(
|
||||
builder_.getStringAttr(mlir::mhlo::stringifyTranspose(transpose)));
|
||||
triangular_solve.layout_aAttr(
|
||||
GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
|
||||
triangular_solve.layout_bAttr(
|
||||
GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
|
||||
triangular_solve.layout_outputAttr(
|
||||
GetLayoutAttribute(instr->shape().layout(), &builder_));
|
||||
return triangular_solve;
|
||||
}
|
||||
|
||||
mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
|
||||
const xla::Layout& layout, Builder* builder) {
|
||||
llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
|
||||
layout.minor_to_major().end());
|
||||
return builder->getIndexTensorAttr(minor_to_major);
|
||||
}
|
||||
|
||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
const xla::HloInstruction* instr, const xla::Shape& current_shape,
|
||||
const xla::ShapeIndex& shape_index) {
|
||||
|
@ -126,6 +126,8 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
|
||||
xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp(
|
||||
const xla::HloInstruction* instr);
|
||||
xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr);
|
||||
xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp(
|
||||
const xla::HloInstruction* instr);
|
||||
|
||||
// Create LHLO operation operands given an XLA HLO instruction. By default,
|
||||
// all XLA HLO operands and results are converted to MLIR and appended to
|
||||
@ -173,6 +175,9 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
|
||||
return GetI64DenseElementsAttr(elements);
|
||||
}
|
||||
|
||||
static mlir::DenseIntElementsAttr GetLayoutAttribute(
|
||||
const xla::Layout& layout, Builder* builder);
|
||||
|
||||
tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final;
|
||||
|
||||
// Computation parameters don't need any specific handling when they are
|
||||
|
Loading…
Reference in New Issue
Block a user