Fix TensorFlow MatrixBandPart op lowering to HLO
createIotaOp should use num_lower operand element type and handle operands requiring broadcasting between conditional result and input. PiperOrigin-RevId: 338778254 Change-Id: Ie569518a85735a9dc3c8885c20eda7fe74901ec6
This commit is contained in:
parent
d39bcbaefb
commit
e7b54fbda2
@ -1224,32 +1224,30 @@ func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf
|
||||
// CHECK-LABEL: matrix_band_part
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
|
||||
func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
|
||||
// CHECK: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64>
|
||||
// CHECK: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64>
|
||||
// CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64>
|
||||
// CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64>
|
||||
|
||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64>
|
||||
// CHECK-DAG: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
|
||||
// CHECK: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK-DAG: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK-DAG: %[[F:.*]] = "mhlo.negate"(%[[B]]) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
// CHECK: %[[E:.*]] = "mhlo.convert"(%[[B]]) : (tensor<i64>) -> tensor<bf16>
|
||||
// CHECK: %[[F:.*]] = "mhlo.negate"(%[[E]]) : (tensor<bf16>) -> tensor<bf16>
|
||||
// CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64>
|
||||
// CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64>
|
||||
// CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64>
|
||||
// CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1>
|
||||
|
||||
// CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16>
|
||||
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16>
|
||||
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16>
|
||||
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<64x64xi1>
|
||||
// CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1>
|
||||
|
||||
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
|
||||
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<64x64xi1>
|
||||
// CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
|
||||
|
||||
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
|
||||
// CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16>
|
||||
|
||||
// CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16>
|
||||
// CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]])
|
||||
// CHECK: return %[[R]]
|
||||
// CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]])
|
||||
// CHECK-DAG: return %[[R]]
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
|
||||
return %0 : tensor<64x64xbf16>
|
||||
}
|
||||
@ -1257,19 +1255,20 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: ten
|
||||
// CHECK-LABEL: matrix_band_part_2
|
||||
// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
|
||||
func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> {
|
||||
// CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xbf16>
|
||||
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16>
|
||||
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16>
|
||||
// CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64>
|
||||
// CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64>
|
||||
// CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64>
|
||||
|
||||
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<24x48xi1>
|
||||
// CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1>
|
||||
|
||||
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
|
||||
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<24x48xi1>
|
||||
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
|
||||
// CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1>
|
||||
// CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
|
||||
|
||||
// CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
|
||||
// CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]])
|
||||
// CHECK: return %[[R]]
|
||||
// CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
|
||||
|
||||
// CHECK-DAG: %[[K:.*]] = "mhlo.broadcast_in_dim"(%[[J]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<24x48xi1>) -> tensor<12x24x48xi1>
|
||||
// CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[K]], %[[INPUT]], %[[ZERO2]])
|
||||
// CHECK-DAG: return %[[R]]
|
||||
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16>
|
||||
return %0 : tensor<12x24x48xbf16>
|
||||
}
|
||||
|
@ -786,11 +786,11 @@ static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
|
||||
// dimension and last dimension, respectively). The element type of the
|
||||
// outputted RankedTensorType will match the element type of `input`.
|
||||
// Requires that `input` is a tensor.
|
||||
static RankedTensorType Get2DTensorType(Value input) {
|
||||
static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
|
||||
// `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
|
||||
int dim_0 = GetDimensionSizeFromEnd(input, 1);
|
||||
int dim_1 = GetDimensionSizeFromEnd(input, 0);
|
||||
auto element_type = input.getType().cast<TensorType>().getElementType();
|
||||
auto element_type = num_lower.getType().cast<TensorType>().getElementType();
|
||||
return RankedTensorType::get({dim_0, dim_1}, element_type);
|
||||
}
|
||||
|
||||
|
@ -365,7 +365,8 @@ class getIntegerAttr<string x>: NativeCodeCall<
|
||||
"$_builder.getI64IntegerAttr(" # x # ")">;
|
||||
|
||||
class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall<
|
||||
"$_builder.getI64IntegerAttr(GetDimensionSizeFromEnd($0, " # dimFromEnd # "))"
|
||||
"$_builder.getIntegerAttr(getElementTypeOrSelf($1.getType()), "
|
||||
" GetDimensionSizeFromEnd($0, " # dimFromEnd # "))"
|
||||
>;
|
||||
|
||||
// TODO(b/149615308): Enable IotaOp usage as a child operation in a pattern
|
||||
@ -373,7 +374,7 @@ class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall<
|
||||
// cannot be inferred.
|
||||
class createIotaOp<string dim>: NativeCodeCall<
|
||||
"$_builder.create<mhlo::IotaOp>($0.getOwner()->getLoc(), "
|
||||
"Get2DTensorType($1), $_builder.getI64IntegerAttr(" # dim # "))">;
|
||||
"Get2DTensorType($1, $2), $_builder.getI64IntegerAttr(" # dim # "))">;
|
||||
|
||||
// This op needs to be created in C++ because the generated Convert Op has no
|
||||
// way to specify shape information as an input. In the MatrixBandPart op
|
||||
@ -396,9 +397,10 @@ def createConvertOp: NativeCodeCall<
|
||||
// return (indicator ? input : zero_matrix)
|
||||
//
|
||||
// TODO(b/149961547): Support dynamic shaped `input` in MatrixBandPartOp.
|
||||
def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_upper),
|
||||
[(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"0"> $input)),
|
||||
(HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"1"> $input)),
|
||||
def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower,
|
||||
$num_upper),
|
||||
[(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"1"> $input, $num_lower)),
|
||||
(HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"0"> $input, $num_upper)),
|
||||
(HLO_SelectOp:$num_lower_or_m
|
||||
(HLO_CompareOp
|
||||
$num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)),
|
||||
@ -414,22 +416,17 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_
|
||||
$n_dim,
|
||||
$num_upper
|
||||
),
|
||||
(HLO_SelectOp
|
||||
(TF_SelectV2Op
|
||||
(HLO_AndOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
(HLO_NegOp
|
||||
(createConvertOp $op, $num_lower_or_m, $input)
|
||||
),
|
||||
(HLO_NegOp $num_lower_or_m),
|
||||
(HLO_SubOp:$offset
|
||||
(createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input)
|
||||
(createIotaOp<"1"> $op, $input, $num_lower),
|
||||
(createIotaOp<"0"> $op, $input, $num_lower)
|
||||
),
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
|
||||
),
|
||||
(HLOClient_BroadcastCompareOp
|
||||
$offset,
|
||||
(createConvertOp
|
||||
$op, $num_upper_or_n, $input
|
||||
),
|
||||
(HLOClient_BroadcastCompareOp $offset, $num_upper_or_n,
|
||||
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
|
||||
)
|
||||
),
|
||||
|
@ -848,6 +848,7 @@ tf_xla_py_test(
|
||||
size = "medium",
|
||||
timeout = "long",
|
||||
srcs = ["matrix_band_part_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
Loading…
x
Reference in New Issue
Block a user