Fix the mapping from HLO operations to kernel arguments in mlir gpu emitter.

This was missing a level of indirection. The gpu.launch operation's arguments
are not necessarily in the same order as the arguments to the function that
contains the launch. So is now catered for.

PiperOrigin-RevId: 280607552
Change-Id: I9c5fa329393eb849b354d9ed9c83acf9d721900e
This commit is contained in:
Stephan Herhut 2019-11-15 00:45:48 -08:00 committed by TensorFlower Gardener
parent abc6b6d2fb
commit 6431ab55bc

View File

@ -226,30 +226,32 @@ using OperandToValueMap =
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
LaunchFuncOp launchOp, LLVMFuncOp kernel) {
auto arguments = launchOp.getParentOfType<FuncOp>().getArguments();
auto operands = instr->operands();
std::vector<const HloInstruction*> ordered_operands;
bool has_failed = false;
for (int i = 0; i < launchOp.getNumKernelOperands(); ++i) {
auto kernel_operand = dyn_cast<BlockArgument>(launchOp.getKernelOperand(i));
if (!kernel_operand) {
for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
++kernel_index) {
auto launchop_operand =
dyn_cast<BlockArgument>(launchOp.getKernelOperand(kernel_index));
if (!launchop_operand) {
launchOp.emitError("argument to kernel is not a function input");
has_failed = true;
continue;
}
auto pos = std::find(arguments.begin(), arguments.end(), kernel_operand);
if (pos == arguments.end()) {
launchOp.emitError("argument to kernel is not a function input");
has_failed = true;
continue;
}
auto index = pos - arguments.begin();
// The last argument to the outer function is the result.
auto operand = (index < operands.size()) ? operands[index] : instr;
// host_index is the argument positon to the surrounding function that
// contains the launch. This index corresponds to HLO operand indices
// by construction.
auto host_index = launchop_operand->getArgNumber();
// The trailing argument to the outer function are the results.
auto operand =
(host_index < operands.size()) ? operands[host_index] : instr;
if (!operand_to_value_map->count(operand)) {
ordered_operands.push_back(operand);
}
(*operand_to_value_map)[operand].push_back(kernel.getArgument(index));
// Associate the HLO operand with the argument value of the kernel
// function.
(*operand_to_value_map)[operand].push_back(
kernel.getArgument(kernel_index));
}
if (has_failed) {
return InternalError("Mapping operands to kernel arguments has failed.");