Fixed reading out of bounds in Pooling when no output_indices.
PiperOrigin-RevId: 307635625 Change-Id: I1294acd8ab58df6ad181e50f3743704f002294a8
This commit is contained in:
parent
8b9771cae3
commit
7f8add7ab8
|
@ -185,9 +185,11 @@ std::string GetMaxPoolingKernelCode(
|
|||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
const auto dst_ind_def =
|
||||
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
|
||||
TensorCodeGenerator indices_tensor(
|
||||
"dst_indices", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[1]);
|
||||
dst_ind_def);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
|
@ -281,10 +283,12 @@ std::string GetMaxPooling3DKernelCode(
|
|||
"dst_data",
|
||||
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
const auto dst_ind_def =
|
||||
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
|
||||
TensorCodeGenerator indices_tensor(
|
||||
"dst_indices",
|
||||
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[1]);
|
||||
dst_ind_def);
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
|
|
Loading…
Reference in New Issue