Fix ubsan warning in TFE_Py_FastPathExecute_C
PiperOrigin-RevId: 234236587
This commit is contained in:
parent
27ca6327e7
commit
5dfe49e2d8
@ -16,7 +16,6 @@ limitations under the License.
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/python/eager/pywrap_tfe.h"
|
#include "tensorflow/python/eager/pywrap_tfe.h"
|
||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
@ -25,6 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/tape.h"
|
#include "tensorflow/c/eager/tape.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
@ -2448,14 +2448,14 @@ bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
|
|||||||
|
|
||||||
bool RunCallbacks(
|
bool RunCallbacks(
|
||||||
const FastPathOpExecInfo& op_exec_info, PyObject* args,
|
const FastPathOpExecInfo& op_exec_info, PyObject* args,
|
||||||
const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
|
const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_inputs,
|
||||||
const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
|
const std::vector<tensorflow::Safe_PyObjectPtr>* const flattened_attrs,
|
||||||
PyObject* flattened_result) {
|
PyObject* flattened_result) {
|
||||||
if (!op_exec_info.run_callbacks) return true;
|
if (!op_exec_info.run_callbacks) return true;
|
||||||
|
|
||||||
tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
|
tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs->size()));
|
||||||
for (int i = 0; i < flattened_inputs.size(); i++) {
|
for (int i = 0; i < flattened_inputs->size(); i++) {
|
||||||
PyObject* input = flattened_inputs[i].get();
|
PyObject* input = (*flattened_inputs)[i].get();
|
||||||
Py_INCREF(input);
|
Py_INCREF(input);
|
||||||
PyTuple_SET_ITEM(inputs.get(), i, input);
|
PyTuple_SET_ITEM(inputs.get(), i, input);
|
||||||
}
|
}
|
||||||
@ -2463,7 +2463,7 @@ bool RunCallbacks(
|
|||||||
int num_non_inferred_attrs = PyTuple_GET_SIZE(args) -
|
int num_non_inferred_attrs = PyTuple_GET_SIZE(args) -
|
||||||
op_exec_info.op_def->input_arg_size() -
|
op_exec_info.op_def->input_arg_size() -
|
||||||
kFastPathExecuteInputStartIndex;
|
kFastPathExecuteInputStartIndex;
|
||||||
int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
|
int num_attrs = flattened_attrs->size() + num_non_inferred_attrs;
|
||||||
tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
|
tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
|
||||||
|
|
||||||
for (int i = 0; i < num_non_inferred_attrs; i++) {
|
for (int i = 0; i < num_non_inferred_attrs; i++) {
|
||||||
@ -2475,7 +2475,7 @@ bool RunCallbacks(
|
|||||||
}
|
}
|
||||||
for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
|
for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
|
||||||
PyObject* attr_or_name =
|
PyObject* attr_or_name =
|
||||||
flattened_attrs.at(i - num_non_inferred_attrs).get();
|
flattened_attrs->at(i - num_non_inferred_attrs).get();
|
||||||
Py_INCREF(attr_or_name);
|
Py_INCREF(attr_or_name);
|
||||||
PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
|
PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
|
||||||
}
|
}
|
||||||
@ -2795,8 +2795,8 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
|
|||||||
PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
|
PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!RunCallbacks(op_exec_info, args, *flattened_inputs, *flattened_attrs,
|
if (!RunCallbacks(op_exec_info, args, flattened_inputs.get(),
|
||||||
flat_result.get())) {
|
flattened_attrs.get(), flat_result.get())) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user