Fix TensorFlow MatrixBandPart op lowering to HLO

createIotaOp should use num_lower operand element type and handle operands requiring broadcasting between conditional result and input.

PiperOrigin-RevId: 338778254
Change-Id: Ie569518a85735a9dc3c8885c20eda7fe74901ec6
This commit is contained in:
Smit Hinsu 2020-10-23 18:06:59 -07:00 committed by TensorFlower Gardener
parent d39bcbaefb
commit e7b54fbda2
4 changed files with 43 additions and 46 deletions

View File

@ -1224,32 +1224,30 @@ func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf
// CHECK-LABEL: matrix_band_part
// CHECK-SAME: (%[[INPUT:.*]]: tensor<64x64xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<64x64xbf16> {
// CHECK: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64>
// CHECK: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64>
// CHECK-DAG: %[[M:.*]] = mhlo.constant dense<64> : tensor<i64>
// CHECK-DAG: %[[N:.*]] = mhlo.constant dense<64> : tensor<i64>
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i64>
// CHECK-DAG: %[[A:.*]] = "mhlo.compare"(%[[LOWER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-DAG: %[[B:.*]] = "mhlo.select"(%[[A]], %[[M]], %[[LOWER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK-DAG: %[[C:.*]] = "mhlo.compare"(%[[UPPER]], %[[ZERO]]) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-DAG: %[[D:.*]] = "mhlo.select"(%[[C]], %[[N]], %[[UPPER]]) : (tensor<i1>, tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK-DAG: %[[F:.*]] = "mhlo.negate"(%[[B]]) : (tensor<i64>) -> tensor<i64>
// CHECK: %[[E:.*]] = "mhlo.convert"(%[[B]]) : (tensor<i64>) -> tensor<bf16>
// CHECK: %[[F:.*]] = "mhlo.negate"(%[[E]]) : (tensor<bf16>) -> tensor<bf16>
// CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xi64>
// CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xi64>
// CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xi64>
// CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<64x64xi64>) -> tensor<64x64xi1>
// CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16>
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16>
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16>
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<64x64xi1>
// CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<64x64xi64>, tensor<i64>) -> tensor<64x64xi1>
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<64x64xi1>
// CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
// CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16>
// CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<64x64xbf16>
// CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]])
// CHECK: return %[[R]]
// CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]])
// CHECK-DAG: return %[[R]]
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<64x64xbf16>, tensor<i64>, tensor<i64>) -> tensor<64x64xbf16>
return %0 : tensor<64x64xbf16>
}
@ -1257,19 +1255,20 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: ten
// CHECK-LABEL: matrix_band_part_2
// CHECK-SAME: (%[[INPUT:.*]]: tensor<12x24x48xbf16>, %[[LOWER:.*]]: tensor<i64>, %[[UPPER:.*]]: tensor<i64>)
func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<12x24x48xbf16> {
// CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xbf16>
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16>
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16>
// CHECK-DAG: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<24x48xi64>
// CHECK-DAG: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xi64>
// CHECK-DAG: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xi64>
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<24x48xi1>
// CHECK-DAG: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<i64>, tensor<24x48xi64>) -> tensor<24x48xi1>
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<24x48xi1>
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
// CHECK-DAG: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[D]] {comparison_direction = "LE"} : (tensor<24x48xi64>, tensor<i64>) -> tensor<24x48xi1>
// CHECK-DAG: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
// CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
// CHECK: %[[R:.*]] = "mhlo.select"(%[[J]], %[[INPUT]], %[[ZERO2]])
// CHECK: return %[[R]]
// CHECK-DAG: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
// CHECK-DAG: %[[K:.*]] = "mhlo.broadcast_in_dim"(%[[J]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<24x48xi1>) -> tensor<12x24x48xi1>
// CHECK-DAG: %[[R:.*]] = "mhlo.select"(%[[K]], %[[INPUT]], %[[ZERO2]])
// CHECK-DAG: return %[[R]]
%0 = "tf.MatrixBandPart"(%arg0, %arg1, %arg2) : (tensor<12x24x48xbf16>, tensor<i64>, tensor<i64>) -> tensor<12x24x48xbf16>
return %0 : tensor<12x24x48xbf16>
}

View File

@ -786,11 +786,11 @@ static int GetDimensionSizeFromEnd(Value input, int dim_from_end) {
// dimension and last dimension, respectively). The element type of the
// outputted RankedTensorType will match the element type of `input`.
// Requires that `input` is a tensor.
static RankedTensorType Get2DTensorType(Value input) {
static RankedTensorType Get2DTensorType(Value input, Value num_lower) {
// `dim_0` refers to the second-to-last dimension; `dim_1` refers to the last.
int dim_0 = GetDimensionSizeFromEnd(input, 1);
int dim_1 = GetDimensionSizeFromEnd(input, 0);
auto element_type = input.getType().cast<TensorType>().getElementType();
auto element_type = num_lower.getType().cast<TensorType>().getElementType();
return RankedTensorType::get({dim_0, dim_1}, element_type);
}

View File

@ -365,7 +365,8 @@ class getIntegerAttr<string x>: NativeCodeCall<
"$_builder.getI64IntegerAttr(" # x # ")">;
class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall<
"$_builder.getI64IntegerAttr(GetDimensionSizeFromEnd($0, " # dimFromEnd # "))"
"$_builder.getIntegerAttr(getElementTypeOrSelf($1.getType()), "
" GetDimensionSizeFromEnd($0, " # dimFromEnd # "))"
>;
// TODO(b/149615308): Enable IotaOp usage as a child operation in a pattern
@ -373,7 +374,7 @@ class GetDimensionSizeFromEnd<string dimFromEnd>: NativeCodeCall<
// cannot be inferred.
class createIotaOp<string dim>: NativeCodeCall<
"$_builder.create<mhlo::IotaOp>($0.getOwner()->getLoc(), "
"Get2DTensorType($1), $_builder.getI64IntegerAttr(" # dim # "))">;
"Get2DTensorType($1, $2), $_builder.getI64IntegerAttr(" # dim # "))">;
// This op needs to be created in C++ because the generated Convert Op has no
// way to specify shape information as an input. In the MatrixBandPart op
@ -396,9 +397,10 @@ def createConvertOp: NativeCodeCall<
// return (indicator ? input : zero_matrix)
//
// TODO(b/149961547): Support dynamic shaped `input` in MatrixBandPartOp.
def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_upper),
[(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"0"> $input)),
(HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"1"> $input)),
def : Pattern<(TF_MatrixBandPartOp:$op AnyStaticShapeTensor:$input, $num_lower,
$num_upper),
[(HLO_ConstOp:$m_dim (GetDimensionSizeFromEnd<"1"> $input, $num_lower)),
(HLO_ConstOp:$n_dim (GetDimensionSizeFromEnd<"0"> $input, $num_upper)),
(HLO_SelectOp:$num_lower_or_m
(HLO_CompareOp
$num_lower, (HLO_ConstOp:$zero (ConstantSplat<"0"> $num_lower)),
@ -414,22 +416,17 @@ def : Pattern<(TF_MatrixBandPartOp:$op AnyRankedTensor:$input, $num_lower, $num_
$n_dim,
$num_upper
),
(HLO_SelectOp
(TF_SelectV2Op
(HLO_AndOp
(HLOClient_BroadcastCompareOp
(HLO_NegOp
(createConvertOp $op, $num_lower_or_m, $input)
),
(HLO_NegOp $num_lower_or_m),
(HLO_SubOp:$offset
(createIotaOp<"1"> $op, $input), (createIotaOp<"0"> $op, $input)
(createIotaOp<"1"> $op, $input, $num_lower),
(createIotaOp<"0"> $op, $input, $num_lower)
),
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
),
(HLOClient_BroadcastCompareOp
$offset,
(createConvertOp
$op, $num_upper_or_n, $input
),
(HLOClient_BroadcastCompareOp $offset, $num_upper_or_n,
(NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LE
)
),

View File

@ -848,6 +848,7 @@ tf_xla_py_test(
size = "medium",
timeout = "long",
srcs = ["matrix_band_part_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip