[XLA/GPU] Plumb through Bitcast op for LMHLO.
Also remove BitcastOp. XLA bitcast requires the input buffer to alias the output buffer, which makes bitcast always a no-op. PiperOrigin-RevId: 356884383 Change-Id: Idd8eb0b902b41d9753830e26add30d93fe470c61
This commit is contained in:
parent
f6e5db3430
commit
0992169740
@ -380,12 +380,6 @@ def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(timshen): add a custom verifier.
|
|
||||||
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
|
|
||||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
|
||||||
}
|
|
||||||
|
|
||||||
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
||||||
[]>, BASE_HLO_BroadcastOp {
|
[]>, BASE_HLO_BroadcastOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -899,14 +899,6 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @bitcast_memrefs
|
|
||||||
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
|
|
||||||
"lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @scatter_memrefs
|
// CHECK-LABEL: func @scatter_memrefs
|
||||||
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
|
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
|
||||||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||||
|
@ -364,7 +364,7 @@ StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
|
|||||||
return xla::HloOpcode::kRngBitGenerator;
|
return xla::HloOpcode::kRngBitGenerator;
|
||||||
} else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
|
} else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
|
||||||
return xla::HloOpcode::kFusion;
|
return xla::HloOpcode::kFusion;
|
||||||
} else if (isa<mlir::mhlo::BitcastOp, mlir::lmhlo::BitcastOp>(op)) {
|
} else if (isa<mlir::mhlo::BitcastOp>(op)) {
|
||||||
return xla::HloOpcode::kBitcast;
|
return xla::HloOpcode::kBitcast;
|
||||||
} else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
|
} else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
|
||||||
return xla::HloOpcode::kAbs;
|
return xla::HloOpcode::kAbs;
|
||||||
|
@ -259,6 +259,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
|
|||||||
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
|
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr);
|
||||||
case HloOpcode::kAtan2:
|
case HloOpcode::kAtan2:
|
||||||
return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
|
return CreateOpWithoutAttrs<lmhlo::Atan2Op>(instr);
|
||||||
|
case HloOpcode::kBitcast:
|
||||||
|
return nullptr;
|
||||||
case HloOpcode::kBitcastConvert:
|
case HloOpcode::kBitcastConvert:
|
||||||
return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
|
return CreateOpWithoutAttrs<lmhlo::BitcastConvertOp>(instr);
|
||||||
case HloOpcode::kBroadcast:
|
case HloOpcode::kBroadcast:
|
||||||
|
@ -166,18 +166,6 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
|
|
||||||
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
|
|
||||||
const HloInstruction* operand = bitcast->operand(0);
|
|
||||||
// Bitcast is a no-op, but we still want to bind it to an llvm::Value
|
|
||||||
// sometimes, e.g., when it's operand is a constant or a bitcast of a
|
|
||||||
// constant.
|
|
||||||
if (bindings_.BoundToIrValue(*operand)) {
|
|
||||||
bindings_.BindHloToIrValue(*bitcast, GetBasePointer(*operand));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
|
Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
|
||||||
VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
|
VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
|
||||||
const HloInstruction* operand = add_dependency->operand(0);
|
const HloInstruction* operand = add_dependency->operand(0);
|
||||||
|
@ -77,7 +77,6 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
|
|
||||||
Status DefaultAction(HloInstruction* hlo) override;
|
Status DefaultAction(HloInstruction* hlo) override;
|
||||||
Status HandleConstant(HloInstruction* constant) override;
|
Status HandleConstant(HloInstruction* constant) override;
|
||||||
Status HandleBitcast(HloInstruction* bitcast) override;
|
|
||||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
|
||||||
Status HandleConvolution(HloInstruction* convolution) override;
|
Status HandleConvolution(HloInstruction* convolution) override;
|
||||||
Status HandleFft(HloInstruction* fft) override;
|
Status HandleFft(HloInstruction* fft) override;
|
||||||
|
@ -5858,5 +5858,11 @@ void MlirEmitterContext::SetOperation(mlir::Operation* op) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status IrEmitterUnnested::HandleBitcast(HloInstruction* bitcast) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(bitcast));
|
||||||
|
DCHECK_EQ(nullptr, input.op);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -157,6 +157,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status DefaultAction(HloInstruction* hlo) override;
|
Status DefaultAction(HloInstruction* hlo) override;
|
||||||
|
Status HandleBitcast(HloInstruction* bitcast) override;
|
||||||
Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
|
Status EmitUsingElementalIrEmitter(MlirEmitterInput input);
|
||||||
|
|
||||||
// IrEmitterUnnested handles the following instructions differently from
|
// IrEmitterUnnested handles the following instructions differently from
|
||||||
|
Loading…
x
Reference in New Issue
Block a user