From 6431ab55bc4fb003aaf838c680cd2334c6cb402f Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Fri, 15 Nov 2019 00:45:48 -0800 Subject: [PATCH] Fix the mapping from HLO operations to kernel arguments in mlir gpu emitter. This was missing a level of indirection. The gpu.launch operation's arguments are not necessarily in the same order as the arguments to the function that contains the launch. So is now catered for. PiperOrigin-RevId: 280607552 Change-Id: I9c5fa329393eb849b354d9ed9c83acf9d721900e --- .../xla/service/mlir_gpu/mlir_compiler.cc | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index c2d537bfd97..b0bb7b1581e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -226,30 +226,32 @@ using OperandToValueMap = static StatusOr> ComputeOperandToValueMap( OperandToValueMap* operand_to_value_map, const HloInstruction* instr, LaunchFuncOp launchOp, LLVMFuncOp kernel) { - auto arguments = launchOp.getParentOfType().getArguments(); auto operands = instr->operands(); std::vector ordered_operands; bool has_failed = false; - for (int i = 0; i < launchOp.getNumKernelOperands(); ++i) { - auto kernel_operand = dyn_cast(launchOp.getKernelOperand(i)); - if (!kernel_operand) { + for (int kernel_index = 0; kernel_index < launchOp.getNumKernelOperands(); + ++kernel_index) { + auto launchop_operand = + dyn_cast(launchOp.getKernelOperand(kernel_index)); + if (!launchop_operand) { launchOp.emitError("argument to kernel is not a function input"); has_failed = true; continue; } - auto pos = std::find(arguments.begin(), arguments.end(), kernel_operand); - if (pos == arguments.end()) { - launchOp.emitError("argument to kernel is not a function input"); - has_failed = true; - continue; - } - auto index = pos - arguments.begin(); - // The last argument to the outer function is the result. - auto operand = (index < operands.size()) ? operands[index] : instr; + // host_index is the argument positon to the surrounding function that + // contains the launch. This index corresponds to HLO operand indices + // by construction. + auto host_index = launchop_operand->getArgNumber(); + // The trailing argument to the outer function are the results. + auto operand = + (host_index < operands.size()) ? operands[host_index] : instr; if (!operand_to_value_map->count(operand)) { ordered_operands.push_back(operand); } - (*operand_to_value_map)[operand].push_back(kernel.getArgument(index)); + // Associate the HLO operand with the argument value of the kernel + // function. + (*operand_to_value_map)[operand].push_back( + kernel.getArgument(kernel_index)); } if (has_failed) { return InternalError("Mapping operands to kernel arguments has failed.");