diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index de7fab3304e..51c34371b00 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -3191,20 +3191,9 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( } }; - const BlockPrologueGenerator& block_prologue_generator = - kernel_generator.GetBlockPrologueGenerator(); - if (block_prologue_generator) { - block_prologue_generator(unnested_hlo, kernel_info); - } - + kernel_generator.GetBlockPrologueGenerator()(unnested_hlo, kernel_info); EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty); - - const BlockEpilogueGenerator& block_epilogue_generator = - kernel_generator.GetBlockEpilogueGenerator(); - if (block_epilogue_generator) { - block_epilogue_generator(unnested_hlo, kernel_info); - } - + kernel_generator.GetBlockEpilogueGenerator()(unnested_hlo, kernel_info); return launch_dimensions; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index e5177c28484..0e3700fc59c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -121,8 +121,10 @@ class IrEmitterUnnested : public IrEmitter { public: explicit KernelCodeGenerator( TileElementGenerator tile_element_generator, - BlockPrologueGenerator block_prologue_generator = {}, - BlockEpilogueGenerator block_epilogue_generator = {}) + BlockPrologueGenerator block_prologue_generator = + [](HloInstruction*, KernelCodegenInfo*) {}, + BlockEpilogueGenerator block_epilogue_generator = + [](HloInstruction*, KernelCodegenInfo*) {}) : tile_element_generator_(std::move(tile_element_generator)), block_prologue_generator_(std::move(block_prologue_generator)), block_epilogue_generator_(std::move(block_epilogue_generator)) {}