Fix the mapping from HLO operations to kernel arguments in mlir gpu emitter.
This was missing a level of indirection. The gpu.launch operation's arguments are not necessarily in the same order as the arguments to the function that contains the launch. So is now catered for. PiperOrigin-RevId: 280607552 Change-Id: I9c5fa329393eb849b354d9ed9c83acf9d721900e
This commit is contained in:
parent
abc6b6d2fb
commit
6431ab55bc
@ -226,30 +226,32 @@ using OperandToValueMap =
|
|||||||
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
|
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
|
||||||
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
|
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
|
||||||
LaunchFuncOp launchOp, LLVMFuncOp kernel) {
|
LaunchFuncOp launchOp, LLVMFuncOp kernel) {
|
||||||
auto arguments = launchOp.getParentOfType<FuncOp>().getArguments();
|
|
||||||
auto operands = instr->operands();
|
auto operands = instr->operands();
|
||||||
std::vector<const HloInstruction*> ordered_operands;
|
std::vector<const HloInstruction*> ordered_operands;
|
||||||
bool has_failed = false;
|
bool has_failed = false;
|
||||||
for (int i = 0; i < launchOp.getNumKernelOperands(); ++i) {
|
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
|
||||||
auto kernel_operand = dyn_cast<BlockArgument>(launchOp.getKernelOperand(i));
|
++kernel_index) {
|
||||||
if (!kernel_operand) {
|
auto launchop_operand =
|
||||||
|
dyn_cast<BlockArgument>(launchOp.getKernelOperand(kernel_index));
|
||||||
|
if (!launchop_operand) {
|
||||||
launchOp.emitError("argument to kernel is not a function input");
|
launchOp.emitError("argument to kernel is not a function input");
|
||||||
has_failed = true;
|
has_failed = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto pos = std::find(arguments.begin(), arguments.end(), kernel_operand);
|
// host_index is the argument positon to the surrounding function that
|
||||||
if (pos == arguments.end()) {
|
// contains the launch. This index corresponds to HLO operand indices
|
||||||
launchOp.emitError("argument to kernel is not a function input");
|
// by construction.
|
||||||
has_failed = true;
|
auto host_index = launchop_operand->getArgNumber();
|
||||||
continue;
|
// The trailing argument to the outer function are the results.
|
||||||
}
|
auto operand =
|
||||||
auto index = pos - arguments.begin();
|
(host_index < operands.size()) ? operands[host_index] : instr;
|
||||||
// The last argument to the outer function is the result.
|
|
||||||
auto operand = (index < operands.size()) ? operands[index] : instr;
|
|
||||||
if (!operand_to_value_map->count(operand)) {
|
if (!operand_to_value_map->count(operand)) {
|
||||||
ordered_operands.push_back(operand);
|
ordered_operands.push_back(operand);
|
||||||
}
|
}
|
||||||
(*operand_to_value_map)[operand].push_back(kernel.getArgument(index));
|
// Associate the HLO operand with the argument value of the kernel
|
||||||
|
// function.
|
||||||
|
(*operand_to_value_map)[operand].push_back(
|
||||||
|
kernel.getArgument(kernel_index));
|
||||||
}
|
}
|
||||||
if (has_failed) {
|
if (has_failed) {
|
||||||
return InternalError("Mapping operands to kernel arguments has failed.");
|
return InternalError("Mapping operands to kernel arguments has failed.");
|
||||||
|
Loading…
Reference in New Issue
Block a user