[XLA/GPU] Plumb through Bitcast op for LMHLO.

Also remove BitcastOp. XLA bitcast requires the input buffer to alias the output buffer, which makes bitcast always a no-op.

PiperOrigin-RevId: 356884383
Change-Id: Idd8eb0b902b41d9753830e26add30d93fe470c61
This commit is contained in:
Tim Shen 2021-02-10 19:44:15 -08:00 committed by TensorFlower Gardener
parent f6e5db3430
commit 0992169740
8 changed files with 10 additions and 28 deletions

View File

@ -380,12 +380,6 @@ def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
);
}
// TODO(timshen): add a custom verifier.
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp {
let arguments = (ins

View File

@ -899,14 +899,6 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
// -----
// CHECK-LABEL: func @bitcast_memrefs
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
"lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @scatter_memrefs
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {

View File

@ -364,7 +364,7 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
return xla::HloOpcode::kRngBitGenerator;
} else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
return xla::HloOpcode::kFusion;
} else if (isa<mlir::mhlo::BitcastOp, mlir::lmhlo::BitcastOp>(op)) {
} else if (isa<mlir::mhlo::BitcastOp>(op)) {
return xla::HloOpcode::kBitcast;
} else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
return xla::HloOpcode::kAbs;

View File

@ -259,6 +259,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
case HloOpcode::kAtan2:
return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
case HloOpcode::kBitcast:
return nullptr;
case HloOpcode::kBitcastConvert:
return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
case HloOpcode::kBroadcast:

View File

@ -166,18 +166,6 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) {
return Status::OK();
}
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
const HloInstruction* operand = bitcast->operand(0);
// Bitcast is a no-op, but we still want to bind it to an llvm::Value
// sometimes, e.g., when it's operand is a constant or a bitcast of a
// constant.
if (bindings_.BoundToIrValue(*operand)) {
bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand));
}
return Status::OK();
}
Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
const HloInstruction* operand = add_dependency->operand(0);

View File

@ -77,7 +77,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status DefaultAction(HloInstruction* hlo) override;
Status HandleConstant(HloInstruction* constant) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;

View File

@ -5858,5 +5858,11 @@ void MlirEmitterContext::SetOperation(mlir::Operation* op) {
}
}
Status IrEmitterUnnested::HandleBitcast(HloInstruction* bitcast) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(bitcast));
DCHECK_EQ(nullptr, input.op);
return Status::OK();
}
} // namespace gpu
} // namespace xla

View File

@ -157,6 +157,7 @@ class IrEmitterUnnested : public IrEmitter,
}
Status DefaultAction(HloInstruction* hlo) override;
Status HandleBitcast(HloInstruction* bitcast) override;
Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
// IrEmitterUnnested handles the following instructions differently from