diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 270674bea9f..6005fe6e6dd 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -174,15 +174,16 @@ tensorflow::Status HloFunctionImporter::ImportInstructions( return tensorflow::Status::OK(); } -StatusOr HloFunctionImporter::ImportInstruction( +StatusOr HloFunctionImporter::ImportInstructionImpl( HloInstruction* instruction, mlir::OpBuilder* func_builder) { TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction)); TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType( instruction->shape(), *builder_)); - llvm::SmallVector attributes = {builder_->getNamedAttr( - "name", builder_->getStringAttr(instruction->name()))}; - mlir::Location loc = func_builder->getUnknownLoc(); + mlir::Location loc = + mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()), + func_builder->getContext()); + llvm::SmallVector attributes; switch (instruction->opcode()) { case HloOpcode::kParameter: { return nullptr; @@ -214,8 +215,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return new_operation; \ } case HloOpcode::kBroadcast: { - // Note that the HLO broadcast is more powerful than the XLA broadcast op. - // BroadcastInDim offers a superset of the HLO op's functionality. + // Note that the HLO broadcast is more powerful than the XLA broadcast + // op. BroadcastInDim offers a superset of the HLO op's functionality. attributes.push_back( builder_->getNamedAttr("broadcast_dimensions", ConvertDimensions(instruction->dimensions()))); @@ -458,7 +459,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return op.getOperation(); } - // Otherwise, it is a indexed conditional and should be mapped to Case op. + // Otherwise, it is a indexed conditional and should be mapped to Case + // op. TF_RETURN_IF_ERROR(GetMlirTypes( {instruction->branch_computation(0)->root_instruction()}, &rets)); @@ -474,8 +476,8 @@ StatusOr HloFunctionImporter::ImportInstruction( return op.getOperation(); } case HloOpcode::kConcatenate: { - // TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr - // for concatenate dimension. + // TODO(b/132057942): Support taking an uint64_t instead of an + // IntegerAttr for concatenate dimension. return func_builder ->create( loc, result_type, operands, @@ -703,9 +705,9 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kReal, RealOp); NoAttributeCase(kRemainder, RemOp); NoAttributeCase(kReplicaId, ReplicaIdOp); - // The dimensions attribute is not present on the HLO Reshape instruction. - // If dimensions are non-default, the XLA builder implements it as a - // separate transpose. + // The dimensions attribute is not present on the HLO Reshape + // instruction. If dimensions are non-default, the XLA builder + // implements it as a separate transpose. NoAttributeCase(kReshape, ReshapeOp); NoAttributeCase(kRoundNearestAfz, RoundOp); NoAttributeCase(kRsqrt, RsqrtOp); @@ -720,9 +722,9 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kTanh, TanhOp); NoAttributeCase(kTuple, TupleOp); NoAttributeCase(kXor, XorOp); - // TODO(b/129422361) Copy needs special handling because it is not defined - // in tensorflow/compiler/xla/client/xla_builder.h. - // See operation semantics in + // TODO(b/129422361) Copy needs special handling because it is not + // defined in tensorflow/compiler/xla/client/xla_builder.h. See + // operation semantics in // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy NoAttributeCase(kCopy, CopyOp); #undef NoAttributeCase @@ -754,6 +756,35 @@ StatusOr HloFunctionImporter::ImportInstruction( } } +StatusOr HloFunctionImporter::ImportInstruction( + HloInstruction* instruction, mlir::OpBuilder* func_builder) { + TF_ASSIGN_OR_RETURN(mlir::Operation * op, + ImportInstructionImpl(instruction, func_builder)); + if (op == nullptr) return op; + + // Best-effort propagation of the layouts. These layouts serve as performance + // hints to the backend. + // + // Minor-to-major is a permutation of [0, rank), presenting tensor dimensions + // in physical minor-to-major order. + // + // Note that non-array shapes are not carrying layouts, and users have to + // figure out the proper layouts of them through context. This is one of the + // reasons why the attribute-based solution is temporary. + // + // TODO(timshen): Investigate the necessity of having layouts in MHLO. + if (instruction->shape().IsArray() && + instruction->shape().layout() != + LayoutUtil::MakeDescendingLayout( + instruction->shape().dimensions().size())) { + llvm::SmallVector minor_to_major( + instruction->shape().layout().minor_to_major().begin(), + instruction->shape().layout().minor_to_major().end()); + op->setAttr("minor_to_major", builder_->getIndexTensorAttr(minor_to_major)); + } + return op; +} + StatusOr> HloFunctionImporter::GetOperands( HloInstruction* instruction) { llvm::SmallVector operands; diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index e0cc89004cf..f925f7f471b 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -84,6 +84,8 @@ class HloFunctionImporter { // Imports an instruction. StatusOr ImportInstruction(xla::HloInstruction* instruction, mlir::OpBuilder* func_builder); + StatusOr ImportInstructionImpl( + HloInstruction* instruction, mlir::OpBuilder* func_builder); // Gets the MLIR operand values from an HLO Instruction. StatusOr> GetOperands( diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index d097b9ca314..c1d07702100 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -1064,7 +1064,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { llvm::SmallVector operands; for (auto operand : op.operands()) operands.push_back(values[operand]); - xla::XlaOp fusion = xla::internal::XlaBuilderBuildFusion( + xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion( ctx.builder, operands, absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()), fused_computation); @@ -1160,7 +1160,32 @@ LogicalResult ConvertToHloModule::Lower( xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaComputation* result) { + // See hlo_function_importer.cc for documentation about layouts in MHLO. + auto propagate_layouts = [](mlir::Operation* inst, xla::XlaOp xla_op) { + auto attr = + inst->getAttrOfType("minor_to_major"); + if (!attr) return; + + auto* v = xla::internal::XlaBuilderFriend::GetInstruction(xla_op) + ->mutable_shape() + ->mutable_layout() + ->mutable_minor_to_major(); + v->Clear(); + for (const llvm::APInt& i : attr) { + *v->Add() = i.getZExtValue(); + } + }; + if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) { + if (inst->getNumResults() == 1) { + auto iter = value_lowering->find(inst->getResult(0)); + if (iter == value_lowering->end()) { + inst->emitOpError( + "inst has a result, but it's not found in value_lowering"); + return failure(); + } + propagate_layouts(inst, iter->second); + } return success(); } @@ -1186,16 +1211,17 @@ LogicalResult ConvertToHloModule::Lower( if (failed(GetXlaOp(operand, value_map, &xla_operand, op))) return failure(); value_map[op.getResult()] = xla_operand; + propagate_layouts(inst, xla_operand); return success(); } - // TODO(jpienaar): This doesn't support layouts yet. if (matchPattern(inst, m_Constant(&const_attr))) { auto literal_or = CreateLiteralFromAttr(const_attr); if (!literal_or.ok()) return inst->emitError(literal_or.status().ToString()); - value_map[inst->getResult(0)] = - xla::ConstantLiteral(builder, literal_or.ValueOrDie()); + auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie()); + value_map[inst->getResult(0)] = constant; + propagate_layouts(inst, constant); return success(); } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir index 1032bb723c5..cea0599adb0 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -1,10 +1,10 @@ // RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FILECHECK_OPTS="" FileCheck %s func @main() -> tensor { - %cst = constant {name = "constant"} dense<1> : tensor - %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor - %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor - %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %cst = constant dense<1> : tensor + %cst_0 = constant dense<5.600000e+01> : tensor + %cst_1 = constant dense<1.200000e+01> : tensor + %cst_2 = constant dense<1.300000e+01> : tensor %0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor @@ -17,7 +17,7 @@ func @main() -> tensor { ^bb0(%arg0: tensor): %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () - }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> tensor + }) : (tensor, tensor, tensor, tensor) -> tensor return %0 : tensor } @@ -48,23 +48,23 @@ func @main() -> tensor { // ----- func @main() -> (tensor, tensor) { - %cst = constant {name = "constant"} dense<1> : tensor - %cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor - %cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor - %cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor + %cst = constant dense<1> : tensor + %cst_0 = constant dense<5.600000e+01> : tensor + %cst_1 = constant dense<1.200000e+01> : tensor + %cst_2 = constant dense<1.300000e+01> : tensor %0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( { ^bb0(%arg0: tensor): - %1 = "mhlo.negate"(%arg0) {name = "negate"} : (tensor) -> tensor + %1 = "mhlo.negate"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "mhlo.copy"(%arg0) {name = "copy"} : (tensor) -> tensor + %1 = "mhlo.copy"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () }, { ^bb0(%arg0: tensor): - %1 = "mhlo.floor"(%arg0) {name = "floor"} : (tensor) -> tensor + %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor "mhlo.return"(%1, %1) : (tensor, tensor) -> () - }) {name = "conditional"} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) + }) : (tensor, tensor, tensor, tensor) -> (tensor, tensor) return %0#0, %0#1 : tensor, tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt index 62f0d7a59e4..1fa7367763e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/case_conditional.hlotxt @@ -26,21 +26,21 @@ ENTRY %indexed_conditional () -> f32[] { } // CHECK-LABEL: func @main() -> tensor -// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor -// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor -// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor -// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor +// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor +// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor +// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor +// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor // CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( { // CHECK: ^bb0(%[[ARG_1:.*]]: tensor): -// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_1]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_2:.*]]: tensor): -// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_2]]) : (tensor) -> () // CHECK: }, { // CHECK: ^bb0(%[[ARG_3:.*]]: tensor): -// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor) -> tensor +// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) : (tensor) -> tensor // CHECK: "mhlo.return"(%[[RES_3]]) : (tensor) -> () -// CHECK: }) {name = "{{.*}}"} : (tensor, tensor, tensor, tensor) -> tensor +// CHECK: }) : (tensor, tensor, tensor, tensor) -> tensor // CHECK: return %[[RESULT]] : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 0afa12a5c65..84816e6715a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -439,7 +439,7 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10 // CHECK-SAME: index_vector_dim=1 // CHECK-SAME: slice_sizes={1,1,300} // CHECK-SAME: indices_are_sorted=true - %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32> return %0 : tensor<10x300xf32> } @@ -502,7 +502,7 @@ func @main() -> tensor<1x10xf32> { func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors - %1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor + %1 = mhlo.add %arg2, %arg3 : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -739,7 +739,7 @@ func @main(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> // CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2) // CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]]) - %0 = "mhlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index 86adcf0710f..4cc70be0965 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -9,95 +9,95 @@ ENTRY %tfcompile.48 { %arg0.1 = f32[1,300] parameter(0) %arg1.2 = f32[1,300,3,1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> + // CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32> %reshape.3 = f32[1,300] reshape(%arg0.1) - // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} - // CHECK-NEXT: %2 = "mhlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> + // CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32> %reshape.28 = f32[300,1,1] reshape(%transpose.27) - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32> %reshape.29 = f32[300,1] reshape(%reshape.28) - // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} - // CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor %constant.8 = f32[] constant(1) - // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} - // CHECK-NEXT: %6 = mhlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32> + // CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32> %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) - // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_0 = constant dense<0.000000e+00> : tensor %constant.32 = f32[] constant(0) - // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} - // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> + // CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT - // CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_1 = constant dense<0.000000e+00> : tensor %constant.10 = f32[] constant(0) - // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} - // CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor + // CHECK-NEXT: %cst_2 = constant dense<0.000000e+00> : tensor %constant.40 = f32[] constant(0) - // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> + // CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} - // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %copy.1 = f32[1,300,3,1] copy(%arg1.2) - // CHECK-NEXT: %12 = "mhlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %reshape.4 = f32[1,300,3,1] reshape(%copy.1) - // CHECK-NEXT: %13 = "mhlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> + // CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> %reshape.24 = f32[1,300,3] reshape(%reshape.4) - // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + // CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} - // CHECK-NEXT: %15 = "mhlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> + // CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32> %reshape.26 = f32[300,3] reshape(%transpose.25) - // CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> + // CHECK-NEXT: %cst_3 = constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) // TODO(b/129709049) consider making this default precision config implied. - // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} - // CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32> + // CHECK-NEXT: %cst_4 = constant dense<0.000000e+00> : tensor<5xf32> %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) - // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32> %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} - // CHECK-NEXT: %18 = mhlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32> + // CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32> %add.39 = f32[300,5] add(%dot.36, %broadcast.38) - // CHECK-NEXT: %19 = mhlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32> + // CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32> %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) - // CHECK-NEXT: %20 = "mhlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32> %reshape.44 = f32[300,1,5] reshape(%maximum.42) - // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44) - // CHECK-NEXT: %22 = "mhlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %reshape.46 = f32[300,1,5] reshape(%select.45) - // CHECK-NEXT: %23 = "mhlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> + // CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple> // CHECK-NEXT: return %23 : tuple> ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 4d4e0213da8..90034ce8c07 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -13,12 +13,12 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %Arg_0.1 = f32[4]{0} parameter(0) %Arg_1.2 = f32[4]{0} parameter(1) - // CHECK-NEXT: mhlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> - %add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) + // CHECK-NEXT: mhlo.add %arg0, %arg1 : tensor<4xf32> + %add.42 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor - ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} + // CHECK-NEXT: "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor + ROOT %dot.4 = f32[] dot(f32[4]{0} %add.42, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } // CHECK-LABEL: func @test_after_all @@ -26,7 +26,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { %test_after_all (token0: token[], token1: token[] ) -> token[] { token0 = token[] parameter(0) token1 = token[] parameter(1) - // CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!mhlo.token, !mhlo.token) -> !mhlo.token + // CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) : (!mhlo.token, !mhlo.token) -> !mhlo.token ROOT after-all = token[] after-all(token0, token1) } @@ -75,10 +75,10 @@ add { %test_broadcast_in_dim { %Arg_0.1 = f32[1, 2] parameter(0) - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1} - // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> + // CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2} } @@ -113,7 +113,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32> %test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] { %a = f32[1,291,291] parameter(0) - // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> + // CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32> ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true } @@ -124,16 +124,16 @@ add { %Arg_1.2 = f32[4] parameter(1) %Arg_2.3 = f32[] parameter(2) - // CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3) } // CHECK-LABEL: func @test_collective_permute // CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> %test_collective_permute (input: f32[128,32]) -> f32[128,32] { - %input = f32[128,32]{0,1} parameter(0) - // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> - ROOT root = f32[128,32]{0,1} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}} + %input = f32[128,32]{1,0} parameter(0) + // CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32> + ROOT root = f32[128,32]{1,0} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}} } @@ -143,14 +143,14 @@ add { %Arg_1.2 = f32[3] parameter(1) %Arg_2.3 = f32[3] parameter(2) - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } @@ -159,7 +159,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -176,12 +176,12 @@ add { %test_constant { // Scalar/0D tensor constant - // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<1> : tensor + // CHECK-NEXT: %cst = constant dense<1> : tensor %constant.0 = s64[] constant(1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: constant dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} // CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64> @@ -206,15 +206,15 @@ add { %test_conv { %arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> + // CHECK-NEXT: %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32> %reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1) // Note that double brackets "[[" have to be escaped as they denote variables // in FileCheck. The only way to do so is to drop into regex with "{{" - // CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: %cst = constant dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} // CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) { @@ -241,10 +241,10 @@ add { %convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> + // CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32> %reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} - // CHECK-NEXT: "mhlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple> + // CHECK-NEXT: "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple> ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} } @@ -263,10 +263,10 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "mhlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %0 = "mhlo.convert"(%arg0) : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "mhlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %1 = "mhlo.convert"(%arg1) : (tensor<4xf32>) -> tensor<4xf64> %convert.4 = f64[4] convert(f32[4] %Arg_1.2) // CHECK-NEXT: mhlo.add %0, %1 @@ -277,7 +277,7 @@ add { %test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.cosine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -286,7 +286,7 @@ add { %test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] { %arg1 = f32[2,3] parameter(0) %arg2 = f32[5,5] parameter(1) -// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> +// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, minor_to_major = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32> ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true } @@ -295,7 +295,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.divide %arg0, %arg1 : tensor<4xf32> ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -304,17 +304,17 @@ add { %Arg_0.1 = f32[1, 4] parameter(0) %Arg_1.2 = f32[4, 1] parameter(1) - // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} - // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default} - // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } @@ -325,17 +325,17 @@ add { %Arg_0.1 = f32[4, 1] parameter(0) %Arg_1.2 = f32[1, 4] parameter(1) - // CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} + // CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGH", "HIGHEST"]} dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest} - // CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} + // CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "DEFAULT"]} dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default} - // CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} + // CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1} } @@ -376,7 +376,7 @@ add { %test_exponential (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.exponential"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1) } @@ -384,7 +384,7 @@ add { %test_expm1 (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1) } @@ -400,7 +400,7 @@ add { %test_floor (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.floor"([[A0]]) : (tensor<16xf32>) -> tensor<16xf32> ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1) } @@ -430,7 +430,7 @@ add { // CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>) %test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] { %Arg_0.1 = f32[4,2] parameter(0) - // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor + // CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1} } @@ -438,7 +438,7 @@ add { %test_imag (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "mhlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1) } @@ -468,7 +468,7 @@ add { %test_log (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.log"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %log.2 = f32[16] log(f32[16] %arg0.1) } @@ -476,7 +476,7 @@ add { %test_log1p (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1) } @@ -507,7 +507,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32> ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -516,7 +516,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32> ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -525,7 +525,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 : tensor<4xf32> ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -533,7 +533,7 @@ add { %test_negate (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK-NEXT: "mhlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: "mhlo.negate"(%arg0) : (tensor<16xf32>) -> tensor<16xf32> ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1) } @@ -541,7 +541,7 @@ add { %test_not (arg0.1: pred[16]) -> pred[16] { %arg0.1 = pred[16] parameter(0) - // CHECK: "mhlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1> + // CHECK: "mhlo.not"(%arg0) : (tensor<16xi1>) -> tensor<16xi1> ROOT %not.2 = pred[16] not(pred[16] %arg0.1) } @@ -595,7 +595,7 @@ add { %test_popcnt (arg0.1: s32[16]) -> s32[16] { %arg0.1 = s32[16] parameter(0) - // CHECK: "mhlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32> + // CHECK: "mhlo.popcnt"(%arg0) : (tensor<16xi32>) -> tensor<16xi32> ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1) } @@ -604,7 +604,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.power %arg0, %arg1 : tensor<4xf32> ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -632,7 +632,7 @@ add { %test_real (Arg_0.1: c64[4]) -> f32[4] { %Arg_0.1 = c64[4] parameter(0) - // CHECK-NEXT: "mhlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex>) -> tensor<4xf32> + // CHECK-NEXT: "mhlo.real"(%arg0) : (tensor<4xcomplex>) -> tensor<4xf32> ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1) } @@ -687,7 +687,7 @@ add { // CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor %reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3 - // CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor + // CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor %sub.5 = f32[] subtract(%reduce.3, %reduce.4) ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5) @@ -741,7 +741,7 @@ add { %test_rsqrt (arg0.1: f32[16]) -> f32[16] { %arg0.1 = f32[16] parameter(0) - // CHECK: "mhlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK: "mhlo.rsqrt"([[ARG0]]) : (tensor<16xf32>) -> tensor<16xf32> ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1) } @@ -788,7 +788,7 @@ add { %Arg_1.2 = s32[2,3] parameter(1) %Arg_2.3 = s32[2,3] parameter(2) - // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK: "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) } @@ -835,7 +835,7 @@ add { %test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] { %Arg_0.1 = f32[4,4] parameter(0) %Arg_1.2 = s32[] parameter(1) - // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1} } @@ -843,7 +843,7 @@ add { %test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.sine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1) } @@ -862,7 +862,7 @@ add { // CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32> // CHECK: "mhlo.sort"([[ARG]]) ( { // CHECK: ^bb0([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor): -// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor, tensor) -> tensor +// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32> @@ -871,7 +871,7 @@ add { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: mhlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32> + // CHECK-NEXT: mhlo.subtract %arg0, %arg1 : tensor<4xf32> ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } @@ -879,7 +879,7 @@ add { %test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: "mhlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: "mhlo.tanh"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"} } @@ -887,7 +887,7 @@ add { %test_transpose { %Arg_0.1 = s32[1,2,3,4] parameter(0) - // CHECK: "mhlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // CHECK: "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} } @@ -909,10 +909,10 @@ add { %Arg_0.1 = s32[1] parameter(0) %Arg_1.2 = f32[1, 2] parameter(1) - // CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple> + // CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) : (tensor<1xi32>) -> tuple> %tuple.3 = (s32[1]) tuple(%Arg_0.1) - // CHECK: "mhlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + // CHECK: "mhlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2) } @@ -934,11 +934,11 @@ add { %arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"} // CHECK-NEXT: "mhlo.while"(%arg0) ( { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor, tensor) -> tensor + // CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK-NEXT: "mhlo.return"([[CMP]]) : (tensor) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^bb0(%arg1: tensor): // no predecessors - // CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor + // CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 : tensor // CHECK-NEXT: "mhlo.return"([[ADD]]) : (tensor) -> () // CHECK-NEXT: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond @@ -992,8 +992,8 @@ add { %Arg_1.2 = c128[2] parameter(1) %abs.4 = f64[2] abs(c128[2] %Arg_1.2) - // CHECK: "mhlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf32> - // CHECK: "mhlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex>) -> tensor<2xf64> + // CHECK: "mhlo.abs"(%[[ARG0]]) : (tensor<2xcomplex>) -> tensor<2xf32> + // CHECK: "mhlo.abs"(%[[ARG1]]) : (tensor<2xcomplex>) -> tensor<2xf64> ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4) } @@ -1002,7 +1002,7 @@ add { %unsigned_int(Arg_0.1: u16[4]) -> u16[4] { %Arg_0.1 = u16[4] parameter(0) - // CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> + // CHECK: "mhlo.not"(%[[ARG0]]) : (tensor<4xui16>) -> tensor<4xui16> ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt new file mode 100644 index 00000000000..da07dc0a76b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.hlotxt @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-print-debuginfo -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule Test + +// CHECK-LABEL: func @main +ENTRY A { + %input = f16[128,224,224,4] parameter(0) + %filter = f16[64,7,7,4] parameter(1) + // %0 = "mhlo.convolution"{{.*}}minor_to_major = dense<[1, 3, 2, 0]> : tensor<4xindex>{{.*}} loc("root.42") + ROOT %root.42 = f16[128,64,112,112]{1,3,2,0} convolution(%input, %filter), dim_labels=b01f_o01i->bf01, window={size=7x7 stride=2x2 pad=3_3x3_3} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir new file mode 100644 index 00000000000..6a7debc8c6c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/layouts_and_names.mlir @@ -0,0 +1,30 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text-with-layouts %s | FileCheck %s + +// Checks exporting layouts + +// CHECK: HloModule +func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> tensor<128x64x112x112xf16> { + // CHECK: %convolution.{{.*}} = f16[128,64,112,112]{1,3,2,0} convolution{{.*}}op_name="root.42" + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>, + kernel_input_feature_dimension = 3 : i64, + kernel_output_feature_dimension = 0 : i64, + kernel_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 1 : i64, + output_spatial_dimensions = dense<[ 2, 3 ]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + lhs_dilations = dense<1> : tensor<2xi64>, + minor_to_major = dense<[ 1, 3, 2, 0 ]> : tensor<4xindex>, + padding = dense<3> : tensor<2x2xi64>, + precision_config = [ "DEFAULT", "DEFAULT" ], + rhs_dilations = dense<1> : tensor<2xi64>, + window_strides = dense<2> : tensor<2xi64> + } : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42") + return %0 : tensor<128x64x112x112xf16> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo index d97c5150335..4c288aee956 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo @@ -139,8 +139,8 @@ dynamic_parameter_binding { } # CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { -# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32> +# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 : tensor<4xf32> # TODO(b/129709049) consider making this default precision config inferred. -# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor +# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor # CHECK-NEXT: return %1 : tensor # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt index 855b1c4bcd5..f7e1ba9ff15 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt @@ -4,25 +4,25 @@ HloModule tfcompile.1 // CHECK-LABEL: func @main() -> tensor { ENTRY %tfcompile.1 { - // CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor %constant.0 = f32[] constant(1) - // CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<1.000000e+00> : tensor + // CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor %constant.1 = f64[] constant(1) - // CHECK-NEXT: %cst_1 = constant {name = "constant.2"} dense<1> : tensor + // CHECK-NEXT: %cst_1 = constant dense<1> : tensor %constant.2 = s8[] constant(1) - // CHECK-NEXT: %cst_2 = constant {name = "constant.3"} dense<1> : tensor + // CHECK-NEXT: %cst_2 = constant dense<1> : tensor %constant.3 = s16[] constant(1) - // CHECK-NEXT: %cst_3 = constant {name = "constant.4"} dense<1> : tensor + // CHECK-NEXT: %cst_3 = constant dense<1> : tensor %constant.4 = s32[] constant(1) - // CHECK-NEXT: %cst_4 = constant {name = "constant.5"} dense<1> : tensor + // CHECK-NEXT: %cst_4 = constant dense<1> : tensor %constant.5 = s64[] constant(1) - // CHECK-NEXT: %cst_5 = constant {name = "constant.6"} dense : tensor + // CHECK-NEXT: %cst_5 = constant dense : tensor // CHECK-NEXT: return %cst_5 : tensor ROOT %constant.6 = pred[] constant(1) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt index 126bc88ec7a..f989104323a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt @@ -26,4 +26,4 @@ ENTRY %foo (arg0.1: s64[]) -> s64[] { // CHECK: "mhlo.return" // CHECK: }) : (tensor) -> tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir index 61d7aadb23f..f852ef06421 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.mlir @@ -10,11 +10,11 @@ module { // CHECK: %[[A0]] = s64[] parameter(0) // CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT ^bb0(%arg1: tensor): - %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): - %1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor + %1 = mhlo.add %arg1, %arg1 : tensor "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 4ad44d1bd77..55833bf9939 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -124,8 +124,8 @@ static StatusOr> HloModuleFromProto( return HloModule::CreateFromProto(module_proto, module_config); } -static mlir::LogicalResult MlirHloToHloTextTranslateFunction( - mlir::ModuleOp module, llvm::raw_ostream& output) { +static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl( + mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) { if (!module) return mlir::failure(); HloProto hloProto; @@ -146,9 +146,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( HloModule* hlo_module = statusOrHloModule.ValueOrDie().get(); - // We don't interpret or use layouts output << hlo_module->ToString( - HloPrintOptions().set_include_layout_in_shapes(false)); + HloPrintOptions().set_include_layout_in_shapes(with_layouts)); // Output alias information as comments in the HLO text. hlo_module->input_output_alias_config().ForEachAlias( @@ -162,6 +161,18 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( return mlir::success(); } +static mlir::LogicalResult MlirHloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + return MlirHloToHloTextTranslateFunctionImpl(module, output, + /*with_layouts=*/false); +} + +static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + return MlirHloToHloTextTranslateFunctionImpl(module, output, + /*with_layouts=*/true); +} + } // namespace xla static void RegisterInputDialects(mlir::DialectRegistry& registry) { @@ -176,6 +187,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate( "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction, RegisterInputDialects); +static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate( + "mlir-hlo-to-hlo-text-with-layouts", + xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects); + static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index c7bbf9f8486..168565e9b50 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -132,10 +132,10 @@ bool InstrIsSetBound(const HloInstructionProto* instr_proto) { namespace internal { -XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation) { +XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation) { return builder->ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; instr.set_fusion_kind(std::string(fusion_kind)); @@ -149,6 +149,11 @@ XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, }); } +HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) { + return &op.builder() + ->instructions_[op.builder()->handle_to_index_[op.handle_]]; +} + } // namespace internal XlaOp operator-(XlaOp x) { return Neg(x); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 55bcd86b493..b3fc3628442 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -47,13 +47,18 @@ namespace xla { class XlaBuilder; class XlaOp; +class HloInstruction; namespace internal { -XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, - absl::Span operands, - absl::string_view fusion_kind, - const XlaComputation& fused_computation); +struct XlaBuilderFriend { + static XlaOp BuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation); + + static HloInstructionProto* GetInstruction(XlaOp op); +}; } // namespace internal @@ -107,6 +112,7 @@ class XlaOp { friend class XlaBuilder; friend class MlirHloBuilder; + friend struct internal::XlaBuilderFriend; // < 0 means "invalid handle". int64 handle_; @@ -1306,9 +1312,7 @@ class XlaBuilder { return LookUpInstructionByHandleInternal(op.handle()); } - friend XlaOp internal::XlaBuilderBuildFusion( - XlaBuilder* builder, absl::Span operands, - absl::string_view fusion_kind, const XlaComputation& fused_computation); + friend struct internal::XlaBuilderFriend; }; // RAII-style object: sets the current sharding assignment in builder on