Fix the mapping from HLO operations to kernel arguments in mlir gpu emitter.

This was missing a level of indirection. The gpu.launch operation's arguments
are not necessarily in the same order as the arguments to the function that
contains the launch. So is now catered for.

PiperOrigin-RevId: 280607552
Change-Id: I9c5fa329393eb849b354d9ed9c83acf9d721900e
This commit is contained in:
Stephan Herhut 2019-11-15 00:45:48 -08:00 committed by TensorFlower Gardener
parent abc6b6d2fb
commit 6431ab55bc

View File

@ -226,30 +226,32 @@ using OperandToValueMap =
static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap( static StatusOr<std::vector<const HloInstruction*>> ComputeOperandToValueMap(
OperandToValueMap* operand_to_value_map, const HloInstruction* instr, OperandToValueMap* operand_to_value_map, const HloInstruction* instr,
LaunchFuncOp launchOp, LLVMFuncOp kernel) { LaunchFuncOp launchOp, LLVMFuncOp kernel) {
auto arguments = launchOp.getParentOfType<FuncOp>().getArguments();
auto operands = instr->operands(); auto operands = instr->operands();
std::vector<const HloInstruction*> ordered_operands; std::vector<const HloInstruction*> ordered_operands;
bool has_failed = false; bool has_failed = false;
for (int i = 0; i < launchOp.getNumKernelOperands(); ++i) { for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands();
auto kernel_operand = dyn_cast<BlockArgument>(launchOp.getKernelOperand(i)); ++kernel_index) {
if (!kernel_operand) { auto launchop_operand =
dyn_cast<BlockArgument>(launchOp.getKernelOperand(kernel_index));
if (!launchop_operand) {
launchOp.emitError("argument to kernel is not a function input"); launchOp.emitError("argument to kernel is not a function input");
has_failed = true; has_failed = true;
continue; continue;
} }
auto pos = std::find(arguments.begin(), arguments.end(), kernel_operand); // host_index is the argument positon to the surrounding function that
if (pos == arguments.end()) { // contains the launch. This index corresponds to HLO operand indices
launchOp.emitError("argument to kernel is not a function input"); // by construction.
has_failed = true; auto host_index = launchop_operand->getArgNumber();
continue; // The trailing argument to the outer function are the results.
} auto operand =
auto index = pos - arguments.begin(); (host_index < operands.size()) ? operands[host_index] : instr;
// The last argument to the outer function is the result.
auto operand = (index < operands.size()) ? operands[index] : instr;
if (!operand_to_value_map->count(operand)) { if (!operand_to_value_map->count(operand)) {
ordered_operands.push_back(operand); ordered_operands.push_back(operand);
} }
(*operand_to_value_map)[operand].push_back(kernel.getArgument(index)); // Associate the HLO operand with the argument value of the kernel
// function.
(*operand_to_value_map)[operand].push_back(
kernel.getArgument(kernel_index));
} }
if (has_failed) { if (has_failed) {
return InternalError("Mapping operands to kernel arguments has failed."); return InternalError("Mapping operands to kernel arguments has failed.");