[XLA:GPU] Add support for PartitionId

PiperOrigin-RevId: 354599221
Change-Id: I8afe7e516507031172876bc19355127f5acf3a0b
This commit is contained in:
Rahul Joshi 2021-01-29 13:30:59 -08:00 committed by TensorFlower Gardener
parent 892c668b01
commit 36c93e632e
9 changed files with 64 additions and 19 deletions

View File

@ -637,6 +637,14 @@ class BASE_HLO_ReplicaIdOp {
}]; }];
} }
class BASE_HLO_PartitionIdOp {
string summary = "PartitionId operator";
string description = [{
Returns the unique ID (int32 scalar) of the partition.
}];
}
class BASE_HLO_AllReduceOp { class BASE_HLO_AllReduceOp {
string summary = "AllReduce operator"; string summary = "AllReduce operator";

View File

@ -608,6 +608,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>); let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
} }
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>, def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp { BASE_HLO_TriangularSolveOp {
let arguments = (ins let arguments = (ins

View File

@ -323,6 +323,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr); return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
return EmitOutfeedOp(instr); return EmitOutfeedOp(instr);
case HloOpcode::kPartitionId:
return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
case HloOpcode::kPad: case HloOpcode::kPad:
return EmitPadOp(instr); return EmitPadOp(instr);
case HloOpcode::kPopulationCount: case HloOpcode::kPopulationCount:

View File

@ -2948,15 +2948,27 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
return Status::OK(); return Status::OK();
} }
template <typename ThunkType, typename OpT>
Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
MlirEmitterInput input) {
auto op = mlir::cast<OpT>(input.op);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
GetAllocationSliceForMlir(op.getOperand()));
AddThunkToThunkSequence(
absl::make_unique<ThunkType>(input.thunk_info, result_slice));
return Status::OK();
}
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) { Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo)); TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
auto replica_id_op = mlir::cast<mlir::lmhlo::ReplicaIdOp>(input.op); return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice, mlir::lmhlo::ReplicaIdOp>(input);
GetAllocationSliceForMlir(replica_id_op.getOperand())); }
AddThunkToThunkSequence(
absl::make_unique<ReplicaIdThunk>(input.thunk_info, result_slice));
return Status::OK(); Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
mlir::lmhlo::PartitionIdOp>(input);
} }
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) { Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {

View File

@ -205,7 +205,12 @@ class IrEmitterUnnested : public IrEmitter,
Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input); Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input);
Status HandleAllToAll(HloInstruction* hlo) override; Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleAfterAll(HloInstruction* after_all) override; Status HandleAfterAll(HloInstruction* after_all) override;
template <typename ThunkType, typename OpT>
Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input);
Status HandleReplicaId(HloInstruction* hlo) override; Status HandleReplicaId(HloInstruction* hlo) override;
Status HandlePartitionId(HloInstruction* hlo) override;
Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleCollectivePermute(HloInstruction* hlo) override;
Status EmitOp(MlirEmitterInput mlir_input); Status EmitOp(MlirEmitterInput mlir_input);

View File

@ -18,11 +18,7 @@ limitations under the License.
namespace xla { namespace xla {
namespace gpu { namespace gpu {
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info, Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) {
const BufferAllocation::Slice& dest)
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler = auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index()); params.profiler->MakeScopedInstructionProfiler(profile_index());
@ -30,9 +26,10 @@ Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id, TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId()); params.GetGlobalDeviceId());
TF_ASSIGN_OR_RETURN(int replica_id, TF_ASSIGN_OR_RETURN(auto logical_ids, params.device_assn->LogicalIdsForDevice(
params.device_assn->ReplicaIdForDevice(global_device_id)); global_device_id));
params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4); int id = kind() == Kind::kReplicaId ? logical_ids.first : logical_ids.second;
params.stream->ThenMemset32(&dest_addr, id, /*size=*/4);
return Status::OK(); return Status::OK();
} }

View File

@ -23,17 +23,31 @@ limitations under the License.
namespace xla { namespace xla {
namespace gpu { namespace gpu {
// Thunk that implements the ReplicaId HLO. // Thunk that implements the ReplicaId(Idx == 0) or PartitionId(Idx == 1).
class ReplicaIdThunk : public Thunk { class ReplicaOrPartitionIdThunk : public Thunk {
public:
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
Status ExecuteOnStream(const ExecuteParams& params) override; Status ExecuteOnStream(const ExecuteParams& params) override;
protected:
ReplicaOrPartitionIdThunk(Kind kind, ThunkInfo thunk_info,
const BufferAllocation::Slice& dest)
: Thunk(kind, thunk_info), dest_(dest) {}
private: private:
const BufferAllocation::Slice dest_; const BufferAllocation::Slice dest_;
}; };
class ReplicaIdThunk : public ReplicaOrPartitionIdThunk {
public:
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
: ReplicaOrPartitionIdThunk(Kind::kReplicaId, thunk_info, dest) {}
};
class PartitionIdThunk : public ReplicaOrPartitionIdThunk {
public:
PartitionIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
: ReplicaOrPartitionIdThunk(Kind::kPartitionId, thunk_info, dest) {}
};
} // namespace gpu } // namespace gpu
} // namespace xla } // namespace xla

View File

@ -72,6 +72,8 @@ absl::string_view ThunkKindToString(Thunk::Kind kind) {
return "kOutfeed"; return "kOutfeed";
case Thunk::kReplicaId: case Thunk::kReplicaId:
return "kReplicaId"; return "kReplicaId";
case Thunk::kPartitionId:
return "kPartitionId";
case Thunk::kSequential: case Thunk::kSequential:
return "kSequential"; return "kSequential";
case Thunk::kTriangularSolve: case Thunk::kTriangularSolve:

View File

@ -64,6 +64,7 @@ class Thunk {
kNcclAllToAll, kNcclAllToAll,
kOutfeed, kOutfeed,
kReplicaId, kReplicaId,
kPartitionId,
kSequential, kSequential,
kTriangularSolve, kTriangularSolve,
kTuple, kTuple,