[MLIR] Add cbrt, reduce-precision, and bitcast ops to MHLO.
PiperOrigin-RevId: 335109804 Change-Id: I0984a3db18191db07de39107886e626e5b8e090a
This commit is contained in:
parent
70302db204
commit
c52875771f
@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
|
||||
>];
|
||||
}
|
||||
|
||||
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
|
||||
|
||||
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
|
||||
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
|
||||
|
||||
@ -1423,4 +1426,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
// This is an op for purposes internal to XLA/GPU.
|
||||
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
|
||||
let arguments = (ins HLO_Tensor:$operand);
|
||||
let results = (outs HLO_Tensor);
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
|
||||
BASE_HLO_ReducePrecisionOp {
|
||||
let arguments = (ins
|
||||
HLO_FpTensor:$operand,
|
||||
I32Attr:$exponent_bits,
|
||||
I32Attr:$mantissa_bits
|
||||
);
|
||||
let results = (outs HLO_FpTensor:$output);
|
||||
}
|
||||
|
||||
#endif // HLO_OPS
|
||||
|
@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_CbrtOp {
|
||||
string summary = "Cubic root operator";
|
||||
|
||||
string description = [{
|
||||
Returns element-wise cubic root of the operand.
|
||||
|
||||
See
|
||||
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_CeilOp {
|
||||
string summary = "Ceil operator";
|
||||
|
||||
@ -1336,4 +1347,17 @@ class BASE_HLO_WhileOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_BitcastOp {
|
||||
string summary = "Bitcast operator";
|
||||
|
||||
string description = [{
|
||||
This op changes the shape of the input in the way that the physical
|
||||
arranggment of elements are unchanged.
|
||||
|
||||
However, the op needs layout information to make sense of "physical
|
||||
arrangement of elements". Layout support in MHLO is currently under
|
||||
exploration.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // HLO_OPS_BASE
|
||||
|
@ -1193,3 +1193,24 @@ func @incompatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "mhlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
|
||||
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
return %0 : tensor<2x4xf32>
|
||||
}
|
||||
|
@ -681,6 +681,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
NoAttributeCase(kAnd, AndOp);
|
||||
NoAttributeCase(kAtan2, Atan2Op);
|
||||
NoAttributeCase(kBitcastConvert, BitcastConvertOp);
|
||||
NoAttributeCase(kCbrt, CbrtOp);
|
||||
NoAttributeCase(kConvert, ConvertOp);
|
||||
NoAttributeCase(kCeil, CeilOp);
|
||||
NoAttributeCase(kClamp, ClampOp);
|
||||
@ -738,6 +739,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
|
||||
&fusion.fused_computation()));
|
||||
return fusion.getOperation();
|
||||
}
|
||||
case HloOpcode::kBitcast:
|
||||
return func_builder
|
||||
->create<mlir::mhlo::BitcastOp>(loc, result_type, operands,
|
||||
attributes)
|
||||
.getOperation();
|
||||
case HloOpcode::kReducePrecision: {
|
||||
auto op = func_builder->create<mlir::mhlo::ReducePrecisionOp>(
|
||||
loc, result_type, operands[0], attributes);
|
||||
op.exponent_bitsAttr(func_builder->getIntegerAttr(
|
||||
func_builder->getI32Type(), instruction->exponent_bits()));
|
||||
op.mantissa_bitsAttr(func_builder->getIntegerAttr(
|
||||
func_builder->getI32Type(), instruction->mantissa_bits()));
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kAddDependency:
|
||||
// Arbitrary op code that I suspect we will not implement for quite a
|
||||
// while and allows testing handling of unknown ops. Selected because it
|
||||
|
@ -1082,6 +1082,15 @@ LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) {
|
||||
auto& value_map = *ctx.values;
|
||||
xla::XlaOp operand;
|
||||
if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure();
|
||||
value_map[op] = xla::internal::XlaBuilderFriend::BuildBitcast(
|
||||
ctx.builder, operand, xla::TypeToShape(op.getType()));
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -1102,3 +1102,33 @@ func @main(%arg: tensor<3xui64>) -> tuple<tensor<3xui64>, tensor<2x2xui32>> {
|
||||
%0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple<tensor<3xui64>, tensor<2x2xui32>>
|
||||
return %0 : tuple<tensor<3xui64>, tensor<2x2xui32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] cbrt(f32[3,4] %[[ARG0]])
|
||||
%0 = "mhlo.cbrt"(%arg) : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] reduce-precision(f32[3,4] %[[ARG0]]), exponent_bits=8, mantissa_bits=10
|
||||
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> {
|
||||
// CHECK: %[[ARG0:.*]] = f32[3,4] parameter(0)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4,1] bitcast(f32[3,4] %[[ARG0]])
|
||||
%0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||
return %0 : tensor<3x4x1xf32>
|
||||
}
|
||||
|
@ -1014,3 +1014,26 @@ add {
|
||||
ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cbrt
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
|
||||
%cbrt (Arg_0.1: f32[3,4]) -> f32[3,4] {
|
||||
%Arg_0.1 = f32[3,4] parameter(0)
|
||||
// CHECK: "mhlo.cbrt"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
ROOT %cbrt = f32[3,4] cbrt(f32[3,4] %Arg_0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @bitcast
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||
%bitcast (Arg_0.1: f32[3,4]) -> f32[3,4,1] {
|
||||
%Arg_0.1 = f32[3,4] parameter(0)
|
||||
// CHECK: "mhlo.bitcast"(%[[ARG0]]) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||
ROOT %bitcast = f32[3,4,1] bitcast(f32[3,4] %Arg_0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduce_precision
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
|
||||
%reduce_precision (Arg_0.1: f32[3,4]) -> f32[3,4] {
|
||||
%Arg_0.1 = f32[3,4] parameter(0)
|
||||
// CHECK: "mhlo.reduce_precision"(%[[ARG0]]) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
ROOT %reduce_precision = f32[3,4] reduce-precision(f32[3,4] %Arg_0.1), exponent_bits=8, mantissa_bits=10
|
||||
}
|
||||
|
@ -149,6 +149,16 @@ XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand,
|
||||
const Shape& shape) {
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return builder->AddInstruction(std::move(instr), HloOpcode::kBitcast,
|
||||
{operand});
|
||||
});
|
||||
}
|
||||
|
||||
HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) {
|
||||
return &op.builder()
|
||||
->instructions_[op.builder()->handle_to_index_[op.handle_]];
|
||||
|
@ -57,6 +57,9 @@ struct XlaBuilderFriend {
|
||||
absl::string_view fusion_kind,
|
||||
const XlaComputation& fused_computation);
|
||||
|
||||
static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
|
||||
const Shape& shape);
|
||||
|
||||
static HloInstructionProto* GetInstruction(XlaOp op);
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user