[XLA:GPU] Add conversion from HLO -> MLIR LMHLO for TriangularSolve

- Also add layout attributes for inputs and output for error checking.

PiperOrigin-RevId: 355863625
Change-Id: I8b37440b7a3253709780833b3716ebdc73c7084a
This commit is contained in:
Rahul Joshi 2021-02-05 09:16:49 -08:00 committed by TensorFlower Gardener
parent 9922e83047
commit aeeafe8f66
8 changed files with 103 additions and 4 deletions

View File

@ -99,6 +99,17 @@ def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>;
// Any pred, int or floating-point tensor types
def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>;
// A layout attribute (1D tensor of index type)
def HLO_LayoutAttr : Attr<
And<[IndexElementsAttr.predicate,
CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>().getType().getRank()
== 1}]>]>,
"A 1D tensor of index type (layout)"> {
let storageType = IndexElementsAttr.storageType;
let returnType = IndexElementsAttr.returnType;
let convertFromStorage = IndexElementsAttr.convertFromStorage;
}
//===----------------------------------------------------------------------===//
// MHLO nullary op definitions.
//===----------------------------------------------------------------------===//

View File

@ -641,7 +641,10 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a
HLO_TransposeAttr:$transpose_a,
HLO_LayoutAttr:$layout_a,
HLO_LayoutAttr:$layout_b,
HLO_LayoutAttr:$layout_output
);
}

View File

@ -865,7 +865,12 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out)
{layout_a = dense<[1, 0]> : tensor<2xindex>,
layout_b = dense<[1, 0]> : tensor<2xindex>,
layout_output = dense<[1, 0]> : tensor<2xindex>,
left_side = true, lower = true, transpose_a = "NO_TRANSPOSE",
unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -138,4 +139,20 @@ StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type) {
}
}
StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
xla::TriangularSolveOptions_Transpose transpose) {
switch (transpose) {
case TriangularSolveOptions::NO_TRANSPOSE:
return mlir::mhlo::Transpose::NO_TRANSPOSE;
case TriangularSolveOptions::TRANSPOSE:
return mlir::mhlo::Transpose::TRANSPOSE;
case TriangularSolveOptions::ADJOINT:
return mlir::mhlo::Transpose::ADJOINT;
case TriangularSolveOptions::TRANSPOSE_INVALID:
return mlir::mhlo::Transpose::TRANSPOSE_INVALID;
default:
return InvalidArgument("Unknown transpose enum value #%d", transpose);
}
}
} // namespace xla

View File

@ -46,6 +46,8 @@ mlir::mhlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder);
StatusOr<mlir::mhlo::FftType> ConvertFftType(FftType type);
StatusOr<mlir::mhlo::Transpose> ConvertTranspose(
TriangularSolveOptions_Transpose transpose);
} // namespace xla

View File

@ -1,4 +1,5 @@
// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s
// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s
/// | FileCheck %s
HloModule TestModule
@ -641,6 +642,7 @@ ENTRY main {
}
// -----
HloModule TestReplicaId
// CHECK: func @main
@ -651,6 +653,7 @@ ENTRY main {
}
// -----
HloModule fft
// CHECK: func @main
@ -661,3 +664,22 @@ ENTRY main {
%input = c64[5,8,32] parameter(0)
ROOT %fft = c64[5,8,32] fft(c64[5,8,32] %input), fft_type=IFFT, fft_length={8,32}
}
// -----
HloModule TriangularSolve_module
// CHECK: func @main
// CHECK: "lmhlo.triangular_solve"
// CHECK-SAME: layout_a = dense<[1, 0]> : tensor<2xindex>
// CHECK-SAME: layout_b = dense<[1, 0]> : tensor<2xindex>
// CHECK-SAME: layout_output = dense<[1, 0]> : tensor<2xindex>
// CHECK-SAME: left_side = false
// CHECK-SAME: lower = true
// CHECK-SAME: transpose_a = "NO_TRANSPOSE"
// CHECK-SAME: unit_diagonal = false
ENTRY main {
%a = f32[4,4]{1,0} parameter(0)
%b = f32[3,4]{1,0} parameter(1)
ROOT %triangular-solve = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a, f32[3,4]{1,0} %b), lower=true, transpose_a=NO_TRANSPOSE
}

View File

@ -379,6 +379,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
case HloOpcode::kTranspose:
return EmitTransposeOp(instr);
case HloOpcode::kTriangularSolve:
return EmitTriangularSolveOp(instr);
case HloOpcode::kXor:
return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
case HloOpcode::kSort:
@ -462,7 +464,7 @@ StatusOr<Value> LhloDialectEmitter::RewriteFusionOperand(
llvm::SmallVector<int64_t, 4> minor_to_major(
shape.layout().minor_to_major().begin(),
shape.layout().minor_to_major().end());
load->setAttr("minor_to_major", b->getIndexTensorAttr(minor_to_major));
load->setAttr("minor_to_major", GetLayoutAttribute(shape.layout(), b));
}
return load.getResult();
}
@ -1290,6 +1292,38 @@ xla::StatusOr<lmhlo::FftOp> LhloDialectEmitter::EmitFftOp(
return fft;
}
xla::StatusOr<lmhlo::TriangularSolveOp>
LhloDialectEmitter::EmitTriangularSolveOp(const xla::HloInstruction* instr) {
auto hlo_triangular_solve =
xla::Cast<xla::HloTriangularSolveInstruction>(instr);
TF_ASSIGN_OR_RETURN(auto triangular_solve,
CreateOpWithoutAttrs<lmhlo::TriangularSolveOp>(instr));
const xla::TriangularSolveOptions& options =
hlo_triangular_solve->triangular_solve_options();
triangular_solve.left_sideAttr(builder_.getBoolAttr(options.left_side()));
triangular_solve.lowerAttr(builder_.getBoolAttr(options.lower()));
triangular_solve.unit_diagonalAttr(
builder_.getBoolAttr(options.unit_diagonal()));
TF_ASSIGN_OR_RETURN(mlir::mhlo::Transpose transpose,
xla::ConvertTranspose(options.transpose_a()));
triangular_solve.transpose_aAttr(
builder_.getStringAttr(mlir::mhlo::stringifyTranspose(transpose)));
triangular_solve.layout_aAttr(
GetLayoutAttribute(instr->operand(0)->shape().layout(), &builder_));
triangular_solve.layout_bAttr(
GetLayoutAttribute(instr->operand(1)->shape().layout(), &builder_));
triangular_solve.layout_outputAttr(
GetLayoutAttribute(instr->shape().layout(), &builder_));
return triangular_solve;
}
mlir::DenseIntElementsAttr LhloDialectEmitter::GetLayoutAttribute(
const xla::Layout& layout, Builder* builder) {
llvm::SmallVector<int64_t, 4> minor_to_major(layout.minor_to_major().begin(),
layout.minor_to_major().end());
return builder->getIndexTensorAttr(minor_to_major);
}
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const xla::HloInstruction* instr, const xla::Shape& current_shape,
const xla::ShapeIndex& shape_index) {

View File

@ -126,6 +126,8 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
xla::StatusOr<lmhlo::RngGetAndUpdateStateOp> EmitRngGetAndUpdateStateOp(
const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::FftOp> EmitFftOp(const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::TriangularSolveOp> EmitTriangularSolveOp(
const xla::HloInstruction* instr);
// Create LHLO operation operands given an XLA HLO instruction. By default,
// all XLA HLO operands and results are converted to MLIR and appended to
@ -173,6 +175,9 @@ class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
return GetI64DenseElementsAttr(elements);
}
static mlir::DenseIntElementsAttr GetLayoutAttribute(
const xla::Layout& layout, Builder* builder);
tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final;
// Computation parameters don't need any specific handling when they are