Use GetStringCount(tensor) and GetString(tensor, i) instead of the raw pointer overloads where applicable.
PiperOrigin-RevId: 273616765
This commit is contained in:
parent
9ca19e7af7
commit
6c10e9096b
@ -94,9 +94,8 @@ class TfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
||||
class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
||||
public:
|
||||
explicit StringTfLiteTensorBuffer(const TfLiteTensor* tensor)
|
||||
: StringTfLiteTensorBuffer(tensor, tensor->data.raw != nullptr
|
||||
? GetStringCount(tensor->data.raw)
|
||||
: 0) {}
|
||||
: StringTfLiteTensorBuffer(
|
||||
tensor, tensor->data.raw != nullptr ? GetStringCount(tensor) : 0) {}
|
||||
|
||||
~StringTfLiteTensorBuffer() override {
|
||||
LogDeallocation();
|
||||
@ -123,7 +122,7 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
||||
if (data()) {
|
||||
tensorflow::tstring* p = static_cast<tensorflow::tstring*>(data());
|
||||
for (size_t i = 0; i < num_strings_; ++p, ++i) {
|
||||
auto ref = GetString(tensor->data.raw, i);
|
||||
auto ref = GetString(tensor, i);
|
||||
p->assign(ref.str, ref.len);
|
||||
}
|
||||
}
|
||||
|
@ -39,9 +39,9 @@ std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
|
||||
std::vector<string> result;
|
||||
|
||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
|
||||
auto num_strings = GetStringCount(tensor->data.raw);
|
||||
auto num_strings = GetStringCount(tensor);
|
||||
for (size_t i = 0; i < num_strings; ++i) {
|
||||
auto ref = GetString(tensor->data.raw, i);
|
||||
auto ref = GetString(tensor, i);
|
||||
result.push_back(string(ref.str, ref.len));
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ bool RegisterCustomOpByName(const char* registerer_name,
|
||||
#else
|
||||
dlsym(RTLD_DEFAULT, registerer_name)
|
||||
#endif // defined(_WIN32)
|
||||
);
|
||||
);
|
||||
|
||||
// Fail in an informative way if the function was not found.
|
||||
if (registerer == nullptr) {
|
||||
@ -429,9 +429,9 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
|
||||
|
||||
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
|
||||
PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
|
||||
auto num_strings = GetStringCount(tensor->data.raw);
|
||||
auto num_strings = GetStringCount(tensor);
|
||||
for (int j = 0; j < num_strings; ++j) {
|
||||
auto ref = GetString(tensor->data.raw, j);
|
||||
auto ref = GetString(tensor, j);
|
||||
|
||||
PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
|
||||
if (bytes == nullptr) {
|
||||
@ -482,7 +482,7 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
|
||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||
PyObject* data, const std::vector<std::string>& registerers,
|
||||
std::string* error_msg) {
|
||||
char * buf = nullptr;
|
||||
char* buf = nullptr;
|
||||
Py_ssize_t length;
|
||||
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
|
||||
|
||||
|
@ -177,7 +177,7 @@ bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
|
||||
return false;
|
||||
}
|
||||
int expected_num_strings = GetStringCount(data_.raw);
|
||||
int returned_num_strings = GetStringCount(tensor.data.raw);
|
||||
int returned_num_strings = GetStringCount(&tensor);
|
||||
if (expected_num_strings != returned_num_strings) {
|
||||
if (verbose) {
|
||||
std::cerr << " string count differ: got " << returned_num_strings
|
||||
@ -187,7 +187,7 @@ bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
|
||||
}
|
||||
for (int i = 0; i < returned_num_strings; ++i) {
|
||||
auto expected_ref = GetString(data_.raw, i);
|
||||
auto returned_ref = GetString(tensor.data.raw, i);
|
||||
auto returned_ref = GetString(&tensor, i);
|
||||
if (expected_ref.len != returned_ref.len) {
|
||||
if (verbose) {
|
||||
std::cerr << " index " << i << ": got string of size "
|
||||
|
Loading…
Reference in New Issue
Block a user