[MLIR] Add XLA HLO -> LMHLO conversion for all elementwise ops.
PiperOrigin-RevId: 345557248 Change-Id: I5832bb00cb735489f6115c19007b68a49b434a0a
This commit is contained in:
parent
a0a365e408
commit
e742fee43c
@ -85,6 +85,8 @@ def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
|
||||
def LHLO_BitcastConvertOp:
|
||||
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
|
||||
|
||||
def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer>, BASE_HLO_CbrtOp;
|
||||
|
||||
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
|
||||
|
||||
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
|
||||
|
@ -139,6 +139,7 @@ cc_library(
|
||||
":mlir_hlo_to_hlo",
|
||||
":translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_ops_base_enums",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo_gpu",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
|
@ -537,7 +537,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
}
|
||||
case HloOpcode::kAllReduce: {
|
||||
auto all_reduce = Cast<HloAllReduceInstruction>(instruction);
|
||||
attributes.push_back(ConvertReplicaGroups(all_reduce->replica_groups()));
|
||||
attributes.push_back(
|
||||
ConvertReplicaGroups(all_reduce->replica_groups(), *builder_));
|
||||
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
|
||||
auto all_reduce_op = func_builder->create<mlir::mhlo::AllReduceOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
@ -932,7 +933,7 @@ mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
|
||||
const std::vector<ReplicaGroup>& replica_groups) {
|
||||
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder) {
|
||||
int64_t num_groups = replica_groups.size();
|
||||
int64_t group_size =
|
||||
num_groups == 0 ? 0 : replica_groups[0].replica_ids_size();
|
||||
@ -944,9 +945,9 @@ mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups(
|
||||
attr[flat_index++] = group.replica_ids(i);
|
||||
}
|
||||
auto type = mlir::RankedTensorType::get({num_groups, group_size},
|
||||
builder_->getIntegerType(64));
|
||||
return builder_->getNamedAttr("replica_groups",
|
||||
DenseIntElementsAttr::get(type, attr));
|
||||
builder.getIntegerType(64));
|
||||
return builder.getNamedAttr("replica_groups",
|
||||
DenseIntElementsAttr::get(type, attr));
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle(
|
||||
|
@ -64,6 +64,12 @@ class HloFunctionImporter {
|
||||
|
||||
static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape);
|
||||
|
||||
// Converts replica groups to attribute
|
||||
//
|
||||
// TODO(timshen): move this to attribute_importer.h.
|
||||
static mlir::NamedAttribute ConvertReplicaGroups(
|
||||
const std::vector<ReplicaGroup>& replica_groups, mlir::Builder builder);
|
||||
|
||||
private:
|
||||
HloFunctionImporter(mlir::ModuleOp module,
|
||||
std::unordered_map<const xla::HloComputation*,
|
||||
@ -136,10 +142,6 @@ class HloFunctionImporter {
|
||||
// padding low and padding high for each of the spatial dimensions.
|
||||
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
|
||||
|
||||
// Converts replica groups to attribute
|
||||
mlir::NamedAttribute ConvertReplicaGroups(
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
|
||||
// Converts channel id to attribute
|
||||
mlir::NamedAttribute ConvertChannelHandle(
|
||||
absl::optional<tensorflow::int64> channel_id);
|
||||
|
@ -169,3 +169,52 @@ ENTRY main {
|
||||
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule GemmBias
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK: "lmhlo_gpu.gemm_bias"
|
||||
// CHECK-SAME: algorithm = 0 : i64
|
||||
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
|
||||
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
|
||||
// CHECK-SAME: batch_size = 1 : i64
|
||||
// CHECK-SAME: beta = 1.000000e+00 : f64
|
||||
// CHECK-SAME: lhs_batching_dimensions = dense<> : tensor<0xi64>
|
||||
// CHECK-SAME: lhs_contracting_dimensions = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: rhs_batching_dimensions = dense<> : tensor<0xi64>
|
||||
// CHECK-SAME: rhs_contracting_dimensions = dense<0> : tensor<1xi64>
|
||||
// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
|
||||
ENTRY main {
|
||||
%A = f32[1,1]{1,0} parameter(0)
|
||||
%B = f32[1,4]{1,0} parameter(1)
|
||||
%C = f32[1,4]{1,0} parameter(2)
|
||||
ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
|
||||
custom_call_target="__cublas$gemm",
|
||||
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"batch_size\":\"1\",\"selected_algorithm\":\"0\"}"
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
HloModule AllReduce
|
||||
|
||||
// Test all-reduce
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_all_reduce
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: memref<8xf32>
|
||||
%test_all_reduce {
|
||||
input = f32[8] parameter(0)
|
||||
// CHECK: "lmhlo.all_reduce"([[INPUT]], {{.*}})
|
||||
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
|
||||
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
|
||||
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {
|
||||
// CHECK-SAME: channel_id = {handle = 1 : i64, type = 0 : i64}
|
||||
// CHECK-SAME: replica_groups = dense<{{\[\[}}0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
|
||||
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
|
||||
}
|
||||
|
@ -43,6 +43,34 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.atan2
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.atan2"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.bitcast_convert
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.bitcast_convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
@ -56,6 +84,63 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.cbrt
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.cbrt"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 2
|
||||
// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8>
|
||||
func @main(%pred: tensor<2x2xf32>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.clamp
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%0 = "mhlo.clamp"(%pred, %lhs, %rhs) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.count_leading_zeros
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.count_leading_zeros"(%value) : (tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
|
||||
// CHECK: lmhlo.compare
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.compare"(%value0, %value1) {comparison_direction="GT"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %res : tensor<2x2xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
@ -71,6 +156,19 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
|
||||
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf16> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<2x2xf16>
|
||||
// CHECK: lmhlo.convert
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.convert"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf16>
|
||||
return %res : tensor<2x2xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
@ -113,6 +211,45 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.exponential_minus_one
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.exponential_minus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.floor
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.floor"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
|
||||
// CHECK: lmhlo.is_finite
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.is_finite"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %res : tensor<2x2xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
@ -126,6 +263,39 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.log_plus_one
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.log_plus_one"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.map
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK: return
|
||||
%res = "mhlo.map"(%value0, %value1) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = "mhlo.add"(%a, %b) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%ret = "mhlo.add"(%a, %c) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%ret) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
@ -184,6 +354,90 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<4xi8>
|
||||
func @main(%value0: tensor<2x2xi1>) -> tensor<2x2xi1> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
|
||||
// CHECK: lmhlo.not
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.not"(%value0) : (tensor<2x2xi1>) -> tensor<2x2xi1>
|
||||
return %res : tensor<2x2xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.not
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.not"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
|
||||
func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
|
||||
// CHECK: lmhlo.or
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
|
||||
return %res : tensor<2x2xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.or
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.or"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.popcnt
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.popcnt"(%value0) : (tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.power
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.power"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
|
||||
@ -210,6 +464,19 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.reduce_precision
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.reduce_precision"(%value0) {exponent_bits=5 : i32, mantissa_bits=12 : i32}: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
@ -225,6 +492,19 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.round_nearest_afz
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.round_nearest_afz"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
@ -254,6 +534,51 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.shift_left
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.shift_left"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.shift_right_arithmetic
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.shift_right_arithmetic"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.shift_right_logical
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.shift_right_logical"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
@ -267,6 +592,19 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lmhlo.sine
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.sine"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
@ -308,6 +646,36 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<4xi8>
|
||||
func @main(%value0: tensor<2x2xi1>, %value1: tensor<2x2xi1>) -> tensor<2x2xi1> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<4xi8> to memref<2x2xi1>
|
||||
// CHECK: lmhlo.xor
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi1>, tensor<2x2xi1>) -> tensor<2x2xi1>
|
||||
return %res : tensor<2x2xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lmhlo.xor
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.xor"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %res : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32>
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/PassOptions.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
@ -229,12 +230,28 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
||||
return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr);
|
||||
case HloOpcode::kAdd:
|
||||
return CreateOpWithoutAttrs<lmhlo::AddOp>(instr);
|
||||
case HloOpcode::kAllReduce:
|
||||
return EmitAllReduceOp(instr);
|
||||
case HloOpcode::kAnd:
|
||||
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
|
||||
case HloOpcode::kAtan2:
|
||||
return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
|
||||
case HloOpcode::kBitcastConvert:
|
||||
return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
|
||||
case HloOpcode::kCeil:
|
||||
return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr);
|
||||
case HloOpcode::kCbrt:
|
||||
return CreateOpWithoutAttrs<lmhlo::CbrtOp>(instr);
|
||||
case HloOpcode::kClamp:
|
||||
return CreateOpWithoutAttrs<lmhlo::ClampOp>(instr);
|
||||
case HloOpcode::kClz:
|
||||
return CreateOpWithoutAttrs<lmhlo::ClzOp>(instr);
|
||||
case HloOpcode::kCompare:
|
||||
return EmitCompareOp(instr);
|
||||
case HloOpcode::kComplex:
|
||||
return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr);
|
||||
case HloOpcode::kConvert:
|
||||
return CreateOpWithoutAttrs<lmhlo::ConvertOp>(instr);
|
||||
case HloOpcode::kCopy:
|
||||
return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr);
|
||||
case HloOpcode::kCos:
|
||||
@ -243,10 +260,20 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
||||
return CreateOpWithoutAttrs<lmhlo::DivOp>(instr);
|
||||
case HloOpcode::kExp:
|
||||
return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr);
|
||||
case HloOpcode::kExpm1:
|
||||
return CreateOpWithoutAttrs<lmhlo::Expm1Op>(instr);
|
||||
case HloOpcode::kFloor:
|
||||
return CreateOpWithoutAttrs<lmhlo::FloorOp>(instr);
|
||||
case HloOpcode::kImag:
|
||||
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr);
|
||||
case HloOpcode::kIsFinite:
|
||||
return CreateOpWithoutAttrs<lmhlo::IsFiniteOp>(instr);
|
||||
case HloOpcode::kLog:
|
||||
return CreateOpWithoutAttrs<lmhlo::LogOp>(instr);
|
||||
case HloOpcode::kLog1p:
|
||||
return CreateOpWithoutAttrs<lmhlo::Log1pOp>(instr);
|
||||
case HloOpcode::kMap:
|
||||
return EmitMapOp(instr);
|
||||
case HloOpcode::kMaximum:
|
||||
return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr);
|
||||
case HloOpcode::kMinimum:
|
||||
@ -255,22 +282,44 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
|
||||
return CreateOpWithoutAttrs<lmhlo::MulOp>(instr);
|
||||
case HloOpcode::kNegate:
|
||||
return CreateOpWithoutAttrs<lmhlo::NegOp>(instr);
|
||||
case HloOpcode::kNot:
|
||||
return CreateOpWithoutAttrs<lmhlo::NotOp>(instr);
|
||||
case HloOpcode::kOr:
|
||||
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
|
||||
case HloOpcode::kPopulationCount:
|
||||
return CreateOpWithoutAttrs<lmhlo::PopulationCountOp>(instr);
|
||||
case HloOpcode::kPower:
|
||||
return CreateOpWithoutAttrs<lmhlo::PowOp>(instr);
|
||||
case HloOpcode::kReal:
|
||||
return CreateOpWithoutAttrs<lmhlo::RealOp>(instr);
|
||||
case HloOpcode::kReducePrecision:
|
||||
return EmitReducePrecisionOp(instr);
|
||||
case HloOpcode::kRemainder:
|
||||
return CreateOpWithoutAttrs<lmhlo::RemOp>(instr);
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
return CreateOpWithoutAttrs<lmhlo::RoundOp>(instr);
|
||||
case HloOpcode::kRsqrt:
|
||||
return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr);
|
||||
case HloOpcode::kSelect:
|
||||
return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr);
|
||||
case HloOpcode::kShiftLeft:
|
||||
return CreateOpWithoutAttrs<lmhlo::ShiftLeftOp>(instr);
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
return CreateOpWithoutAttrs<lmhlo::ShiftRightLogicalOp>(instr);
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
return CreateOpWithoutAttrs<lmhlo::ShiftRightArithmeticOp>(instr);
|
||||
case HloOpcode::kSign:
|
||||
return CreateOpWithoutAttrs<lmhlo::SignOp>(instr);
|
||||
case HloOpcode::kSin:
|
||||
return CreateOpWithoutAttrs<lmhlo::SinOp>(instr);
|
||||
case HloOpcode::kSqrt:
|
||||
return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr);
|
||||
case HloOpcode::kSubtract:
|
||||
return CreateOpWithoutAttrs<lmhlo::SubOp>(instr);
|
||||
case HloOpcode::kTanh:
|
||||
return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr);
|
||||
case HloOpcode::kXor:
|
||||
return CreateOpWithoutAttrs<lmhlo::XorOp>(instr);
|
||||
case HloOpcode::kSort:
|
||||
return EmitSortOp(instr);
|
||||
case HloOpcode::kFusion:
|
||||
@ -642,6 +691,92 @@ StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
|
||||
return reduce_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
|
||||
auto* map = ::xla::Cast<::xla::HloMapInstruction>(instr);
|
||||
std::vector<int64_t> dimensions(map->dimensions().begin(),
|
||||
map->dimensions().end());
|
||||
map_op.dimensionsAttr(GetI64DenseElementsAttr(dimensions));
|
||||
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||
*instr->called_computations()[0], &map_op.computation(), &builder_));
|
||||
return map_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
|
||||
HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto compare_op,
|
||||
CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
|
||||
|
||||
auto* compare = ::xla::Cast<::xla::HloCompareInstruction>(instr);
|
||||
auto direction = [&]() {
|
||||
switch (compare->direction()) {
|
||||
case xla::ComparisonDirection::kEq:
|
||||
return mhlo::ComparisonDirection::EQ;
|
||||
case xla::ComparisonDirection::kNe:
|
||||
return mhlo::ComparisonDirection::NE;
|
||||
case xla::ComparisonDirection::kGe:
|
||||
return mhlo::ComparisonDirection::GE;
|
||||
case xla::ComparisonDirection::kGt:
|
||||
return mhlo::ComparisonDirection::GT;
|
||||
case xla::ComparisonDirection::kLe:
|
||||
return mhlo::ComparisonDirection::LE;
|
||||
case xla::ComparisonDirection::kLt:
|
||||
return mhlo::ComparisonDirection::LT;
|
||||
}
|
||||
}();
|
||||
compare_op.comparison_directionAttr(
|
||||
builder_.getStringAttr(stringifyComparisonDirection(direction)));
|
||||
auto compare_type = [&]() {
|
||||
switch (compare->type()) {
|
||||
case xla::Comparison::Type::kFloat:
|
||||
return mhlo::ComparisonType::FLOAT;
|
||||
case xla::Comparison::Type::kFloatTotalOrder:
|
||||
return mhlo::ComparisonType::TOTALORDER;
|
||||
case xla::Comparison::Type::kSigned:
|
||||
return mhlo::ComparisonType::SIGNED;
|
||||
case xla::Comparison::Type::kUnsigned:
|
||||
return mhlo::ComparisonType::UNSIGNED;
|
||||
}
|
||||
}();
|
||||
compare_op.compare_typeAttr(
|
||||
builder_.getStringAttr(stringifyComparisonType(compare_type)));
|
||||
return compare_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
|
||||
HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
|
||||
CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
|
||||
auto* reduce_precision =
|
||||
::xla::Cast<::xla::HloReducePrecisionInstruction>(instr);
|
||||
reduce_precision_op.exponent_bitsAttr(
|
||||
builder_.getI32IntegerAttr(reduce_precision->exponent_bits()));
|
||||
reduce_precision_op.mantissa_bitsAttr(
|
||||
builder_.getI32IntegerAttr(reduce_precision->mantissa_bits()));
|
||||
return reduce_precision_op;
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
|
||||
HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto all_reduce_op,
|
||||
CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
|
||||
auto* all_reduce = ::xla::Cast<::xla::HloAllReduceInstruction>(instr);
|
||||
auto replica_groups_attr = ::xla::HloFunctionImporter::ConvertReplicaGroups(
|
||||
all_reduce->replica_groups(), builder_);
|
||||
all_reduce_op.setAttr(replica_groups_attr.first, replica_groups_attr.second);
|
||||
all_reduce_op.constrain_layoutAttr(
|
||||
builder_.getBoolAttr(all_reduce->constrain_layout()));
|
||||
all_reduce_op.channel_idAttr(mlir::mhlo::ChannelHandle::get(
|
||||
builder_.getI64IntegerAttr(all_reduce->channel_id().value_or(0)),
|
||||
builder_.getI64IntegerAttr(0), builder_.getContext()));
|
||||
all_reduce_op.use_global_device_idsAttr(
|
||||
builder_.getBoolAttr(all_reduce->use_global_device_ids()));
|
||||
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||
*instr->called_computations()[0], &all_reduce_op.computation(),
|
||||
&builder_));
|
||||
return all_reduce_op;
|
||||
}
|
||||
|
||||
StatusOr<Value> LhloDialectEmitter::GetOrCreateArrayView(
|
||||
const ::xla::HloInstruction* instr, const ::xla::Shape& current_shape,
|
||||
const ::xla::ShapeIndex& shape_index) {
|
||||
|
@ -72,6 +72,16 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
::xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
|
||||
const ::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
|
||||
::xla::HloInstruction* instr);
|
||||
|
||||
::xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
|
||||
::xla::HloInstruction* instr);
|
||||
|
||||
::xla::Status CreateOperands(::xla::HloInstruction* instr,
|
||||
SmallVectorImpl<Value>& operands,
|
||||
size_t& num_arguments, size_t& num_results);
|
||||
|
Loading…
x
Reference in New Issue
Block a user