Fix the mapping from HLO operations to kernel arguments in mlir gpu emitter.
This was missing a level of indirection. The gpu.launch operation's arguments are not necessarily in the same order as the arguments to the function that contains the launch. So is now catered for. PiperOrigin-RevId: 280607552 Change-Id: I9c5fa329393eb849b354d9ed9c83acf9d721900e
This commit is contained in:
parent
abc6b6d2fb
commit
6431ab55bc
@ -226,30 +226,32 @@ using OperandToValueMap =
|
||||
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
|
||||
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
|
||||
LaunchFuncOp launchOp, LLVMFuncOp kernel) {
|
||||
auto arguments = launchOp.getParentOfType<FuncOp>().getArguments();
|
||||
auto operands = instr->operands();
|
||||
std::vector<const HloInstruction*> ordered_operands;
|
||||
bool has_failed = false;
|
||||
for (int i = 0; i < launchOp.getNumKernelOperands(); ++i) {
|
||||
auto kernel_operand = dyn_cast<BlockArgument>(launchOp.getKernelOperand(i));
|
||||
if (!kernel_operand) {
|
||||
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
|
||||
++kernel_index) {
|
||||
auto launchop_operand =
|
||||
dyn_cast<BlockArgument>(launchOp.getKernelOperand(kernel_index));
|
||||
if (!launchop_operand) {
|
||||
launchOp.emitError("argument to kernel is not a function input");
|
||||
has_failed = true;
|
||||
continue;
|
||||
}
|
||||
auto pos = std::find(arguments.begin(), arguments.end(), kernel_operand);
|
||||
if (pos == arguments.end()) {
|
||||
launchOp.emitError("argument to kernel is not a function input");
|
||||
has_failed = true;
|
||||
continue;
|
||||
}
|
||||
auto index = pos - arguments.begin();
|
||||
// The last argument to the outer function is the result.
|
||||
auto operand = (index < operands.size()) ? operands[index] : instr;
|
||||
// host_index is the argument positon to the surrounding function that
|
||||
// contains the launch. This index corresponds to HLO operand indices
|
||||
// by construction.
|
||||
auto host_index = launchop_operand->getArgNumber();
|
||||
// The trailing argument to the outer function are the results.
|
||||
auto operand =
|
||||
(host_index < operands.size()) ? operands[host_index] : instr;
|
||||
if (!operand_to_value_map->count(operand)) {
|
||||
ordered_operands.push_back(operand);
|
||||
}
|
||||
(*operand_to_value_map)[operand].push_back(kernel.getArgument(index));
|
||||
// Associate the HLO operand with the argument value of the kernel
|
||||
// function.
|
||||
(*operand_to_value_map)[operand].push_back(
|
||||
kernel.getArgument(kernel_index));
|
||||
}
|
||||
if (has_failed) {
|
||||
return InternalError("Mapping operands to kernel arguments has failed.");
|
||||
|
Loading…
Reference in New Issue
Block a user