Add TriangularSolve op to HLO dialect.
Adds op definition and import/export support for it. Adds extra verifier checks on shape of op operands and results. PiperOrigin-RevId: 289152231 Change-Id: I1f13f18131fe13ccfdb451b2748a9d76312211a2
This commit is contained in:
parent
d4c8c604ee
commit
dff1d31b49
@ -449,6 +449,24 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
"permutation", ConvertDimensions(instruction->dimensions())));
|
||||
MakeAndReturn(TransposeOp);
|
||||
}
|
||||
case HloOpcode::kTriangularSolve: {
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"left_side",
|
||||
builder_->getBoolAttr(
|
||||
instruction->triangular_solve_options().left_side())));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"lower", builder_->getBoolAttr(
|
||||
instruction->triangular_solve_options().lower())));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"unit_diagonal",
|
||||
builder_->getBoolAttr(
|
||||
instruction->triangular_solve_options().unit_diagonal())));
|
||||
auto transpose_a =
|
||||
builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
|
||||
instruction->triangular_solve_options().transpose_a()));
|
||||
attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
|
||||
MakeAndReturn(TriangularSolveOp);
|
||||
}
|
||||
case HloOpcode::kMap: {
|
||||
auto op = func_builder->create<mlir::xla_hlo::MapOp>(
|
||||
loc, result_type, operands,
|
||||
|
@ -1103,6 +1103,63 @@ static LogicalResult Verify(TransposeOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TriangularSolveOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(TriangularSolveOp op) {
|
||||
auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// Skip verifier if a is unranked tensor.
|
||||
if (!a_type) return success();
|
||||
|
||||
// Check that a should have rank >= 2
|
||||
auto a_rank = a_type.getRank();
|
||||
if (a_rank < 2)
|
||||
return op.emitOpError()
|
||||
<< "operand 'a' must have rank >= 2, but got " << a_type;
|
||||
|
||||
// The two minor dimensions of a must have same size.
|
||||
if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
|
||||
return op.emitOpError() << "two minor dimensions of operand 'a' must have "
|
||||
"equal size, but got "
|
||||
<< a_type;
|
||||
|
||||
auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
|
||||
// If b is unranked skip remaining checks.
|
||||
if (!b_type) return success();
|
||||
|
||||
// Check that a and b have same rank.
|
||||
auto b_rank = b_type.getRank();
|
||||
if (a_rank != b_rank)
|
||||
return op.emitOpError() << "operands must have equal rank, but got "
|
||||
<< a_type << " and " << b_type;
|
||||
|
||||
// The shared dimension of a and b should match.
|
||||
if (a_type.getDimSize(a_rank - 1) !=
|
||||
b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
|
||||
return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
|
||||
"not match, but got "
|
||||
<< a_type << " and " << b_type;
|
||||
|
||||
// The leading batch dimensions of a and b must be equal.
|
||||
auto a_batch_dims = a_type.getShape().drop_back(2);
|
||||
auto b_batch_dims = b_type.getShape().drop_back(2);
|
||||
if (a_batch_dims != b_batch_dims)
|
||||
return op.emitOpError()
|
||||
<< "leading batch dimensions of the operands must be same, but got "
|
||||
<< a_type << " and " << b_type;
|
||||
|
||||
// Result and argument b must have same shape.
|
||||
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type) return success();
|
||||
if (result_type != b_type)
|
||||
return op.emitOpError()
|
||||
<< "result and operand 'b' must have same shape, but got "
|
||||
<< result_type << " and " << b_type;
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetTupleElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1146,6 +1146,20 @@ def HLO_TransposeOp: HLO_Op<"transpose",
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def HLO_TriangularSolveOp: HLO_Op<"triangular_solve",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>,
|
||||
BASE_HLO_TriangularSolveOp {
|
||||
let arguments = (ins
|
||||
HLO_FpOrComplexTensor:$a,
|
||||
HLO_FpOrComplexTensor:$b,
|
||||
BoolAttr:$left_side,
|
||||
BoolAttr:$lower,
|
||||
BoolAttr:$unit_diagonal,
|
||||
HLO_TransposeAttr:$transpose_a
|
||||
);
|
||||
let results = (outs HLO_FpOrComplexTensor);
|
||||
}
|
||||
|
||||
def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [
|
||||
NoSideEffect,
|
||||
SingleBlockImplicitTerminator<"ReturnOp">
|
||||
|
@ -1064,6 +1064,46 @@ class BASE_HLO_TransposeOp {
|
||||
}];
|
||||
}
|
||||
|
||||
// These mirror the XLA Transpose enum in Triangular Solve options.
|
||||
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
|
||||
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
|
||||
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
|
||||
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
|
||||
|
||||
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
|
||||
"Transpose options",
|
||||
[
|
||||
HLO_TRANSPOSE_INVALID,
|
||||
HLO_NO_TRANSPOSE,
|
||||
HLO_TRANSPOSE,
|
||||
HLO_ADJOINT
|
||||
]>;
|
||||
|
||||
class BASE_HLO_TriangularSolveOp {
|
||||
string summary = "TriangularSolve operator";
|
||||
|
||||
string description = [{
|
||||
Solves systems of linear equations with lower or upper triangular
|
||||
coefficient matrices by forward- or back-substitution. Broadcasting along
|
||||
leading dimensions, this routine solves one of the matrix systems
|
||||
op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where
|
||||
op(a) is either op(a) = a, or op(a) = Transpose(a), or
|
||||
op(a) = Conj(Transpose(a)).
|
||||
|
||||
Input data is read only from the lower/upper triangle of a, depending on the
|
||||
value of lower. Values from the other triangle are ignored. Output data is
|
||||
returned in the same triangle; the values in the other triangle are
|
||||
implementation-defined and may be anything.
|
||||
|
||||
If the rank of a and b are greater than 2, they are treated as batches of
|
||||
matrices, where all except the minor 2 dimensions are batch dimensions. a
|
||||
and b must have equal batch dimensions.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#triangularsolve.
|
||||
}];
|
||||
|
||||
}
|
||||
|
||||
class BASE_HLO_RngUniformOp {
|
||||
string summary = "RNG with uniform distribution.";
|
||||
|
||||
|
@ -173,6 +173,18 @@ static std::vector<xla::ReplicaGroup> Convert_replica_groups(
|
||||
return result;
|
||||
}
|
||||
|
||||
// Converts StringRef to xla Transpose enum.
|
||||
static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
|
||||
llvm::StringRef transpose_str) {
|
||||
xla::TriangularSolveOptions::Transpose transpose_enum;
|
||||
// Illegal tanspose string would be caught by the verifier, so
|
||||
// 'Transpose_Parse' call below should never return false.
|
||||
if (!xla::TriangularSolveOptions::Transpose_Parse(transpose_str,
|
||||
&transpose_enum))
|
||||
return xla::TriangularSolveOptions::NO_TRANSPOSE;
|
||||
return transpose_enum;
|
||||
}
|
||||
|
||||
#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
|
||||
static std::vector<int64> Convert_##attribute( \
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \
|
||||
|
@ -537,6 +537,61 @@ func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x?x3x?xi32>)
|
||||
|
||||
// -----
|
||||
|
||||
func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
|
||||
// expected-error@+1 {{operand 'a' must have rank >= 2, but got 'tensor<4xf32>'}}
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
|
||||
return %0 : tensor<4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
|
||||
// expected-error@+1 {{two minor dimensions of operand 'a' must have equal size, but got 'tensor<4x3xf32>'}}
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
|
||||
return %0 : tensor<4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
|
||||
// expected-error@+1 {{operands must have equal rank, but got 'tensor<10x4x4xf32>' and 'tensor<4x3xf32>'}}
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
|
||||
return %0 : tensor<4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// expected-error@+1 {{shared dimension of operands 'a' and 'b' does not match, but got 'tensor<4x4xf32>' and 'tensor<3x4xf32>'}}
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> {
|
||||
// expected-error@+1 {{leading batch dimensions of the operands must be same, but got 'tensor<10x5x4x4xf32>' and 'tensor<10x6x4x3xf32>'}}
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32>
|
||||
return %0 : tensor<10x6x4x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @triangular_solve_mismatch_result_and_b_type(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x4xf32> {
|
||||
// expected-error@+1 {{result and operand 'b' must have same shape, but got 'tensor<4x4xf32>' and 'tensor<4x3xf32>'}}
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x4xf32>
|
||||
return %0 : tensor<4x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tuple
|
||||
func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
|
||||
%0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
|
@ -836,6 +836,19 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
|
||||
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
|
||||
return %0 : tensor<4x3xf32>
|
||||
}
|
||||
|
||||
// CHECK: [[ARG_A:%.*]] = f32[4,4] parameter(0)
|
||||
// CHECK: [[ARG_B:%.*]] = f32[4,3] parameter(1)
|
||||
// CHECK: ROOT
|
||||
// CHECK-SAME: f32[4,3] triangular-solve(f32[4,4] [[ARG_A]], f32[4,3] [[ARG_B]]), left_side=true, lower=true, unit_diagonal=true, transpose_a=NO_TRANSPOSE
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> {
|
||||
%result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>
|
||||
|
@ -744,6 +744,19 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_triangular_solve
|
||||
// CHECK-SAME: ([[ARG_A:%.*]]: tensor<4x4xf32>, [[ARG_B:%.*]]: tensor<4x3xf32>) -> tensor<4x3xf32>
|
||||
%test_triangular_solve (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] {
|
||||
%Arg_0.1 = f32[4,4] parameter(0)
|
||||
%Arg_1.2 = f32[4,3] parameter(1)
|
||||
// CHECK-NEXT: "xla_hlo.triangular_solve"([[ARG_A]], [[ARG_B]])
|
||||
// CHECK-SAME: left_side = true
|
||||
// CHECK-SAME: lower = true
|
||||
// CHECK-SAME: transpose_a = "NO_TRANSPOSE"
|
||||
// CHECK-SAME: unit_diagonal = true
|
||||
ROOT %triangular-solve.3 = f32[4,3] triangular-solve(f32[4,4] %Arg_0.1, f32[4,3] %Arg_1.2), left_side=true, lower=true, transpose_a=NO_TRANSPOSE, unit_diagonal=true
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
|
||||
%test_tuple(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) {
|
||||
%Arg_0.1 = s32[1] parameter(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user