diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index 1da4fd04ffb..1f5d1fb2ae5 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -438,6 +438,7 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kAdd, AddOp); NoAttributeCase(kAnd, AndOp); NoAttributeCase(kAtan2, Atan2Op); + NoAttributeCase(kBitcastConvert, BitcastConvertOp); NoAttributeCase(kConvert, ConvertOp); NoAttributeCase(kClamp, ClampOp); NoAttributeCase(kComplex, ComplexOp); diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index 3c4fd473eb6..2d1913f8274 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -586,6 +586,14 @@ def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>, let results = (outs HLO_Tuple); } +def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert", + [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_BitcastConvertOp { + + let arguments = (ins HLO_Tensor:$operand); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + def HLO_BroadcastOp : HLO_Op<"broadcast", [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp { let arguments = (ins diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td index a6d4210b60c..3be2c26a1bf 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -656,6 +656,20 @@ class BASE_HLO_BatchNormTrainingOp { }]; } +class BASE_HLO_BitcastConvertOp { + string summary = "BitcastConvert operator"; + + string description = [{ + Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast + operation from a data shape to a target shape. The dimensions must match, + and the conversion is an element-wise one. Bitcast is implemented as a + low-level cast, so machines with different floating-point representations + will give different results. + + See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. + }]; +} + class BASE_HLO_BroadcastOp { string summary = "Broadcast a tensor to a higher rank by prepending dimensions"; diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 26cd512aa85..c7501a91dd9 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -446,6 +446,14 @@ LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) { return success(); } +LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + value_map[op] = xla::BitcastConvertType( + value_map[op.operand()], + xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); + return success(); +} + LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { auto type = op.getType().dyn_cast(); if (!type) return failure(); diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 85ed317f8c6..ff9d69b9a52 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -111,6 +111,18 @@ func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi // ----- +// CHECK-LABEL: HloModule +func @main(%arg0: tensor<2xi32>) -> tensor<2xf32> { + %0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: ENTRY +// CHECK: %[[ARG:.*]] = s32[2] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = f32[2] bitcast-convert(s32[2] %[[ARG]]) + +// ----- + // CHECK-LABEL: HloModule func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { // CHECK: [[ARG:%.*]] = s32[4] parameter(0)