Roll forward with fix
PiperOrigin-RevId: 334675913 Change-Id: I2bc41094bac933926906dcb1570f580d42468fc1
This commit is contained in:
parent
ce321f0c9d
commit
e35bac6a94
@ -174,15 +174,16 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
|
||||
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
|
||||
TF_ASSIGN_OR_RETURN(auto result_type, ConvertShapeToType<RankedTensorType>(
|
||||
instruction->shape(), *builder_));
|
||||
llvm::SmallVector<NamedAttribute, 10> attributes = {builder_->getNamedAttr(
|
||||
"name", builder_->getStringAttr(instruction->name()))};
|
||||
mlir::Location loc = func_builder->getUnknownLoc();
|
||||
mlir::Location loc =
|
||||
mlir::NameLoc::get(func_builder->getIdentifier(instruction->name()),
|
||||
func_builder->getContext());
|
||||
|
||||
llvm::SmallVector<NamedAttribute, 10> attributes;
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kParameter: {
|
||||
return nullptr;
|
||||
@ -214,8 +215,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
return new_operation; \
|
||||
}
|
||||
case HloOpcode::kBroadcast: {
|
||||
// Note that the HLO broadcast is more powerful than the XLA broadcast op.
|
||||
// BroadcastInDim offers a superset of the HLO op's functionality.
|
||||
// Note that the HLO broadcast is more powerful than the XLA broadcast
|
||||
// op. BroadcastInDim offers a superset of the HLO op's functionality.
|
||||
attributes.push_back(
|
||||
builder_->getNamedAttr("broadcast_dimensions",
|
||||
ConvertDimensions(instruction->dimensions())));
|
||||
@ -458,7 +459,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
return op.getOperation();
|
||||
}
|
||||
|
||||
// Otherwise, it is a indexed conditional and should be mapped to Case op.
|
||||
// Otherwise, it is a indexed conditional and should be mapped to Case
|
||||
// op.
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(
|
||||
{instruction->branch_computation(0)->root_instruction()}, &rets));
|
||||
|
||||
@ -474,8 +476,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConcatenate: {
|
||||
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
|
||||
// for concatenate dimension.
|
||||
// TODO(b/132057942): Support taking an uint64_t instead of an
|
||||
// IntegerAttr for concatenate dimension.
|
||||
return func_builder
|
||||
->create<mlir::mhlo::ConcatenateOp>(
|
||||
loc, result_type, operands,
|
||||
@ -703,9 +705,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
NoAttributeCase(kReal, RealOp);
|
||||
NoAttributeCase(kRemainder, RemOp);
|
||||
NoAttributeCase(kReplicaId, ReplicaIdOp);
|
||||
// The dimensions attribute is not present on the HLO Reshape instruction.
|
||||
// If dimensions are non-default, the XLA builder implements it as a
|
||||
// separate transpose.
|
||||
// The dimensions attribute is not present on the HLO Reshape
|
||||
// instruction. If dimensions are non-default, the XLA builder
|
||||
// implements it as a separate transpose.
|
||||
NoAttributeCase(kReshape, ReshapeOp);
|
||||
NoAttributeCase(kRoundNearestAfz, RoundOp);
|
||||
NoAttributeCase(kRsqrt, RsqrtOp);
|
||||
@ -720,9 +722,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
NoAttributeCase(kTanh, TanhOp);
|
||||
NoAttributeCase(kTuple, TupleOp);
|
||||
NoAttributeCase(kXor, XorOp);
|
||||
// TODO(b/129422361) Copy needs special handling because it is not defined
|
||||
// in tensorflow/compiler/xla/client/xla_builder.h.
|
||||
// See operation semantics in
|
||||
// TODO(b/129422361) Copy needs special handling because it is not
|
||||
// defined in tensorflow/compiler/xla/client/xla_builder.h. See
|
||||
// operation semantics in
|
||||
// g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
|
||||
NoAttributeCase(kCopy, CopyOp);
|
||||
#undef NoAttributeCase
|
||||
@ -754,6 +756,35 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Operation * op,
|
||||
ImportInstructionImpl(instruction, func_builder));
|
||||
if (op == nullptr) return op;
|
||||
|
||||
// Best-effort propagation of the layouts. These layouts serve as performance
|
||||
// hints to the backend.
|
||||
//
|
||||
// Minor-to-major is a permutation of [0, rank), presenting tensor dimensions
|
||||
// in physical minor-to-major order.
|
||||
//
|
||||
// Note that non-array shapes are not carrying layouts, and users have to
|
||||
// figure out the proper layouts of them through context. This is one of the
|
||||
// reasons why the attribute-based solution is temporary.
|
||||
//
|
||||
// TODO(timshen): Investigate the necessity of having layouts in MHLO.
|
||||
if (instruction->shape().IsArray() &&
|
||||
instruction->shape().layout() !=
|
||||
LayoutUtil::MakeDescendingLayout(
|
||||
instruction->shape().dimensions().size())) {
|
||||
llvm::SmallVector<int64_t, 4> minor_to_major(
|
||||
instruction->shape().layout().minor_to_major().begin(),
|
||||
instruction->shape().layout().minor_to_major().end());
|
||||
op->setAttr("minor_to_major", builder_->getIndexTensorAttr(minor_to_major));
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
StatusOr<llvm::SmallVector<mlir::Value, 4>> HloFunctionImporter::GetOperands(
|
||||
HloInstruction* instruction) {
|
||||
llvm::SmallVector<mlir::Value, 4> operands;
|
||||
|
@ -84,6 +84,8 @@ class HloFunctionImporter {
|
||||
// Imports an instruction.
|
||||
StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction,
|
||||
mlir::OpBuilder* func_builder);
|
||||
StatusOr<mlir::Operation*> ImportInstructionImpl(
|
||||
HloInstruction* instruction, mlir::OpBuilder* func_builder);
|
||||
|
||||
// Gets the MLIR operand values from an HLO Instruction.
|
||||
StatusOr<llvm::SmallVector<mlir::Value, 4>> GetOperands(
|
||||
|
@ -1064,7 +1064,7 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
||||
llvm::SmallVector<xla::XlaOp, 4> operands;
|
||||
for (auto operand : op.operands()) operands.push_back(values[operand]);
|
||||
|
||||
xla::XlaOp fusion = xla::internal::XlaBuilderBuildFusion(
|
||||
xla::XlaOp fusion = xla::internal::XlaBuilderFriend::BuildFusion(
|
||||
ctx.builder, operands,
|
||||
absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()),
|
||||
fused_computation);
|
||||
@ -1160,7 +1160,32 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||
xla::XlaComputation* result) {
|
||||
// See hlo_function_importer.cc for documentation about layouts in MHLO.
|
||||
auto propagate_layouts = [](mlir::Operation* inst, xla::XlaOp xla_op) {
|
||||
auto attr =
|
||||
inst->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
|
||||
if (!attr) return;
|
||||
|
||||
auto* v = xla::internal::XlaBuilderFriend::GetInstruction(xla_op)
|
||||
->mutable_shape()
|
||||
->mutable_layout()
|
||||
->mutable_minor_to_major();
|
||||
v->Clear();
|
||||
for (const llvm::APInt& i : attr) {
|
||||
*v->Add() = i.getZExtValue();
|
||||
}
|
||||
};
|
||||
|
||||
if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
|
||||
if (inst->getNumResults() == 1) {
|
||||
auto iter = value_lowering->find(inst->getResult(0));
|
||||
if (iter == value_lowering->end()) {
|
||||
inst->emitOpError(
|
||||
"inst has a result, but it's not found in value_lowering");
|
||||
return failure();
|
||||
}
|
||||
propagate_layouts(inst, iter->second);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -1186,16 +1211,17 @@ LogicalResult ConvertToHloModule::Lower(
|
||||
if (failed(GetXlaOp(operand, value_map, &xla_operand, op)))
|
||||
return failure();
|
||||
value_map[op.getResult()] = xla_operand;
|
||||
propagate_layouts(inst, xla_operand);
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(jpienaar): This doesn't support layouts yet.
|
||||
if (matchPattern(inst, m_Constant(&const_attr))) {
|
||||
auto literal_or = CreateLiteralFromAttr(const_attr);
|
||||
if (!literal_or.ok())
|
||||
return inst->emitError(literal_or.status().ToString());
|
||||
value_map[inst->getResult(0)] =
|
||||
xla::ConstantLiteral(builder, literal_or.ValueOrDie());
|
||||
auto constant = xla::ConstantLiteral(builder, literal_or.ValueOrDie());
|
||||
value_map[inst->getResult(0)] = constant;
|
||||
propagate_layouts(inst, constant);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
func @main() -> tensor<f32> {
|
||||
%cst = constant {name = "constant"} dense<1> : tensor<i32>
|
||||
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%cst_0 = constant dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant dense<1.300000e+01> : tensor<f32>
|
||||
%0 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
@ -17,7 +17,7 @@ func @main() -> tensor<f32> {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
}) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
@ -48,23 +48,23 @@ func @main() -> tensor<f32> {
|
||||
// -----
|
||||
|
||||
func @main() -> (tensor<f32>, tensor<f32>) {
|
||||
%cst = constant {name = "constant"} dense<1> : tensor<i32>
|
||||
%cst_0 = constant {name = "constant.1"} dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant {name = "constant.2"} dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant {name = "constant.3"} dense<1.300000e+01> : tensor<f32>
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%cst_0 = constant dense<5.600000e+01> : tensor<f32>
|
||||
%cst_1 = constant dense<1.200000e+01> : tensor<f32>
|
||||
%cst_2 = constant dense<1.300000e+01> : tensor<f32>
|
||||
%0:2 = "mhlo.case"(%cst, %cst_0, %cst_1, %cst_2) ( {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "mhlo.negate"(%arg0) {name = "negate"} : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "mhlo.negate"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "mhlo.copy"(%arg0) {name = "copy"} : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "mhlo.copy"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}, {
|
||||
^bb0(%arg0: tensor<f32>):
|
||||
%1 = "mhlo.floor"(%arg0) {name = "floor"} : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "mhlo.floor"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%1, %1) : (tensor<f32>, tensor<f32>) -> ()
|
||||
}) {name = "conditional"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
}) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<f32>, tensor<f32>)
|
||||
return %0#0, %0#1 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
|
@ -26,21 +26,21 @@ ENTRY %indexed_conditional () -> f32[] {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main() -> tensor<f32>
|
||||
// CHECK: %[[INDEX:.*]] = constant {name = "constant"} dense<1> : tensor<i32>
|
||||
// CHECK: %[[OPERAND_1:.*]] = constant {name = "{{.*}}"} dense<5.600000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_2:.*]] = constant {name = "{{.*}}"} dense<1.200000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_3:.*]] = constant {name = "{{.*}}"} dense<1.300000e+01> : tensor<f32>
|
||||
// CHECK: %[[INDEX:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[OPERAND_1:.*]] = constant dense<5.600000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_2:.*]] = constant dense<1.200000e+01> : tensor<f32>
|
||||
// CHECK: %[[OPERAND_3:.*]] = constant dense<1.300000e+01> : tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.case"(%[[INDEX]], %[[OPERAND_1]], %[[OPERAND_2]], %[[OPERAND_3]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG_1:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[RES_1:.*]] = "mhlo.negate"(%[[ARG_1]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[RES_1]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[ARG_2:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[RES_2:.*]] = "mhlo.copy"(%[[ARG_2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[RES_2]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }, {
|
||||
// CHECK: ^bb0(%[[ARG_3:.*]]: tensor<f32>):
|
||||
// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) {name = "{{.*}}"} : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[RES_3:.*]] = "mhlo.floor"(%[[ARG_3]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: "mhlo.return"(%[[RES_3]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {name = "{{.*}}"} : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[RESULT]] : tensor<f32>
|
||||
|
@ -439,7 +439,7 @@ func @main(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10
|
||||
// CHECK-SAME: index_vector_dim=1
|
||||
// CHECK-SAME: slice_sizes={1,1,300}
|
||||
// CHECK-SAME: indices_are_sorted=true
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, name = "gather", slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
|
||||
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<[0, 1]> : tensor<2xi64>, index_vector_dim = 1 : i64, offset_dims = dense<1> : tensor<1xi64>, start_index_map = dense<[0, 1]> : tensor<2xi64>}, indices_are_sorted = true, slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>} : (tensor<200x100x300xf32>, tensor<10x2xi32>) -> tensor<10x300xf32>
|
||||
return %0 : tensor<10x300xf32>
|
||||
}
|
||||
|
||||
@ -502,7 +502,7 @@ func @main() -> tensor<1x10xf32> {
|
||||
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = "mhlo.map"(%arg0, %arg1) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
|
||||
%1 = mhlo.add %arg2, %arg3 {name = "add"} : tensor<f32>
|
||||
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
@ -739,7 +739,7 @@ func @main(%arg0: tensor<i1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) ->
|
||||
// CHECK: %[[ARG2:.*]] = s32[2,3] parameter(2)
|
||||
|
||||
// CHECK: ROOT %[[RES:.*]] = s32[2,3] select(pred[2,3] %[[COND]], s32[2,3] %[[ARG1]], s32[2,3] %[[ARG2]])
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
|
@ -9,95 +9,95 @@ ENTRY %tfcompile.48 {
|
||||
%arg0.1 = f32[1,300] parameter(0)
|
||||
%arg1.2 = f32[1,300,3,1] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
|
||||
// CHECK-NEXT: %0 = "mhlo.reshape"(%arg0) : (tensor<1x300xf32>) -> tensor<1x300xf32>
|
||||
%reshape.3 = f32[1,300] reshape(%arg0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "mhlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
|
||||
// CHECK-NEXT: %1 = "mhlo.transpose"(%0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
|
||||
%transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
|
||||
|
||||
// CHECK-NEXT: %2 = "mhlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
|
||||
// CHECK-NEXT: %2 = "mhlo.reshape"(%1) : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
|
||||
%reshape.28 = f32[300,1,1] reshape(%transpose.27)
|
||||
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
|
||||
%reshape.29 = f32[300,1] reshape(%reshape.28)
|
||||
|
||||
// CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %4 = "mhlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
|
||||
|
||||
// CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32>
|
||||
%constant.8 = f32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %5 = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %6 = mhlo.multiply %4, %5 {name = "multiply.31"} : tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %6 = mhlo.multiply %4, %5 : tensor<300x1x5xf32>
|
||||
%multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
|
||||
|
||||
// CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
||||
%constant.32 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %7 = "mhlo.broadcast_in_dim"(%cst_0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
|
||||
// CHECK-NEXT: %8 = "mhlo.compare"(%6, %7) {comparison_direction = "GT"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
|
||||
%compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
|
||||
|
||||
// CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %cst_1 = constant dense<0.000000e+00> : tensor<f32>
|
||||
%constant.10 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %9 = "mhlo.broadcast_in_dim"(%cst_1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %cst_2 = constant dense<0.000000e+00> : tensor<f32>
|
||||
%constant.40 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>, name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %10 = "mhlo.broadcast_in_dim"(%cst_2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<300x5xf32>
|
||||
%broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %11 = "mhlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
// CHECK-NEXT: %11 = "mhlo.copy"(%arg1) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%copy.1 = f32[1,300,3,1] copy(%arg1.2)
|
||||
|
||||
// CHECK-NEXT: %12 = "mhlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
// CHECK-NEXT: %12 = "mhlo.reshape"(%11) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%reshape.4 = f32[1,300,3,1] reshape(%copy.1)
|
||||
|
||||
// CHECK-NEXT: %13 = "mhlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
|
||||
// CHECK-NEXT: %13 = "mhlo.reshape"(%12) : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
|
||||
%reshape.24 = f32[1,300,3] reshape(%reshape.4)
|
||||
|
||||
// CHECK-NEXT: %14 = "mhlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
|
||||
// CHECK-NEXT: %14 = "mhlo.transpose"(%13) {permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
|
||||
%transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
|
||||
|
||||
// CHECK-NEXT: %15 = "mhlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
|
||||
// CHECK-NEXT: %15 = "mhlo.reshape"(%14) : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
|
||||
%reshape.26 = f32[300,3] reshape(%transpose.25)
|
||||
|
||||
// CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
// CHECK-NEXT: %cst_3 = constant dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
%constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config implied.
|
||||
// CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %16 = "mhlo.dot"(%15, %cst_3) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
%dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
|
||||
// CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32>
|
||||
// CHECK-NEXT: %cst_4 = constant dense<0.000000e+00> : tensor<5xf32>
|
||||
%constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
|
||||
|
||||
// CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
|
||||
// CHECK-NEXT: %17 = "mhlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<300x5xf32>
|
||||
%broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
|
||||
|
||||
// CHECK-NEXT: %18 = mhlo.add %16, %17 {name = "add.39"} : tensor<300x5xf32>
|
||||
// CHECK-NEXT: %18 = mhlo.add %16, %17 : tensor<300x5xf32>
|
||||
%add.39 = f32[300,5] add(%dot.36, %broadcast.38)
|
||||
|
||||
// CHECK-NEXT: %19 = mhlo.maximum %10, %18 {name = "maximum.42"} : tensor<300x5xf32>
|
||||
// CHECK-NEXT: %19 = mhlo.maximum %10, %18 : tensor<300x5xf32>
|
||||
%maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
|
||||
|
||||
// CHECK-NEXT: %20 = "mhlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %20 = "mhlo.reshape"(%19) : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.44 = f32[300,1,5] reshape(%maximum.42)
|
||||
|
||||
// CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %21 = "mhlo.select"(%8, %9, %20) : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
|
||||
|
||||
// CHECK-NEXT: %22 = "mhlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
// CHECK-NEXT: %22 = "mhlo.reshape"(%21) : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.46 = f32[300,1,5] reshape(%select.45)
|
||||
|
||||
// CHECK-NEXT: %23 = "mhlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: %23 = "mhlo.tuple"(%22) : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
|
||||
ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
|
||||
}
|
||||
|
@ -13,12 +13,12 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
%Arg_0.1 = f32[4]{0} parameter(0)
|
||||
%Arg_1.2 = f32[4]{0} parameter(1)
|
||||
|
||||
// CHECK-NEXT: mhlo.add %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
%add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
|
||||
// CHECK-NEXT: mhlo.add %arg0, %arg1 : tensor<4xf32>
|
||||
%add.42 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: "mhlo.dot"(%0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
|
||||
// CHECK-NEXT: "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.42, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_after_all
|
||||
@ -26,7 +26,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
%test_after_all (token0: token[], token1: token[] ) -> token[] {
|
||||
token0 = token[] parameter(0)
|
||||
token1 = token[] parameter(1)
|
||||
// CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) {name = "{{.*}}"} : (!mhlo.token, !mhlo.token) -> !mhlo.token
|
||||
// CHECK-NEXT: "mhlo.after_all"([[VAL_0]], [[VAL_1]]) : (!mhlo.token, !mhlo.token) -> !mhlo.token
|
||||
ROOT after-all = token[] after-all(token0, token1)
|
||||
}
|
||||
|
||||
@ -75,10 +75,10 @@ add {
|
||||
%test_broadcast_in_dim {
|
||||
%Arg_0.1 = f32[1, 2] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
|
||||
%broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1}
|
||||
|
||||
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "{{.*}}"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
|
||||
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
|
||||
ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2}
|
||||
}
|
||||
|
||||
@ -113,7 +113,7 @@ add {
|
||||
// CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
|
||||
%test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] {
|
||||
%a = f32[1,291,291] parameter(0)
|
||||
// CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true, name = {{.*}}} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
|
||||
// CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
|
||||
ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true
|
||||
}
|
||||
|
||||
@ -124,16 +124,16 @@ add {
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
|
||||
// CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_collective_permute
|
||||
// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
|
||||
%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
|
||||
%input = f32[128,32]{0,1} parameter(0)
|
||||
// CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {name = {{.*}}, source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32>
|
||||
ROOT root = f32[128,32]{0,1} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}}
|
||||
%input = f32[128,32]{1,0} parameter(0)
|
||||
// CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32>
|
||||
ROOT root = f32[128,32]{1,0} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}}
|
||||
}
|
||||
|
||||
|
||||
@ -143,14 +143,14 @@ add {
|
||||
%Arg_1.2 = f32[3] parameter(1)
|
||||
%Arg_2.3 = f32[3] parameter(2)
|
||||
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
|
||||
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "LE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
|
||||
|
||||
// Requires broadcast of compatible tensors.
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "{{.*}}"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: "mhlo.compare"(%arg0, %arg2) {comparison_direction = "GT"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
|
||||
}
|
||||
|
||||
@ -159,7 +159,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -176,12 +176,12 @@ add {
|
||||
%test_constant {
|
||||
|
||||
// Scalar/0D tensor constant
|
||||
// CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<1> : tensor<i64>
|
||||
// CHECK-NEXT: %cst = constant dense<1> : tensor<i64>
|
||||
%constant.0 = s64[] constant(1)
|
||||
|
||||
// Note that double brackets "[[" have to be escaped as they denote variables
|
||||
// in FileCheck. The only way to do so is to drop into regex with "{{"
|
||||
// CHECK-NEXT: constant {name = "{{.*}}"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
|
||||
// CHECK-NEXT: constant dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
|
||||
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK: dense<[1, 2, 4, 8]> : tensor<4xui64>
|
||||
@ -206,15 +206,15 @@ add {
|
||||
%test_conv {
|
||||
%arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
// CHECK-NEXT: %0 = "mhlo.copy"(%arg0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
%copy.1 = f32[256,32,32,6]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: %1 = "mhlo.reshape"(%0) {name = "{{.*}}"} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
// CHECK-NEXT: %1 = "mhlo.reshape"(%0) {minor_to_major = dense<[2, 1, 3, 0]> : tensor<4xindex>} : (tensor<256x32x32x6xf32>) -> tensor<256x32x32x6xf32>
|
||||
%reshape.2 = f32[256,32,32,6]{2,1,3,0} reshape(%copy.1)
|
||||
|
||||
// Note that double brackets "[[" have to be escaped as they denote variables
|
||||
// in FileCheck. The only way to do so is to drop into regex with "{{"
|
||||
// CHECK-NEXT: %cst = constant {name = "{{.*}}"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
|
||||
// CHECK-NEXT: %cst = constant dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
|
||||
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %2 = "mhlo.convolution"(%1, %cst) {
|
||||
@ -241,10 +241,10 @@ add {
|
||||
|
||||
%convolution.4 = f32[16,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) {name = "{{.*}}"} : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
|
||||
// CHECK-NEXT: %3 = "mhlo.reshape"(%2) : (tensor<16x30x30x256xf32>) -> tensor<256x30x30x16xf32>
|
||||
%reshape.5 = f32[256,30,30,16]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
|
||||
|
||||
// CHECK-NEXT: "mhlo.tuple"(%3) {name = "{{.*}}"} : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
// CHECK-NEXT: "mhlo.tuple"(%3) : (tensor<256x30x30x16xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
ROOT %tuple.6 = (f32[256,30,30,16]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
|
||||
}
|
||||
|
||||
@ -263,10 +263,10 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "mhlo.convert"(%arg0) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
// CHECK-NEXT: %0 = "mhlo.convert"(%arg0) : (tensor<4xf32>) -> tensor<4xf64>
|
||||
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "mhlo.convert"(%arg1) {name = "{{.*}}"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
// CHECK-NEXT: %1 = "mhlo.convert"(%arg1) : (tensor<4xf32>) -> tensor<4xf64>
|
||||
%convert.4 = f64[4] convert(f32[4] %Arg_1.2)
|
||||
|
||||
// CHECK-NEXT: mhlo.add %0, %1
|
||||
@ -277,7 +277,7 @@ add {
|
||||
%test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: "mhlo.cosine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
// CHECK-NEXT: "mhlo.cosine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
|
||||
}
|
||||
|
||||
@ -286,7 +286,7 @@ add {
|
||||
%test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
|
||||
%arg1 = f32[2,3] parameter(0)
|
||||
%arg2 = f32[5,5] parameter(1)
|
||||
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, name = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {backend_config = "bar", call_target_name = "foo", has_side_effect = true, minor_to_major = {{.*}}} : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
|
||||
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
|
||||
}
|
||||
|
||||
@ -295,7 +295,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: mhlo.divide %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.divide %arg0, %arg1 : tensor<4xf32>
|
||||
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -304,17 +304,17 @@ add {
|
||||
%Arg_0.1 = f32[1, 4] parameter(0)
|
||||
%Arg_1.2 = f32[4, 1] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest}
|
||||
|
||||
// CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default}
|
||||
|
||||
// CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default}
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
||||
|
||||
@ -325,17 +325,17 @@ add {
|
||||
%Arg_0.1 = f32[4, 1] parameter(0)
|
||||
%Arg_1.2 = f32[1, 4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGH", "HIGHEST"]}
|
||||
// CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGH", "HIGHEST"]}
|
||||
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest}
|
||||
|
||||
// CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["HIGHEST", "DEFAULT"]}
|
||||
// CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["HIGHEST", "DEFAULT"]}
|
||||
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default}
|
||||
|
||||
// CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
// CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default}
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, name = "{{.*}}", precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
// CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<> : tensor<0xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]}
|
||||
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}
|
||||
}
|
||||
|
||||
@ -376,7 +376,7 @@ add {
|
||||
%test_exponential (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "mhlo.exponential"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.exponential"(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -384,7 +384,7 @@ add {
|
||||
%test_expm1 (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK: "mhlo.exponential_minus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK: "mhlo.exponential_minus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -400,7 +400,7 @@ add {
|
||||
%test_floor (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "mhlo.floor"([[A0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.floor"([[A0]]) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -430,7 +430,7 @@ add {
|
||||
// CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>)
|
||||
%test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] {
|
||||
%Arg_0.1 = f32[4,2] parameter(0)
|
||||
// CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
// CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i32} : (tensor<4x2xf32>) -> tensor<i32>
|
||||
ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1}
|
||||
}
|
||||
|
||||
@ -438,7 +438,7 @@ add {
|
||||
%test_imag (Arg_0.1: c64[4]) -> f32[4] {
|
||||
%Arg_0.1 = c64[4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "mhlo.imag"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
|
||||
}
|
||||
|
||||
@ -468,7 +468,7 @@ add {
|
||||
%test_log (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "mhlo.log"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.log"(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %log.2 = f32[16] log(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -476,7 +476,7 @@ add {
|
||||
%test_log1p (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK: "mhlo.log_plus_one"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK: "mhlo.log_plus_one"(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -507,7 +507,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: mhlo.maximum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -516,7 +516,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: mhlo.minimum %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -525,7 +525,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 : tensor<4xf32>
|
||||
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -533,7 +533,7 @@ add {
|
||||
%test_negate (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "mhlo.negate"(%arg0) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK-NEXT: "mhlo.negate"(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -541,7 +541,7 @@ add {
|
||||
%test_not (arg0.1: pred[16]) -> pred[16] {
|
||||
%arg0.1 = pred[16] parameter(0)
|
||||
|
||||
// CHECK: "mhlo.not"(%arg0) {name = "{{.*}}"} : (tensor<16xi1>) -> tensor<16xi1>
|
||||
// CHECK: "mhlo.not"(%arg0) : (tensor<16xi1>) -> tensor<16xi1>
|
||||
ROOT %not.2 = pred[16] not(pred[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -595,7 +595,7 @@ add {
|
||||
%test_popcnt (arg0.1: s32[16]) -> s32[16] {
|
||||
%arg0.1 = s32[16] parameter(0)
|
||||
|
||||
// CHECK: "mhlo.popcnt"(%arg0) {name = "{{.*}}"} : (tensor<16xi32>) -> tensor<16xi32>
|
||||
// CHECK: "mhlo.popcnt"(%arg0) : (tensor<16xi32>) -> tensor<16xi32>
|
||||
ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -604,7 +604,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: mhlo.power %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.power %arg0, %arg1 : tensor<4xf32>
|
||||
ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -632,7 +632,7 @@ add {
|
||||
%test_real (Arg_0.1: c64[4]) -> f32[4] {
|
||||
%Arg_0.1 = c64[4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: "mhlo.real"(%arg0) {name = "{{.*}}"} : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: "mhlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
|
||||
ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1)
|
||||
}
|
||||
|
||||
@ -687,7 +687,7 @@ add {
|
||||
// CHECK: {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
|
||||
%reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3
|
||||
|
||||
// CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] {name = "{{.*}}"} : tensor<f32>
|
||||
// CHECK: %4 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor<f32>
|
||||
%sub.5 = f32[] subtract(%reduce.3, %reduce.4)
|
||||
|
||||
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5)
|
||||
@ -741,7 +741,7 @@ add {
|
||||
%test_rsqrt (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
// CHECK: "mhlo.rsqrt"([[ARG0]]) {name = "{{.*}}"} : (tensor<16xf32>) -> tensor<16xf32>
|
||||
// CHECK: "mhlo.rsqrt"([[ARG0]]) : (tensor<16xf32>) -> tensor<16xf32>
|
||||
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
@ -788,7 +788,7 @@ add {
|
||||
%Arg_1.2 = s32[2,3] parameter(1)
|
||||
%Arg_2.3 = s32[2,3] parameter(2)
|
||||
|
||||
// CHECK: "mhlo.select"(%arg0, %arg1, %arg2) {name = "{{.*}}"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
// CHECK: "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
|
||||
}
|
||||
|
||||
@ -835,7 +835,7 @@ add {
|
||||
%test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] {
|
||||
%Arg_0.1 = f32[4,4] parameter(0)
|
||||
%Arg_1.2 = s32[] parameter(1)
|
||||
// CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32, name = "{{.*}}"} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i32} : (tensor<4x4xf32>, tensor<i32>) -> tensor<4x4xf32>
|
||||
ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1}
|
||||
}
|
||||
|
||||
@ -843,7 +843,7 @@ add {
|
||||
%test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: "mhlo.sine"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
// CHECK-NEXT: "mhlo.sine"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
|
||||
}
|
||||
|
||||
@ -862,7 +862,7 @@ add {
|
||||
// CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32>
|
||||
// CHECK: "mhlo.sort"([[ARG]]) ( {
|
||||
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
|
||||
// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT", name = "lt"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: [[CMP:%.*]] = "mhlo.compare"([[ARG0]], [[ARG1]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32>
|
||||
|
||||
@ -871,7 +871,7 @@ add {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: mhlo.subtract %arg0, %arg1 {name = "{{.*}}"} : tensor<4xf32>
|
||||
// CHECK-NEXT: mhlo.subtract %arg0, %arg1 : tensor<4xf32>
|
||||
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -879,7 +879,7 @@ add {
|
||||
%test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
// CHECK-NEXT: "mhlo.tanh"(%arg0) {name = "{{.*}}"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
// CHECK-NEXT: "mhlo.tanh"(%arg0) : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
|
||||
}
|
||||
|
||||
@ -887,7 +887,7 @@ add {
|
||||
%test_transpose {
|
||||
%Arg_0.1 = s32[1,2,3,4] parameter(0)
|
||||
|
||||
// CHECK: "mhlo.transpose"(%arg0) {name = "{{.*}}", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
// CHECK: "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
|
||||
}
|
||||
|
||||
@ -909,10 +909,10 @@ add {
|
||||
%Arg_0.1 = s32[1] parameter(0)
|
||||
%Arg_1.2 = f32[1, 2] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) {name = "{{.*}}"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
|
||||
// CHECK-NEXT: %0 = "mhlo.tuple"(%arg0) : (tensor<1xi32>) -> tuple<tensor<1xi32>>
|
||||
%tuple.3 = (s32[1]) tuple(%Arg_0.1)
|
||||
|
||||
// CHECK: "mhlo.tuple"(%arg0, %arg1) {name = "{{.*}}"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
// CHECK: "mhlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2)
|
||||
}
|
||||
|
||||
@ -934,11 +934,11 @@ add {
|
||||
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
|
||||
// CHECK-NEXT: "mhlo.while"(%arg0) ( {
|
||||
// CHECK-NEXT: ^bb0(%arg1: tensor<i64>): // no predecessors
|
||||
// CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "{{.*}}"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-NEXT: [[CMP:%.*]] = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-NEXT: "mhlo.return"([[CMP]]) : (tensor<i1>) -> ()
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: ^bb0(%arg1: tensor<i64>): // no predecessors
|
||||
// CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 {name = "{{.*}}"} : tensor<i64>
|
||||
// CHECK-NEXT: [[ADD:%.*]] = mhlo.add %arg1, %arg1 : tensor<i64>
|
||||
// CHECK-NEXT: "mhlo.return"([[ADD]]) : (tensor<i64>) -> ()
|
||||
// CHECK-NEXT: }) : (tensor<i64>) -> tensor<i64>
|
||||
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
|
||||
@ -992,8 +992,8 @@ add {
|
||||
%Arg_1.2 = c128[2] parameter(1)
|
||||
%abs.4 = f64[2] abs(c128[2] %Arg_1.2)
|
||||
|
||||
// CHECK: "mhlo.abs"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
// CHECK: "mhlo.abs"(%[[ARG1]]) {name = "{{.*}}"} : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
|
||||
// CHECK: "mhlo.abs"(%[[ARG0]]) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
// CHECK: "mhlo.abs"(%[[ARG1]]) : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
|
||||
ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4)
|
||||
}
|
||||
|
||||
@ -1002,7 +1002,7 @@ add {
|
||||
%unsigned_int(Arg_0.1: u16[4]) -> u16[4] {
|
||||
%Arg_0.1 = u16[4] parameter(0)
|
||||
|
||||
// CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16>
|
||||
// CHECK: "mhlo.not"(%[[ARG0]]) : (tensor<4xui16>) -> tensor<4xui16>
|
||||
ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1)
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,11 @@
|
||||
// RUN: tf-mlir-translate -mlir-print-debuginfo -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule Test
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
ENTRY A {
|
||||
%input = f16[128,224,224,4] parameter(0)
|
||||
%filter = f16[64,7,7,4] parameter(1)
|
||||
// %0 = "mhlo.convolution"{{.*}}minor_to_major = dense<[1, 3, 2, 0]> : tensor<4xindex>{{.*}} loc("root.42")
|
||||
ROOT %root.42 = f16[128,64,112,112]{1,3,2,0} convolution(%input, %filter), dim_labels=b01f_o01i->bf01, window={size=7x7 stride=2x2 pad=3_3x3_3}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text-with-layouts %s | FileCheck %s
|
||||
|
||||
// Checks exporting layouts
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<128x224x224x4xf16>, %arg1: tensor<64x7x7x4xf16>) -> tensor<128x64x112x112xf16> {
|
||||
// CHECK: %convolution.{{.*}} = f16[128,64,112,112]{1,3,2,0} convolution{{.*}}op_name="root.42"
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 3 : i64,
|
||||
kernel_output_feature_dimension = 0 : i64,
|
||||
kernel_spatial_dimensions = dense<[ 1, 2 ]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 1 : i64,
|
||||
output_spatial_dimensions = dense<[ 2, 3 ]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
lhs_dilations = dense<1> : tensor<2xi64>,
|
||||
minor_to_major = dense<[ 1, 3, 2, 0 ]> : tensor<4xindex>,
|
||||
padding = dense<3> : tensor<2x2xi64>,
|
||||
precision_config = [ "DEFAULT", "DEFAULT" ],
|
||||
rhs_dilations = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<2> : tensor<2xi64>
|
||||
} : (tensor<128x224x224x4xf16>, tensor<64x7x7x4xf16>)-> tensor<128x64x112x112xf16> loc("root.42")
|
||||
return %0 : tensor<128x64x112x112xf16>
|
||||
}
|
@ -139,8 +139,8 @@ dynamic_parameter_binding {
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
|
||||
# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 {name = "add.3"} : tensor<4xf32>
|
||||
# CHECK-NEXT: %0 = mhlo.add %arg0, %arg1 : tensor<4xf32>
|
||||
# TODO(b/129709049) consider making this default precision config inferred.
|
||||
# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
# CHECK-NEXT: %1 = "mhlo.dot"(%0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
# CHECK-NEXT: return %1 : tensor<f32>
|
||||
# CHECK-NEXT: }
|
||||
|
@ -4,25 +4,25 @@ HloModule tfcompile.1
|
||||
|
||||
// CHECK-LABEL: func @main() -> tensor<i1> {
|
||||
ENTRY %tfcompile.1 {
|
||||
// CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK-NEXT: %cst = constant dense<1.000000e+00> : tensor<f32>
|
||||
%constant.0 = f32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<1.000000e+00> : tensor<f64>
|
||||
// CHECK-NEXT: %cst_0 = constant dense<1.000000e+00> : tensor<f64>
|
||||
%constant.1 = f64[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_1 = constant {name = "constant.2"} dense<1> : tensor<i8>
|
||||
// CHECK-NEXT: %cst_1 = constant dense<1> : tensor<i8>
|
||||
%constant.2 = s8[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_2 = constant {name = "constant.3"} dense<1> : tensor<i16>
|
||||
// CHECK-NEXT: %cst_2 = constant dense<1> : tensor<i16>
|
||||
%constant.3 = s16[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_3 = constant {name = "constant.4"} dense<1> : tensor<i32>
|
||||
// CHECK-NEXT: %cst_3 = constant dense<1> : tensor<i32>
|
||||
%constant.4 = s32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_4 = constant {name = "constant.5"} dense<1> : tensor<i64>
|
||||
// CHECK-NEXT: %cst_4 = constant dense<1> : tensor<i64>
|
||||
%constant.5 = s64[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %cst_5 = constant {name = "constant.6"} dense<true> : tensor<i1>
|
||||
// CHECK-NEXT: %cst_5 = constant dense<true> : tensor<i1>
|
||||
// CHECK-NEXT: return %cst_5 : tensor<i1>
|
||||
ROOT %constant.6 = pred[] constant(1)
|
||||
}
|
||||
|
@ -26,4 +26,4 @@ ENTRY %foo (arg0.1: s64[]) -> s64[] {
|
||||
// CHECK: "mhlo.return"
|
||||
// CHECK: }) : (tensor<i64>) -> tensor<i64>
|
||||
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
|
||||
}
|
||||
}
|
||||
|
@ -10,11 +10,11 @@ module {
|
||||
// CHECK: %[[A0]] = s64[] parameter(0)
|
||||
// CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
%1 = mhlo.add %arg1, %arg1 : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
|
@ -124,8 +124,8 @@ static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
|
||||
return HloModule::CreateFromProto(module_proto, module_config);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunctionImpl(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output, bool with_layouts) {
|
||||
if (!module) return mlir::failure();
|
||||
|
||||
HloProto hloProto;
|
||||
@ -146,9 +146,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
|
||||
HloModule* hlo_module = statusOrHloModule.ValueOrDie().get();
|
||||
|
||||
// We don't interpret or use layouts
|
||||
output << hlo_module->ToString(
|
||||
HloPrintOptions().set_include_layout_in_shapes(false));
|
||||
HloPrintOptions().set_include_layout_in_shapes(with_layouts));
|
||||
|
||||
// Output alias information as comments in the HLO text.
|
||||
hlo_module->input_output_alias_config().ForEachAlias(
|
||||
@ -162,6 +161,18 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||
/*with_layouts=*/false);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirHloToHloTextWithLayoutsTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::raw_ostream& output) {
|
||||
return MlirHloToHloTextTranslateFunctionImpl(module, output,
|
||||
/*with_layouts=*/true);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
static void RegisterInputDialects(mlir::DialectRegistry& registry) {
|
||||
@ -176,6 +187,10 @@ static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate(
|
||||
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction,
|
||||
RegisterInputDialects);
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextWithLayoutsTranslate(
|
||||
"mlir-hlo-to-hlo-text-with-layouts",
|
||||
xla::MlirHloToHloTextWithLayoutsTranslateFunction, RegisterInputDialects);
|
||||
|
||||
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
|
||||
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
|
||||
|
||||
|
@ -132,10 +132,10 @@ bool InstrIsSetBound(const HloInstructionProto* instr_proto) {
|
||||
|
||||
namespace internal {
|
||||
|
||||
XlaOp XlaBuilderBuildFusion(XlaBuilder* builder,
|
||||
absl::Span<const XlaOp> operands,
|
||||
absl::string_view fusion_kind,
|
||||
const XlaComputation& fused_computation) {
|
||||
XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
|
||||
absl::Span<const XlaOp> operands,
|
||||
absl::string_view fusion_kind,
|
||||
const XlaComputation& fused_computation) {
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
instr.set_fusion_kind(std::string(fusion_kind));
|
||||
@ -149,6 +149,11 @@ XlaOp XlaBuilderBuildFusion(XlaBuilder* builder,
|
||||
});
|
||||
}
|
||||
|
||||
HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
|
||||
return &op.builder()
|
||||
->instructions_[op.builder()->handle_to_index_[op.handle_]];
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
||||
XlaOp operator-(XlaOp x) { return Neg(x); }
|
||||
|
@ -47,13 +47,18 @@ namespace xla {
|
||||
|
||||
class XlaBuilder;
|
||||
class XlaOp;
|
||||
class HloInstruction;
|
||||
|
||||
namespace internal {
|
||||
|
||||
XlaOp XlaBuilderBuildFusion(XlaBuilder* builder,
|
||||
absl::Span<const XlaOp> operands,
|
||||
absl::string_view fusion_kind,
|
||||
const XlaComputation& fused_computation);
|
||||
struct XlaBuilderFriend {
|
||||
static XlaOp BuildFusion(XlaBuilder* builder,
|
||||
absl::Span<const XlaOp> operands,
|
||||
absl::string_view fusion_kind,
|
||||
const XlaComputation& fused_computation);
|
||||
|
||||
static HloInstructionProto* GetInstruction(XlaOp op);
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
@ -107,6 +112,7 @@ class XlaOp {
|
||||
|
||||
friend class XlaBuilder;
|
||||
friend class MlirHloBuilder;
|
||||
friend struct internal::XlaBuilderFriend;
|
||||
|
||||
// < 0 means "invalid handle".
|
||||
int64 handle_;
|
||||
@ -1306,9 +1312,7 @@ class XlaBuilder {
|
||||
return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
|
||||
}
|
||||
|
||||
friend XlaOp internal::XlaBuilderBuildFusion(
|
||||
XlaBuilder* builder, absl::Span<const XlaOp> operands,
|
||||
absl::string_view fusion_kind, const XlaComputation& fused_computation);
|
||||
friend struct internal::XlaBuilderFriend;
|
||||
};
|
||||
|
||||
// RAII-style object: sets the current sharding assignment in builder on
|
||||
|
Loading…
x
Reference in New Issue
Block a user