Add TriangularSolve op to HLO dialect.

Adds op definition and import/export support for it.
Adds extra verifier checks on shape of op operands and results.

PiperOrigin-RevId: 289152231
Change-Id: I1f13f18131fe13ccfdb451b2748a9d76312211a2
This commit is contained in:
Prakalp Srivastava 2020-01-10 13:17:13 -08:00 committed by TensorFlower Gardener
parent d4c8c604ee
commit dff1d31b49
8 changed files with 222 additions and 0 deletions

View File

@ -449,6 +449,24 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
"permutation", ConvertDimensions(instruction->dimensions())));
MakeAndReturn(TransposeOp);
}
case HloOpcode::kTriangularSolve: {
attributes.push_back(builder_->getNamedAttr(
"left_side",
builder_->getBoolAttr(
instruction->triangular_solve_options().left_side())));
attributes.push_back(builder_->getNamedAttr(
"lower", builder_->getBoolAttr(
instruction->triangular_solve_options().lower())));
attributes.push_back(builder_->getNamedAttr(
"unit_diagonal",
builder_->getBoolAttr(
instruction->triangular_solve_options().unit_diagonal())));
auto transpose_a =
builder_->getStringAttr(TriangularSolveOptions::Transpose_Name(
instruction->triangular_solve_options().transpose_a()));
attributes.push_back(builder_->getNamedAttr("transpose_a", transpose_a));
MakeAndReturn(TriangularSolveOp);
}
case HloOpcode::kMap: {
auto op = func_builder->create<mlir::xla_hlo::MapOp>(
loc, result_type, operands,

View File

@ -1103,6 +1103,63 @@ static LogicalResult Verify(TransposeOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// TriangularSolveOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(TriangularSolveOp op) {
auto a_type = op.a().getType().dyn_cast<RankedTensorType>();
// Skip verifier if a is unranked tensor.
if (!a_type) return success();
// Check that a should have rank >= 2
auto a_rank = a_type.getRank();
if (a_rank < 2)
return op.emitOpError()
<< "operand 'a' must have rank >= 2, but got " << a_type;
// The two minor dimensions of a must have same size.
if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1))
return op.emitOpError() << "two minor dimensions of operand 'a' must have "
"equal size, but got "
<< a_type;
auto b_type = op.b().getType().dyn_cast<RankedTensorType>();
// If b is unranked skip remaining checks.
if (!b_type) return success();
// Check that a and b have same rank.
auto b_rank = b_type.getRank();
if (a_rank != b_rank)
return op.emitOpError() << "operands must have equal rank, but got "
<< a_type << " and " << b_type;
// The shared dimension of a and b should match.
if (a_type.getDimSize(a_rank - 1) !=
b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1)))
return op.emitOpError() << "shared dimension of operands 'a' and 'b' does "
"not match, but got "
<< a_type << " and " << b_type;
// The leading batch dimensions of a and b must be equal.
auto a_batch_dims = a_type.getShape().drop_back(2);
auto b_batch_dims = b_type.getShape().drop_back(2);
if (a_batch_dims != b_batch_dims)
return op.emitOpError()
<< "leading batch dimensions of the operands must be same, but got "
<< a_type << " and " << b_type;
// Result and argument b must have same shape.
auto result_type = op.getType().dyn_cast<RankedTensorType>();
if (!result_type) return success();
if (result_type != b_type)
return op.emitOpError()
<< "result and operand 'b' must have same shape, but got "
<< result_type << " and " << b_type;
return success();
}
//===----------------------------------------------------------------------===//
// GetTupleElementOp
//===----------------------------------------------------------------------===//

View File

@ -1146,6 +1146,20 @@ def HLO_TransposeOp: HLO_Op<"transpose",
let hasFolder = 1;
}
def HLO_TriangularSolveOp: HLO_Op<"triangular_solve",
[NoSideEffect, SameOperandsAndResultElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins
HLO_FpOrComplexTensor:$a,
HLO_FpOrComplexTensor:$b,
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a
);
let results = (outs HLO_FpOrComplexTensor);
}
def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [
NoSideEffect,
SingleBlockImplicitTerminator<"ReturnOp">

View File

@ -1064,6 +1064,46 @@ class BASE_HLO_TransposeOp {
}];
}
// These mirror the XLA Transpose enum in Triangular Solve options.
def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">;
def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">;
def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">;
def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">;
def HLO_TransposeAttr : StrEnumAttr<"Transpose",
"Transpose options",
[
HLO_TRANSPOSE_INVALID,
HLO_NO_TRANSPOSE,
HLO_TRANSPOSE,
HLO_ADJOINT
]>;
class BASE_HLO_TriangularSolveOp {
string summary = "TriangularSolve operator";
string description = [{
Solves systems of linear equations with lower or upper triangular
coefficient matrices by forward- or back-substitution. Broadcasting along
leading dimensions, this routine solves one of the matrix systems
op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where
op(a) is either op(a) = a, or op(a) = Transpose(a), or
op(a) = Conj(Transpose(a)).
Input data is read only from the lower/upper triangle of a, depending on the
value of lower. Values from the other triangle are ignored. Output data is
returned in the same triangle; the values in the other triangle are
implementation-defined and may be anything.
If the rank of a and b are greater than 2, they are treated as batches of
matrices, where all except the minor 2 dimensions are batch dimensions. a
and b must have equal batch dimensions.
See https://www.tensorflow.org/xla/operation_semantics#triangularsolve.
}];
}
class BASE_HLO_RngUniformOp {
string summary = "RNG with uniform distribution.";

View File

@ -173,6 +173,18 @@ static std::vector<xla::ReplicaGroup> Convert_replica_groups(
return result;
}
// Converts StringRef to xla Transpose enum.
static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
llvm::StringRef transpose_str) {
xla::TriangularSolveOptions::Transpose transpose_enum;
// Illegal tanspose string would be caught by the verifier, so
// 'Transpose_Parse' call below should never return false.
if (!xla::TriangularSolveOptions::Transpose_Parse(transpose_str,
&transpose_enum))
return xla::TriangularSolveOptions::NO_TRANSPOSE;
return transpose_enum;
}
#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
static std::vector<int64> Convert_##attribute( \
llvm::Optional<mlir::DenseIntElementsAttr> attribute) { \

View File

@ -537,6 +537,61 @@ func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x?x3x?xi32>)
// -----
func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
// expected-error@+1 {{operand 'a' must have rank >= 2, but got 'tensor<4xf32>'}}
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
return %0 : tensor<4x3xf32>
}
// -----
func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
// expected-error@+1 {{two minor dimensions of operand 'a' must have equal size, but got 'tensor<4x3xf32>'}}
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
return %0 : tensor<4x3xf32>
}
// -----
func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
// expected-error@+1 {{operands must have equal rank, but got 'tensor<10x4x4xf32>' and 'tensor<4x3xf32>'}}
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
return %0 : tensor<4x3xf32>
}
// -----
func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> {
// expected-error@+1 {{shared dimension of operands 'a' and 'b' does not match, but got 'tensor<4x4xf32>' and 'tensor<3x4xf32>'}}
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}
// -----
func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> {
// expected-error@+1 {{leading batch dimensions of the operands must be same, but got 'tensor<10x5x4x4xf32>' and 'tensor<10x6x4x3xf32>'}}
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32>
return %0 : tensor<10x6x4x3xf32>
}
// -----
func @triangular_solve_mismatch_result_and_b_type(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x4xf32> {
// expected-error@+1 {{result and operand 'b' must have same shape, but got 'tensor<4x4xf32>' and 'tensor<4x3xf32>'}}
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
// -----
// CHECK-LABEL: func @tuple
func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
%0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>

View File

@ -836,6 +836,19 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
// -----
// CHECK: HloModule
func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> {
%0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32>
return %0 : tensor<4x3xf32>
}
// CHECK: [[ARG_A:%.*]] = f32[4,4] parameter(0)
// CHECK: [[ARG_B:%.*]] = f32[4,3] parameter(1)
// CHECK: ROOT
// CHECK-SAME: f32[4,3] triangular-solve(f32[4,4] [[ARG_A]], f32[4,3] [[ARG_B]]), left_side=true, lower=true, unit_diagonal=true, transpose_a=NO_TRANSPOSE
// -----
// CHECK: HloModule
func @main(%arg0: tensor<f32>, %arg1 : tensor<i32>) -> tuple<tensor<f32>, tensor<i32>> {
%result = "xla_hlo.tuple"(%arg0, %arg1) {} : (tensor<f32>, tensor<i32>) -> tuple<tensor<f32>, tensor<i32>>

View File

@ -744,6 +744,19 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
}
// CHECK-LABEL: func @test_triangular_solve
// CHECK-SAME: ([[ARG_A:%.*]]: tensor<4x4xf32>, [[ARG_B:%.*]]: tensor<4x3xf32>) -> tensor<4x3xf32>
%test_triangular_solve (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] {
%Arg_0.1 = f32[4,4] parameter(0)
%Arg_1.2 = f32[4,3] parameter(1)
// CHECK-NEXT: "xla_hlo.triangular_solve"([[ARG_A]], [[ARG_B]])
// CHECK-SAME: left_side = true
// CHECK-SAME: lower = true
// CHECK-SAME: transpose_a = "NO_TRANSPOSE"
// CHECK-SAME: unit_diagonal = true
ROOT %triangular-solve.3 = f32[4,3] triangular-solve(f32[4,4] %Arg_0.1, f32[4,3] %Arg_1.2), left_side=true, lower=true, transpose_a=NO_TRANSPOSE, unit_diagonal=true
}
// CHECK-LABEL: func @test_tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
%test_tuple(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) {
%Arg_0.1 = s32[1] parameter(0)