[MLIR] Add XLA HLO -> LMHLO conversion for all elementwise ops.

PiperOrigin-RevId: 345557248
Change-Id: I5832bb00cb735489f6115c19007b68a49b434a0a
This commit is contained in:
Tim Shen 2020-12-03 16:11:22 -08:00 committed by TensorFlower Gardener
parent a0a365e408
commit e742fee43c
8 changed files with 577 additions and 9 deletions

View File

@ -85,6 +85,8 @@ def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
def LHLO_BitcastConvertOp:
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer>, BASE_HLO_CbrtOp;
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;

View File

@ -139,6 +139,7 @@ cc_library(
":mlir_hlo_to_hlo",
":translate_cl_options",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:hlo_ops_base_enums",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
"//tensorflow/compiler/xla:debug_options_flags",

View File

@ -537,7 +537,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
}
case HloOpcode::kAllReduce: {
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
attributes.push_back(ConvertReplicaGroups(all_reduce->replica_groups()));
attributes.push_back(
ConvertReplicaGroups(all_reduce->replica_groups(), *builder_));
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
loc, result_type, operands, attributes);
@ -932,7 +933,7 @@ mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
}
mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups) {
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder) {
int64_t num_groups = replica_groups.size();
int64_t group_size =
num_groups == 0 ? 0 : replica_groups[0].replica_ids_size();
@ -944,9 +945,9 @@ mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
attr[flat_index++] = group.replica_ids(i);
}
auto type = mlir::RankedTensorType::get({num_groups, group_size},
builder_->getIntegerType(64));
return builder_->getNamedAttr("replica_groups",
DenseIntElementsAttr::get(type, attr));
builder.getIntegerType(64));
return builder.getNamedAttr("replica_groups",
DenseIntElementsAttr::get(type, attr));
}
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(

View File

@ -64,6 +64,12 @@ class HloFunctionImporter {
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
// Converts replica groups to attribute
//
// TODO(timshen): move this to attribute_importer.h.
static mlir::NamedAttribute ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder);
private:
HloFunctionImporter(mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
@ -136,10 +142,6 @@ class HloFunctionImporter {
// padding low and padding high for each of the spatial dimensions.
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
// Converts replica groups to attribute
mlir::NamedAttribute ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups);
// Converts channel id to attribute
mlir::NamedAttribute ConvertChannelHandle(
absl::optional<tensorflow::int64> channel_id);

View File

@ -169,3 +169,52 @@ ENTRY main {
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
}
// -----
HloModule GemmBias
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.gemm_bias"
// CHECK-SAME: algorithm = 0 : i64
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
// CHECK-SAME: batch_size = 1 : i64
// CHECK-SAME: beta = 1.000000e+00 : f64
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
ENTRY main {
%A = f32[1,1]{1,0} parameter(0)
%B = f32[1,4]{1,0} parameter(1)
%C = f32[1,4]{1,0} parameter(2)
ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
custom_call_target="__cublas$gemm",
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
}
// -----
HloModule AllReduce
// Test all-reduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func @test_all_reduce
// CHECK-SAME: ([[INPUT:%.*]]: memref<8xf32>
%test_all_reduce {
input = f32[8] parameter(0)
// CHECK: "lmhlo.all_reduce"([[INPUT]], {{.*}})
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: channel_id = {handle = 1 : i64, type = 0 : i64}
// CHECK-SAME: replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
}

View File

@ -43,6 +43,34 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.atan2
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.atan2"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.bitcast_convert
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.bitcast_convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
@ -56,6 +84,63 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.cbrt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.cbrt"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 2
// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8>
func @main(%pred: tensor<2x2xf32>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.clamp
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
// CHECK-NEXT: return
%0 = "mhlo.clamp"(%pred, %lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.count_leading_zeros
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.count_leading_zeros"(%value) : (tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xi1> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
// CHECK: lmhlo.compare
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.compare"(%value0, %value1) {comparison_direction="GT"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %res : tensor<2x2xi1>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
@ -71,6 +156,19 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf16> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<2x2xf16>
// CHECK: lmhlo.convert
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf16>
return %res : tensor<2x2xf16>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
@ -113,6 +211,45 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.exponential_minus_one
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.exponential_minus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.floor
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.floor"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xi1> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
// CHECK: lmhlo.is_finite
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.is_finite"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xi1>
return %res : tensor<2x2xi1>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
@ -126,6 +263,39 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.log_plus_one
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.log_plus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.map
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK: return
%res = "mhlo.map"(%value0, %value1) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = "mhlo.add"(%a, %b) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%ret = "mhlo.add"(%a, %c) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%ret) : (tensor<f32>) -> ()
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
@ -184,6 +354,90 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
func @main(%value0: tensor<2x2xi1>) -> tensor<2x2xi1> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
// CHECK: lmhlo.not
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.not"(%value0) : (tensor<2x2xi1>) -> tensor<2x2xi1>
return %res : tensor<2x2xi1>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.not
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.not"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
// CHECK: lmhlo.or
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
return %res : tensor<2x2xi1>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.or
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.popcnt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.popcnt"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.power
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.power"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
@ -210,6 +464,19 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.reduce_precision
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.reduce_precision"(%value0) {exponent_bits=5 : i32, mantissa_bits=12 : i32}: (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
@ -225,6 +492,19 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.round_nearest_afz
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.round_nearest_afz"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
@ -254,6 +534,51 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.shift_left
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.shift_left"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.shift_right_arithmetic
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.shift_right_arithmetic"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.shift_right_logical
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.shift_right_logical"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
@ -267,6 +592,19 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lmhlo.sine
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.sine"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
@ -308,6 +646,36 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
// CHECK: lmhlo.xor
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
return %res : tensor<2x2xi1>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lmhlo.xor
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %res : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32>
// CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32>

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
@ -229,12 +230,28 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
case HloOpcode::kAdd:
return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
case HloOpcode::kAllReduce:
return EmitAllReduceOp(instr);
case HloOpcode::kAnd:
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
case HloOpcode::kAtan2:
return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
case HloOpcode::kBitcastConvert:
return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
case HloOpcode::kCeil:
return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr);
case HloOpcode::kCbrt:
return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
case HloOpcode::kClamp:
return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
case HloOpcode::kClz:
return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
case HloOpcode::kCompare:
return EmitCompareOp(instr);
case HloOpcode::kComplex:
return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr);
case HloOpcode::kConvert:
return CreateOpWithoutAttrs<lmhlo::ConvertOp>(instr);
case HloOpcode::kCopy:
return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr);
case HloOpcode::kCos:
@ -243,10 +260,20 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
return CreateOpWithoutAttrs<lmhlo::DivOp>(instr);
case HloOpcode::kExp:
return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr);
case HloOpcode::kExpm1:
return CreateOpWithoutAttrs<lmhlo::Expm1Op>(instr);
case HloOpcode::kFloor:
return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
case HloOpcode::kImag:
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
case HloOpcode::kIsFinite:
return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
case HloOpcode::kLog:
return CreateOpWithoutAttrs<lmhlo::LogOp>(instr);
case HloOpcode::kLog1p:
return CreateOpWithoutAttrs<lmhlo::Log1pOp>(instr);
case HloOpcode::kMap:
return EmitMapOp(instr);
case HloOpcode::kMaximum:
return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr);
case HloOpcode::kMinimum:
@ -255,22 +282,44 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
return CreateOpWithoutAttrs<lmhlo::MulOp>(instr);
case HloOpcode::kNegate:
return CreateOpWithoutAttrs<lmhlo::NegOp>(instr);
case HloOpcode::kNot:
return CreateOpWithoutAttrs<lmhlo::NotOp>(instr);
case HloOpcode::kOr:
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
case HloOpcode::kPopulationCount:
return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr);
case HloOpcode::kPower:
return CreateOpWithoutAttrs<lmhlo::PowOp>(instr);
case HloOpcode::kReal:
return CreateOpWithoutAttrs<lmhlo::RealOp>(instr);
case HloOpcode::kReducePrecision:
return EmitReducePrecisionOp(instr);
case HloOpcode::kRemainder:
return CreateOpWithoutAttrs<lmhlo::RemOp>(instr);
case HloOpcode::kRoundNearestAfz:
return CreateOpWithoutAttrs<lmhlo::RoundOp>(instr);
case HloOpcode::kRsqrt:
return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr);
case HloOpcode::kSelect:
return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr);
case HloOpcode::kShiftLeft:
return CreateOpWithoutAttrs<lmhlo::ShiftLeftOp>(instr);
case HloOpcode::kShiftRightLogical:
return CreateOpWithoutAttrs<lmhlo::ShiftRightLogicalOp>(instr);
case HloOpcode::kShiftRightArithmetic:
return CreateOpWithoutAttrs<lmhlo::ShiftRightArithmeticOp>(instr);
case HloOpcode::kSign:
return CreateOpWithoutAttrs<lmhlo::SignOp>(instr);
case HloOpcode::kSin:
return CreateOpWithoutAttrs<lmhlo::SinOp>(instr);
case HloOpcode::kSqrt:
return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr);
case HloOpcode::kSubtract:
return CreateOpWithoutAttrs<lmhlo::SubOp>(instr);
case HloOpcode::kTanh:
return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
case HloOpcode::kXor:
return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
case HloOpcode::kSort:
return EmitSortOp(instr);
case HloOpcode::kFusion:
@ -642,6 +691,92 @@ StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
return reduce_op;
}
StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
auto* map = ::xla::Cast<::xla::HloMapInstruction>(instr);
std::vector<int64_t> dimensions(map->dimensions().begin(),
map->dimensions().end());
map_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
*instr->called_computations()[0], &map_op.computation(), &builder_));
return map_op;
}
StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto compare_op,
CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
auto* compare = ::xla::Cast<::xla::HloCompareInstruction>(instr);
auto direction = [&]() {
switch (compare->direction()) {
case xla::ComparisonDirection::kEq:
return mhlo::ComparisonDirection::EQ;
case xla::ComparisonDirection::kNe:
return mhlo::ComparisonDirection::NE;
case xla::ComparisonDirection::kGe:
return mhlo::ComparisonDirection::GE;
case xla::ComparisonDirection::kGt:
return mhlo::ComparisonDirection::GT;
case xla::ComparisonDirection::kLe:
return mhlo::ComparisonDirection::LE;
case xla::ComparisonDirection::kLt:
return mhlo::ComparisonDirection::LT;
}
}();
compare_op.comparison_directionAttr(
builder_.getStringAttr(stringifyComparisonDirection(direction)));
auto compare_type = [&]() {
switch (compare->type()) {
case xla::Comparison::Type::kFloat:
return mhlo::ComparisonType::FLOAT;
case xla::Comparison::Type::kFloatTotalOrder:
return mhlo::ComparisonType::TOTALORDER;
case xla::Comparison::Type::kSigned:
return mhlo::ComparisonType::SIGNED;
case xla::Comparison::Type::kUnsigned:
return mhlo::ComparisonType::UNSIGNED;
}
}();
compare_op.compare_typeAttr(
builder_.getStringAttr(stringifyComparisonType(compare_type)));
return compare_op;
}
StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
auto* reduce_precision =
::xla::Cast<::xla::HloReducePrecisionInstruction>(instr);
reduce_precision_op.exponent_bitsAttr(
builder_.getI32IntegerAttr(reduce_precision->exponent_bits()));
reduce_precision_op.mantissa_bitsAttr(
builder_.getI32IntegerAttr(reduce_precision->mantissa_bits()));
return reduce_precision_op;
}
StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_reduce_op,
CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
auto* all_reduce = ::xla::Cast<::xla::HloAllReduceInstruction>(instr);
auto replica_groups_attr = ::xla::HloFunctionImporter::ConvertReplicaGroups(
all_reduce->replica_groups(), builder_);
all_reduce_op.setAttr(replica_groups_attr.first, replica_groups_attr.second);
all_reduce_op.constrain_layoutAttr(
builder_.getBoolAttr(all_reduce->constrain_layout()));
all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
builder_.getI64IntegerAttr(all_reduce->channel_id().value_or(0)),
builder_.getI64IntegerAttr(0), builder_.getContext()));
all_reduce_op.use_global_device_idsAttr(
builder_.getBoolAttr(all_reduce->use_global_device_ids()));
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
*instr->called_computations()[0], &all_reduce_op.computation(),
&builder_));
return all_reduce_op;
}
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
const ::xla::ShapeIndex& shape_index) {

View File

@ -72,6 +72,16 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
const ::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
::xla::HloInstruction* instr);
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
::xla::HloInstruction* instr);
::xla::Status CreateOperands(::xla::HloInstruction* instr,
SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results);