[NFC] Eliminate references to HLO inst in CustomCallThunk
- HLO Inst was mostly used for validation, moved that validation to IR emitter, and eliminated HLO inst pointer from CustomCallThunk PiperOrigin-RevId: 335671028 Change-Id: I6d1966d11841be5ac52ee167dd4f196e3fa4022e
This commit is contained in:
parent
a4f4855c82
commit
ebfdff6a87
@ -26,27 +26,10 @@ CustomCallThunk::CustomCallThunk(
|
||||
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
|
||||
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque)
|
||||
: Thunk(Thunk::kCustomCall, thunk_info),
|
||||
hlo_instruction_(thunk_info.hlo_instruction),
|
||||
call_target_(call_target),
|
||||
operand_slices_(std::move(operand_slices)),
|
||||
result_slices_(std::move(result_slices)),
|
||||
opaque_(std::move(opaque)) {
|
||||
const HloInstruction* instr = hlo_instruction_;
|
||||
CHECK_EQ(instr->operand_count(), operand_slices_.size());
|
||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||
const auto& s1 = operand_slices_[i].shape();
|
||||
const auto& s2 = instr->operand(i)->shape();
|
||||
CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat(
|
||||
"Shape mismatch between instr->operand(%d) and "
|
||||
"operand_slices[%d].shape(): %s vs %s",
|
||||
i, i, s1.ToString(), s2.ToString());
|
||||
}
|
||||
CHECK(ShapeUtil::Equal(instr->shape(), result_slices.shape()))
|
||||
<< absl::StreamFormat(
|
||||
"Shape mismatch between instr->shape() and result_slices.shape(): "
|
||||
"%s vs %s.",
|
||||
instr->shape().ToString(), result_slices.shape().ToString());
|
||||
}
|
||||
opaque_(std::move(opaque)) {}
|
||||
|
||||
// For each leaf in a preorder traversal of `slices`, appends its device address
|
||||
// to `buffers`.
|
||||
|
||||
@ -46,7 +46,6 @@ class CustomCallThunk : public Thunk {
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
private:
|
||||
const HloInstruction* hlo_instruction_;
|
||||
void* call_target_;
|
||||
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices_;
|
||||
ShapeTree<BufferAllocation::Slice> result_slices_;
|
||||
|
||||
@ -386,11 +386,26 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
return slices;
|
||||
};
|
||||
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices;
|
||||
for (const auto* operand : custom_call->operands()) {
|
||||
for (int64 i = 0; i < custom_call->operand_count(); i++) {
|
||||
const auto* operand = custom_call->operand(i);
|
||||
operand_slices.push_back(get_slices_for_instr(operand));
|
||||
const auto& s1 = operand_slices.back().shape();
|
||||
const auto& s2 = operand->shape();
|
||||
CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat(
|
||||
"Shape mismatch between operand shape and "
|
||||
"slice shape for operand %d: %s vs %s",
|
||||
i, s1.ToString(), s2.ToString());
|
||||
}
|
||||
ShapeTree<BufferAllocation::Slice> result_slices =
|
||||
get_slices_for_instr(custom_call);
|
||||
CHECK(ShapeUtil::Equal(custom_call->shape(), result_slices.shape()))
|
||||
<< absl::StreamFormat(
|
||||
"Shape mismatch between instr->shape() and "
|
||||
"result_slices.shape(): "
|
||||
"%s vs %s.",
|
||||
custom_call->shape().ToString(),
|
||||
result_slices.shape().ToString());
|
||||
|
||||
AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
|
||||
context_->GetThunkInfo(custom_call), call_target,
|
||||
std::move(operand_slices), std::move(result_slices),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user