[NFC] Eliminate references to HLO inst in CustomCallThunk

- HLO Inst was mostly used for validation, moved that validation to IR emitter, and
  eliminated HLO inst pointer from CustomCallThunk

PiperOrigin-RevId: 335671028
Change-Id: I6d1966d11841be5ac52ee167dd4f196e3fa4022e
This commit is contained in:
Rahul Joshi 2020-10-06 10:36:46 -07:00 committed by TensorFlower Gardener
parent a4f4855c82
commit ebfdff6a87
3 changed files with 17 additions and 20 deletions

View File

@ -26,27 +26,10 @@ CustomCallThunk::CustomCallThunk(
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices,
ShapeTree<BufferAllocation::Slice> result_slices, std::string opaque)
: Thunk(Thunk::kCustomCall, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
call_target_(call_target),
operand_slices_(std::move(operand_slices)),
result_slices_(std::move(result_slices)),
opaque_(std::move(opaque)) {
const HloInstruction* instr = hlo_instruction_;
CHECK_EQ(instr->operand_count(), operand_slices_.size());
for (int64 i = 0; i < instr->operand_count(); ++i) {
const auto& s1 = operand_slices_[i].shape();
const auto& s2 = instr->operand(i)->shape();
CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat(
"Shape mismatch between instr->operand(%d) and "
"operand_slices[%d].shape(): %s vs %s",
i, i, s1.ToString(), s2.ToString());
}
CHECK(ShapeUtil::Equal(instr->shape(), result_slices.shape()))
<< absl::StreamFormat(
"Shape mismatch between instr->shape() and result_slices.shape(): "
"%s vs %s.",
instr->shape().ToString(), result_slices.shape().ToString());
}
opaque_(std::move(opaque)) {}
// For each leaf in a preorder traversal of `slices`, appends its device address
// to `buffers`.

View File

@ -46,7 +46,6 @@ class CustomCallThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const HloInstruction* hlo_instruction_;
void* call_target_;
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices_;
ShapeTree<BufferAllocation::Slice> result_slices_;

View File

@ -386,11 +386,26 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
return slices;
};
std::vector<ShapeTree<BufferAllocation::Slice>> operand_slices;
for (const auto* operand : custom_call->operands()) {
for (int64 i = 0; i < custom_call->operand_count(); i++) {
const auto* operand = custom_call->operand(i);
operand_slices.push_back(get_slices_for_instr(operand));
const auto& s1 = operand_slices.back().shape();
const auto& s2 = operand->shape();
CHECK(ShapeUtil::Equal(s1, s2)) << absl::StreamFormat(
"Shape mismatch between operand shape and "
"slice shape for operand %d: %s vs %s",
i, s1.ToString(), s2.ToString());
}
ShapeTree<BufferAllocation::Slice> result_slices =
get_slices_for_instr(custom_call);
CHECK(ShapeUtil::Equal(custom_call->shape(), result_slices.shape()))
<< absl::StreamFormat(
"Shape mismatch between instr->shape() and "
"result_slices.shape(): "
"%s vs %s.",
custom_call->shape().ToString(),
result_slices.shape().ToString());
AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
context_->GetThunkInfo(custom_call), call_target,
std::move(operand_slices), std::move(result_slices),