[XLA/GPU] Plumb through Bitcast op for LMHLO.
Also remove BitcastOp. XLA bitcast requires the input buffer to alias the output buffer, which makes bitcast always a no-op. PiperOrigin-RevId: 356884383 Change-Id: Idd8eb0b902b41d9753830e26add30d93fe470c61
This commit is contained in:
parent
f6e5db3430
commit
0992169740
@ -380,12 +380,6 @@ def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(timshen): add a custom verifier.
|
||||
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
||||
[]>, BASE_HLO_BroadcastOp {
|
||||
let arguments = (ins
|
||||
|
@ -899,14 +899,6 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @bitcast_memrefs
|
||||
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
|
||||
"lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scatter_memrefs
|
||||
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
|
||||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||
|
@ -364,7 +364,7 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
|
||||
return xla::HloOpcode::kRngBitGenerator;
|
||||
} else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
|
||||
return xla::HloOpcode::kFusion;
|
||||
} else if (isa<mlir::mhlo::BitcastOp, mlir::lmhlo::BitcastOp>(op)) {
|
||||
} else if (isa<mlir::mhlo::BitcastOp>(op)) {
|
||||
return xla::HloOpcode::kBitcast;
|
||||
} else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
|
||||
return xla::HloOpcode::kAbs;
|
||||
|
@ -259,6 +259,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
|
||||
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
|
||||
case HloOpcode::kAtan2:
|
||||
return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
|
||||
case HloOpcode::kBitcast:
|
||||
return nullptr;
|
||||
case HloOpcode::kBitcastConvert:
|
||||
return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
|
||||
case HloOpcode::kBroadcast:
|
||||
|
@ -166,18 +166,6 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
|
||||
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
|
||||
const HloInstruction* operand = bitcast->operand(0);
|
||||
// Bitcast is a no-op, but we still want to bind it to an llvm::Value
|
||||
// sometimes, e.g., when it's operand is a constant or a bitcast of a
|
||||
// constant.
|
||||
if (bindings_.BoundToIrValue(*operand)) {
|
||||
bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
|
||||
VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
|
||||
const HloInstruction* operand = add_dependency->operand(0);
|
||||
|
@ -77,7 +77,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
Status HandleConstant(HloInstruction* constant) override;
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleFft(HloInstruction* fft) override;
|
||||
|
@ -5858,5 +5858,11 @@ void MlirEmitterContext::SetOperation(mlir::Operation* op) {
|
||||
}
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleBitcast(HloInstruction* bitcast) {
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(bitcast));
|
||||
DCHECK_EQ(nullptr, input.op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -157,6 +157,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||
Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
|
||||
|
||||
// IrEmitterUnnested handles the following instructions differently from
|
||||
|
Loading…
x
Reference in New Issue
Block a user