[XLA:GPU] Add support for PartitionId

PiperOrigin-RevId: 354599221
Change-Id: I8afe7e516507031172876bc19355127f5acf3a0b
This commit is contained in:
Rahul Joshi 2021-01-29 13:30:59 -08:00 committed by TensorFlower Gardener
parent 892c668b01
commit 36c93e632e
9 changed files with 64 additions and 19 deletions

View File

@ -637,6 +637,14 @@ class BASE_HLO_ReplicaIdOp {
}];
}
class BASE_HLO_PartitionIdOp {
string summary = "PartitionId operator";
string description = [{
Returns the unique ID (int32 scalar) of the partition.
}];
}
class BASE_HLO_AllReduceOp {
string summary = "AllReduce operator";

View File

@ -608,6 +608,10 @@ def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins

View File

@ -323,6 +323,8 @@ StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
return CreateOpWithoutAttrs<lmhlo::OrOp>(instr);
case HloOpcode::kOutfeed:
return EmitOutfeedOp(instr);
case HloOpcode::kPartitionId:
return CreateOpWithoutAttrs<lmhlo::PartitionIdOp>(instr);
case HloOpcode::kPad:
return EmitPadOp(instr);
case HloOpcode::kPopulationCount:

View File

@ -2948,15 +2948,27 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
return Status::OK();
}
template <typename ThunkType, typename OpT>
Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
MlirEmitterInput input) {
auto op = mlir::cast<OpT>(input.op);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
GetAllocationSliceForMlir(op.getOperand()));
AddThunkToThunkSequence(
absl::make_unique<ThunkType>(input.thunk_info, result_slice));
return Status::OK();
}
Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
auto replica_id_op = mlir::cast<mlir::lmhlo::ReplicaIdOp>(input.op);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
GetAllocationSliceForMlir(replica_id_op.getOperand()));
AddThunkToThunkSequence(
absl::make_unique<ReplicaIdThunk>(input.thunk_info, result_slice));
return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
mlir::lmhlo::ReplicaIdOp>(input);
}
return Status::OK();
Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
mlir::lmhlo::PartitionIdOp>(input);
}
Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {

View File

@ -205,7 +205,12 @@ class IrEmitterUnnested : public IrEmitter,
Status EmitAllReduceFromMlir(MlirEmitterInput mlir_input);
Status HandleAllToAll(HloInstruction* hlo) override;
Status HandleAfterAll(HloInstruction* after_all) override;
template <typename ThunkType, typename OpT>
Status EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input);
Status HandleReplicaId(HloInstruction* hlo) override;
Status HandlePartitionId(HloInstruction* hlo) override;
Status HandleCollectivePermute(HloInstruction* hlo) override;
Status EmitOp(MlirEmitterInput mlir_input);

View File

@ -18,11 +18,7 @@ limitations under the License.
namespace xla {
namespace gpu {
ReplicaIdThunk::ReplicaIdThunk(ThunkInfo thunk_info,
const BufferAllocation::Slice& dest)
: Thunk(Kind::kReplicaId, thunk_info), dest_(dest) {}
Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
Status ReplicaOrPartitionIdThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
@ -30,9 +26,10 @@ Status ReplicaIdThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_ASSIGN_OR_RETURN(GlobalDeviceId global_device_id,
params.GetGlobalDeviceId());
TF_ASSIGN_OR_RETURN(int replica_id,
params.device_assn->ReplicaIdForDevice(global_device_id));
params.stream->ThenMemset32(&dest_addr, replica_id, /*size=*/4);
TF_ASSIGN_OR_RETURN(auto logical_ids, params.device_assn->LogicalIdsForDevice(
global_device_id));
int id = kind() == Kind::kReplicaId ? logical_ids.first : logical_ids.second;
params.stream->ThenMemset32(&dest_addr, id, /*size=*/4);
return Status::OK();
}

View File

@ -23,17 +23,31 @@ limitations under the License.
namespace xla {
namespace gpu {
// Thunk that implements the ReplicaId HLO.
class ReplicaIdThunk : public Thunk {
public:
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest);
// Thunk that implements the ReplicaId(Idx == 0) or PartitionId(Idx == 1).
class ReplicaOrPartitionIdThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
protected:
ReplicaOrPartitionIdThunk(Kind kind, ThunkInfo thunk_info,
const BufferAllocation::Slice& dest)
: Thunk(kind, thunk_info), dest_(dest) {}
private:
const BufferAllocation::Slice dest_;
};
class ReplicaIdThunk : public ReplicaOrPartitionIdThunk {
public:
ReplicaIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
: ReplicaOrPartitionIdThunk(Kind::kReplicaId, thunk_info, dest) {}
};
class PartitionIdThunk : public ReplicaOrPartitionIdThunk {
public:
PartitionIdThunk(ThunkInfo thunk_info, const BufferAllocation::Slice& dest)
: ReplicaOrPartitionIdThunk(Kind::kPartitionId, thunk_info, dest) {}
};
} // namespace gpu
} // namespace xla

View File

@ -72,6 +72,8 @@ absl::string_view ThunkKindToString(Thunk::Kind kind) {
return "kOutfeed";
case Thunk::kReplicaId:
return "kReplicaId";
case Thunk::kPartitionId:
return "kPartitionId";
case Thunk::kSequential:
return "kSequential";
case Thunk::kTriangularSolve:

View File

@ -64,6 +64,7 @@ class Thunk {
kNcclAllToAll,
kOutfeed,
kReplicaId,
kPartitionId,
kSequential,
kTriangularSolve,
kTuple,