[XLA:GPU] Add support for PartitionId
PiperOrigin-RevId: 354599221 Change-Id: I8afe7e516507031172876bc19355127f5acf3a0b
This commit is contained in:
parent
892c668b01
commit
36c93e632e
@ -637,6 +637,14 @@ class BASE_HLO_ReplicaIdOp {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_PartitionIdOp {
|
||||||
|
string summary = "PartitionId operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns the unique ID (int32 scalar) of the partition.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class BASE_HLO_AllReduceOp {
|
class BASE_HLO_AllReduceOp {
|
||||||
string summary = "AllReduce operator";
|
string summary = "AllReduce operator";
|
||||||
|
@ -608,6 +608,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
|
|||||||
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
|
||||||
|
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||||
|
}
|
||||||
|
|
||||||
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
||||||
BASE_HLO_TriangularSolveOp {
|
BASE_HLO_TriangularSolveOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
|
@ -323,6 +323,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
|
|||||||
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
|
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
|
||||||
case HloOpcode::kOutfeed:
|
case HloOpcode::kOutfeed:
|
||||||
return EmitOutfeedOp(instr);
|
return EmitOutfeedOp(instr);
|
||||||
|
case HloOpcode::kPartitionId:
|
||||||
|
return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
|
||||||
case HloOpcode::kPad:
|
case HloOpcode::kPad:
|
||||||
return EmitPadOp(instr);
|
return EmitPadOp(instr);
|
||||||
case HloOpcode::kPopulationCount:
|
case HloOpcode::kPopulationCount:
|
||||||
|
@ -2948,15 +2948,27 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ThunkType, typename OpT>
|
||||||
|
Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
|
||||||
|
MlirEmitterInput input) {
|
||||||
|
auto op = mlir::cast<OpT>(input.op);
|
||||||
|
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
|
||||||
|
GetAllocationSliceForMlir(op.getOperand()));
|
||||||
|
AddThunkToThunkSequence(
|
||||||
|
absl::make_unique<ThunkType>(input.thunk_info, result_slice));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
|
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
|
||||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
||||||
auto replica_id_op = mlir::cast<mlir::lmhlo::ReplicaIdOp>(input.op);
|
return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
|
||||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
|
mlir::lmhlo::ReplicaIdOp>(input);
|
||||||
GetAllocationSliceForMlir(replica_id_op.getOperand()));
|
}
|
||||||
AddThunkToThunkSequence(
|
|
||||||
absl::make_unique<ReplicaIdThunk>(input.thunk_info, result_slice));
|
|
||||||
|
|
||||||
return Status::OK();
|
Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
||||||
|
return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
|
||||||
|
mlir::lmhlo::PartitionIdOp>(input);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
|
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
|
||||||
|
@ -205,7 +205,12 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input);
|
Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input);
|
||||||
Status HandleAllToAll(HloInstruction* hlo) override;
|
Status HandleAllToAll(HloInstruction* hlo) override;
|
||||||
Status HandleAfterAll(HloInstruction* after_all) override;
|
Status HandleAfterAll(HloInstruction* after_all) override;
|
||||||
|
|
||||||
|
template <typename ThunkType, typename OpT>
|
||||||
|
Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input);
|
||||||
Status HandleReplicaId(HloInstruction* hlo) override;
|
Status HandleReplicaId(HloInstruction* hlo) override;
|
||||||
|
Status HandlePartitionId(HloInstruction* hlo) override;
|
||||||
|
|
||||||
Status HandleCollectivePermute(HloInstruction* hlo) override;
|
Status HandleCollectivePermute(HloInstruction* hlo) override;
|
||||||
|
|
||||||
Status EmitOp(MlirEmitterInput mlir_input);
|
Status EmitOp(MlirEmitterInput mlir_input);
|
||||||
|
@ -18,11 +18,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
|
Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||||
const BufferAllocation::Slice& dest)
|
|
||||||
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
|
|
||||||
|
|
||||||
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|
||||||
auto op_profiler =
|
auto op_profiler =
|
||||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||||
|
|
||||||
@ -30,9 +26,10 @@ Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
|
||||||
params.GetGlobalDeviceId());
|
params.GetGlobalDeviceId());
|
||||||
TF_ASSIGN_OR_RETURN(int replica_id,
|
TF_ASSIGN_OR_RETURN(auto logical_ids, params.device_assn->LogicalIdsForDevice(
|
||||||
params.device_assn->ReplicaIdForDevice(global_device_id));
|
global_device_id));
|
||||||
params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4);
|
int id = kind() == Kind::kReplicaId ? logical_ids.first : logical_ids.second;
|
||||||
|
params.stream->ThenMemset32(&dest_addr, id, /*size=*/4);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,17 +23,31 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
// Thunk that implements the ReplicaId HLO.
|
// Thunk that implements the ReplicaId(Idx == 0) or PartitionId(Idx == 1).
|
||||||
class ReplicaIdThunk : public Thunk {
|
class ReplicaOrPartitionIdThunk : public Thunk {
|
||||||
public:
|
|
||||||
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
|
|
||||||
|
|
||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ReplicaOrPartitionIdThunk(Kind kind, ThunkInfo thunk_info,
|
||||||
|
const BufferAllocation::Slice& dest)
|
||||||
|
: Thunk(kind, thunk_info), dest_(dest) {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const BufferAllocation::Slice dest_;
|
const BufferAllocation::Slice dest_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ReplicaIdThunk : public ReplicaOrPartitionIdThunk {
|
||||||
|
public:
|
||||||
|
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
|
||||||
|
: ReplicaOrPartitionIdThunk(Kind::kReplicaId, thunk_info, dest) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class PartitionIdThunk : public ReplicaOrPartitionIdThunk {
|
||||||
|
public:
|
||||||
|
PartitionIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
|
||||||
|
: ReplicaOrPartitionIdThunk(Kind::kPartitionId, thunk_info, dest) {}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -72,6 +72,8 @@ absl::string_view ThunkKindToString(Thunk::Kind kind) {
|
|||||||
return "kOutfeed";
|
return "kOutfeed";
|
||||||
case Thunk::kReplicaId:
|
case Thunk::kReplicaId:
|
||||||
return "kReplicaId";
|
return "kReplicaId";
|
||||||
|
case Thunk::kPartitionId:
|
||||||
|
return "kPartitionId";
|
||||||
case Thunk::kSequential:
|
case Thunk::kSequential:
|
||||||
return "kSequential";
|
return "kSequential";
|
||||||
case Thunk::kTriangularSolve:
|
case Thunk::kTriangularSolve:
|
||||||
|
@ -64,6 +64,7 @@ class Thunk {
|
|||||||
kNcclAllToAll,
|
kNcclAllToAll,
|
||||||
kOutfeed,
|
kOutfeed,
|
||||||
kReplicaId,
|
kReplicaId,
|
||||||
|
kPartitionId,
|
||||||
kSequential,
|
kSequential,
|
||||||
kTriangularSolve,
|
kTriangularSolve,
|
||||||
kTuple,
|
kTuple,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user