diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc index fb985461c02..e292f2dad7d 100644 --- a/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc +++ b/tensorflow/lite/delegates/gpu/cl/kernels/pooling.cc @@ -185,9 +185,11 @@ std::string GetMaxPoolingKernelCode( TensorCodeGenerator dst_tensor( "dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"}, op_def.dst_tensors[0]); + const auto dst_ind_def = + output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0]; TensorCodeGenerator indices_tensor( "dst_indices", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"}, - op_def.dst_tensors[1]); + dst_ind_def); std::string c = GetCommonDefines(op_def.precision); @@ -281,10 +283,12 @@ std::string GetMaxPooling3DKernelCode( "dst_data", WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"}, op_def.dst_tensors[0]); + const auto dst_ind_def = + output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0]; TensorCodeGenerator indices_tensor( "dst_indices", WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"}, - op_def.dst_tensors[1]); + dst_ind_def); std::string c = GetCommonDefines(op_def.precision);