Fixed reading out of bounds in Pooling when no output_indices.

PiperOrigin-RevId: 307635625
Change-Id: I1294acd8ab58df6ad181e50f3743704f002294a8
This commit is contained in:
Raman Sarokin 2020-04-21 10:45:01 -07:00 committed by TensorFlower Gardener
parent 8b9771cae3
commit 7f8add7ab8
1 changed files with 6 additions and 2 deletions

View File

@ -185,9 +185,11 @@ std::string GetMaxPoolingKernelCode(
TensorCodeGenerator dst_tensor(
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[0]);
const auto dst_ind_def =
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
TensorCodeGenerator indices_tensor(
"dst_indices", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
op_def.dst_tensors[1]);
dst_ind_def);
std::string c = GetCommonDefines(op_def.precision);
@ -281,10 +283,12 @@ std::string GetMaxPooling3DKernelCode(
"dst_data",
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
const auto dst_ind_def =
output_indices ? op_def.dst_tensors[1] : op_def.dst_tensors[0];
TensorCodeGenerator indices_tensor(
"dst_indices",
WHDSPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[1]);
dst_ind_def);
std::string c = GetCommonDefines(op_def.precision);