Fix the MHLO to LMHLO lowering of 'gather'

The lowering assumes that the 'gather' op attributes are identical in both MHLO and LMHLO. But that's not true; some time ago the MHLO version was changed to pack 4 of its attributes into a struct. By doing the same for the LMHLO version we both fix the lowering for this op and resolve a longstanding TODO.

PiperOrigin-RevId: 337943946
Change-Id: I872adaec775cffa21a35ff2dd2d10db1b15330fe
This commit is contained in:
A. Unique TensorFlower 2020-10-19 15:13:24 -07:00 committed by TensorFlower Gardener
parent ea13fb0c5a
commit 06384d97df
4 changed files with 31 additions and 13 deletions
tensorflow/compiler
mlir/hlo
include/mlir-hlo/Dialect/mhlo/IR
tests
xla/service/mlir_gpu

View File

@ -602,11 +602,8 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
I64Attr:$index_vector_dim,
I64ElementsAttr:$offset_dims,
GatherDimensionNumbers:$dimension_numbers,
I64ElementsAttr:$slice_sizes,
I64ElementsAttr:$collapsed_slice_dims,
I64ElementsAttr:$start_index_map,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

View File

@ -287,6 +287,28 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// -----
// BOTH-LABEL: func @gather
func @gather(%operand: memref<13x7xf32>, %idxs: memref<5xi32>, %result: memref<5x7xf32>) {
%tensor_operand = tensor_load %operand : memref<13x7xf32>
%tensor_idxs = tensor_load %idxs : memref<5xi32>
%tensor_result =
"mhlo.gather"(%tensor_operand, %tensor_idxs)
{ dimension_numbers =
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
, index_vector_dim = 1 : i64
, offset_dims = dense<1> : tensor<1xi64>
, start_index_map = dense<0> : tensor<1xi64> }
, indices_are_sorted = false
, name = "gather.71"
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
: (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
// BOTH: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<5x7xf32>
return
}
// -----
// BOTH-LABEL: func @imag_dyn
func @imag_dyn(%operand: memref<?xcomplex<f32>>, %result: memref<?xf32>) {
%tensor_operand = tensor_load %operand : memref<?xcomplex<f32>>

View File

@ -322,12 +322,9 @@ Status LhloDialectEmitter::HandleGather(HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr));
OpBuilder func_builder(function.getBody());
// TODO(pifon): Clean-up LHLO GatherOp to be consistent with HLO GatherOp.
func_builder.create<lhlo::GatherOp>(
getLocation(instr), function.getArgument(0), function.getArgument(1),
dim_numbers.index_vector_dim(), dim_numbers.offset_dims(), slice_sizes,
dim_numbers.collapsed_slice_dims(), dim_numbers.start_index_map(),
function.getArgument(2));
dim_numbers, slice_sizes, function.getArgument(2));
return Status::OK();
}

View File

@ -12,9 +12,11 @@ ENTRY %Gather (x: f32[100,10], y: s64[4,6]) -> f32[4,6,10] {
// CHECK: func @gather(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]],
// CHECK-SAME: %[[RESULT:.*]]: [[RTYPE:.*]]) {
// CHECK-NEXT: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) {
// CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>,
// CHECK-SAME: index_vector_dim = 2 : i64,
// CHECK-SAME: offset_dims = dense<2> : tensor<1xi64>,
// CHECK-SAME: slice_sizes = dense<[1, 10]> : tensor<2xi64>,
// CHECK-SAME: start_index_map = dense<0> : tensor<1xi64>
// CHECK-SAME: dimension_numbers = {
// CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>,
// CHECK-SAME: index_vector_dim = 2 : i64,
// CHECK-SAME: offset_dims = dense<2> : tensor<1xi64>,
// CHECK-SAME: start_index_map = dense<0> : tensor<1xi64>
// CHECK-SAME: },
// CHECK-SAME: slice_sizes = dense<[1, 10]> : tensor<2xi64>
// CHECK-SAME: } : ([[TYPE0]], [[TYPE1]], [[RTYPE]]) -> ()