[XLA:GPU] Add support for PartitionId
PiperOrigin-RevId: 354599221 Change-Id: I8afe7e516507031172876bc19355127f5acf3a0b
This commit is contained in:
parent
892c668b01
commit
36c93e632e
@ -637,6 +637,14 @@ class BASE_HLO_ReplicaIdOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_PartitionIdOp {
|
||||
string summary = "PartitionId operator";
|
||||
|
||||
string description = [{
|
||||
Returns the unique ID (int32 scalar) of the partition.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
class BASE_HLO_AllReduceOp {
|
||||
string summary = "AllReduce operator";
|
||||
|
@ -608,6 +608,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
|
||||
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||
}
|
||||
|
||||
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
|
||||
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||
}
|
||||
|
||||
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
||||
BASE_HLO_TriangularSolveOp {
|
||||
let arguments = (ins
|
||||
|
@ -323,6 +323,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
|
||||
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
|
||||
case HloOpcode::kOutfeed:
|
||||
return EmitOutfeedOp(instr);
|
||||
case HloOpcode::kPartitionId:
|
||||
return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
|
||||
case HloOpcode::kPad:
|
||||
return EmitPadOp(instr);
|
||||
case HloOpcode::kPopulationCount:
|
||||
|
@ -2948,15 +2948,27 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename ThunkType, typename OpT>
|
||||
Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
|
||||
MlirEmitterInput input) {
|
||||
auto op = mlir::cast<OpT>(input.op);
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
|
||||
GetAllocationSliceForMlir(op.getOperand()));
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<ThunkType>(input.thunk_info, result_slice));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
||||
auto replica_id_op = mlir::cast<mlir::lmhlo::ReplicaIdOp>(input.op);
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
|
||||
GetAllocationSliceForMlir(replica_id_op.getOperand()));
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<ReplicaIdThunk>(input.thunk_info, result_slice));
|
||||
return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
|
||||
mlir::lmhlo::ReplicaIdOp>(input);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
||||
return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
|
||||
mlir::lmhlo::PartitionIdOp>(input);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
|
||||
|
@ -205,7 +205,12 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input);
|
||||
Status HandleAllToAll(HloInstruction* hlo) override;
|
||||
Status HandleAfterAll(HloInstruction* after_all) override;
|
||||
|
||||
template <typename ThunkType, typename OpT>
|
||||
Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input);
|
||||
Status HandleReplicaId(HloInstruction* hlo) override;
|
||||
Status HandlePartitionId(HloInstruction* hlo) override;
|
||||
|
||||
Status HandleCollectivePermute(HloInstruction* hlo) override;
|
||||
|
||||
Status EmitOp(MlirEmitterInput mlir_input);
|
||||
|
@ -18,11 +18,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
|
||||
|
||||
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
@ -30,9 +26,10 @@ Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||
params.GetGlobalDeviceId());
|
||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
||||
params.device_assn->ReplicaIdForDevice(global_device_id));
|
||||
params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4);
|
||||
TF_ASSIGN_OR_RETURN(auto logical_ids, params.device_assn->LogicalIdsForDevice(
|
||||
global_device_id));
|
||||
int id = kind() == Kind::kReplicaId ? logical_ids.first : logical_ids.second;
|
||||
params.stream->ThenMemset32(&dest_addr, id, /*size=*/4);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -23,17 +23,31 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// Thunk that implements the ReplicaId HLO.
|
||||
class ReplicaIdThunk : public Thunk {
|
||||
public:
|
||||
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
|
||||
|
||||
// Thunk that implements the ReplicaId(Idx == 0) or PartitionId(Idx == 1).
|
||||
class ReplicaOrPartitionIdThunk : public Thunk {
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
protected:
|
||||
ReplicaOrPartitionIdThunk(Kind kind, ThunkInfo thunk_info,
|
||||
const BufferAllocation::Slice& dest)
|
||||
: Thunk(kind, thunk_info), dest_(dest) {}
|
||||
|
||||
private:
|
||||
const BufferAllocation::Slice dest_;
|
||||
};
|
||||
|
||||
class ReplicaIdThunk : public ReplicaOrPartitionIdThunk {
|
||||
public:
|
||||
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
|
||||
: ReplicaOrPartitionIdThunk(Kind::kReplicaId, thunk_info, dest) {}
|
||||
};
|
||||
|
||||
class PartitionIdThunk : public ReplicaOrPartitionIdThunk {
|
||||
public:
|
||||
PartitionIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
|
||||
: ReplicaOrPartitionIdThunk(Kind::kPartitionId, thunk_info, dest) {}
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -72,6 +72,8 @@ absl::string_view ThunkKindToString(Thunk::Kind kind) {
|
||||
return "kOutfeed";
|
||||
case Thunk::kReplicaId:
|
||||
return "kReplicaId";
|
||||
case Thunk::kPartitionId:
|
||||
return "kPartitionId";
|
||||
case Thunk::kSequential:
|
||||
return "kSequential";
|
||||
case Thunk::kTriangularSolve:
|
||||
|
@ -64,6 +64,7 @@ class Thunk {
|
||||
kNcclAllToAll,
|
||||
kOutfeed,
|
||||
kReplicaId,
|
||||
kPartitionId,
|
||||
kSequential,
|
||||
kTriangularSolve,
|
||||
kTuple,
|
||||
|
Loading…
Reference in New Issue
Block a user