Add BitcastConvert op to HLO dialect.
BitcastConvert op requires custom exporter as the second operand 'new_element_type' in hlo instruction is inferred from the result type in mlir-hlo. PiperOrigin-RevId: 284888516 Change-Id: I2198330513166f3aa0329bb5b176ebb709b94230
This commit is contained in:
parent
f8dbd4fcce
commit
3d79d19aa2
@ -438,6 +438,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
NoAttributeCase(kAdd, AddOp);
|
||||
NoAttributeCase(kAnd, AndOp);
|
||||
NoAttributeCase(kAtan2, Atan2Op);
|
||||
NoAttributeCase(kBitcastConvert, BitcastConvertOp);
|
||||
NoAttributeCase(kConvert, ConvertOp);
|
||||
NoAttributeCase(kClamp, ClampOp);
|
||||
NoAttributeCase(kComplex, ComplexOp);
|
||||
|
@ -586,6 +586,14 @@ def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>,
|
||||
let results = (outs HLO_Tuple);
|
||||
}
|
||||
|
||||
def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert",
|
||||
[NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_BitcastConvertOp {
|
||||
|
||||
let arguments = (ins HLO_Tensor:$operand);
|
||||
let results = (outs HLO_Tensor);
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_BroadcastOp : HLO_Op<"broadcast",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp {
|
||||
let arguments = (ins
|
||||
|
@ -656,6 +656,20 @@ class BASE_HLO_BatchNormTrainingOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_BitcastConvertOp {
|
||||
string summary = "BitcastConvert operator";
|
||||
|
||||
string description = [{
|
||||
Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast
|
||||
operation from a data shape to a target shape. The dimensions must match,
|
||||
and the conversion is an element-wise one. Bitcast is implemented as a
|
||||
low-level cast, so machines with different floating-point representations
|
||||
will give different results.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_BroadcastOp {
|
||||
string summary = "Broadcast a tensor to a higher rank by prepending dimensions";
|
||||
|
||||
|
@ -446,6 +446,14 @@ LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
value_map[op] = xla::BitcastConvertType(
|
||||
value_map[op.operand()],
|
||||
xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
|
||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type) return failure();
|
||||
|
@ -111,6 +111,18 @@ func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ENTRY
|
||||
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[2] bitcast-convert(s32[2] %[[ARG]])
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: HloModule
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
||||
// CHECK: [[ARG:%.*]] = s32[4] parameter(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user