Add BitcastConvert op to HLO dialect.

BitcastConvert op requires custom exporter as the second operand 'new_element_type' in hlo instruction is inferred from the result type in mlir-hlo.

PiperOrigin-RevId: 284888516
Change-Id: I2198330513166f3aa0329bb5b176ebb709b94230
This commit is contained in:
Prakalp Srivastava 2019-12-10 17:51:36 -08:00 committed by TensorFlower Gardener
parent f8dbd4fcce
commit 3d79d19aa2
5 changed files with 43 additions and 0 deletions

View File

@ -438,6 +438,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
NoAttributeCase(kAdd, AddOp);
NoAttributeCase(kAnd, AndOp);
NoAttributeCase(kAtan2, Atan2Op);
NoAttributeCase(kBitcastConvert, BitcastConvertOp);
NoAttributeCase(kConvert, ConvertOp);
NoAttributeCase(kClamp, ClampOp);
NoAttributeCase(kComplex, ComplexOp);

View File

@ -586,6 +586,14 @@ def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>,
let results = (outs HLO_Tuple);
}
def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert",
[NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_BitcastConvertOp {
let arguments = (ins HLO_Tensor:$operand);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_BroadcastOp : HLO_Op<"broadcast",
[NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp {
let arguments = (ins

View File

@ -656,6 +656,20 @@ class BASE_HLO_BatchNormTrainingOp {
}];
}
class BASE_HLO_BitcastConvertOp {
string summary = "BitcastConvert operator";
string description = [{
Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast
operation from a data shape to a target shape. The dimensions must match,
and the conversion is an element-wise one. Bitcast is implemented as a
low-level cast, so machines with different floating-point representations
will give different results.
See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype.
}];
}
class BASE_HLO_BroadcastOp {
string summary = "Broadcast a tensor to a higher rank by prepending dimensions";

View File

@ -446,6 +446,14 @@ LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) {
return success();
}
LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) {
auto& value_map = *ctx.values;
value_map[op] = xla::BitcastConvertType(
value_map[op.operand()],
xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType())));
return success();
}
LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) {
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type) return failure();

View File

@ -111,6 +111,18 @@ func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// CHECK-LABEL: ENTRY
// CHECK: %[[ARG:.*]] = s32[2] parameter(0)
// CHECK: ROOT %[[RESULT:.*]] = f32[2] bitcast-convert(s32[2] %[[ARG]])
// -----
// CHECK-LABEL: HloModule
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
// CHECK: [[ARG:%.*]] = s32[4] parameter(0)