Use GetStringCount(tensor) and GetString(tensor, i) instead of the raw pointer overloads where applicable.
PiperOrigin-RevId: 273616765
This commit is contained in:
parent
9ca19e7af7
commit
6c10e9096b
@ -94,9 +94,8 @@ class TfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
|||||||
class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
||||||
public:
|
public:
|
||||||
explicit StringTfLiteTensorBuffer(const TfLiteTensor* tensor)
|
explicit StringTfLiteTensorBuffer(const TfLiteTensor* tensor)
|
||||||
: StringTfLiteTensorBuffer(tensor, tensor->data.raw != nullptr
|
: StringTfLiteTensorBuffer(
|
||||||
? GetStringCount(tensor->data.raw)
|
tensor, tensor->data.raw != nullptr ? GetStringCount(tensor) : 0) {}
|
||||||
: 0) {}
|
|
||||||
|
|
||||||
~StringTfLiteTensorBuffer() override {
|
~StringTfLiteTensorBuffer() override {
|
||||||
LogDeallocation();
|
LogDeallocation();
|
||||||
@ -123,7 +122,7 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
|||||||
if (data()) {
|
if (data()) {
|
||||||
tensorflow::tstring* p = static_cast<tensorflow::tstring*>(data());
|
tensorflow::tstring* p = static_cast<tensorflow::tstring*>(data());
|
||||||
for (size_t i = 0; i < num_strings_; ++p, ++i) {
|
for (size_t i = 0; i < num_strings_; ++p, ++i) {
|
||||||
auto ref = GetString(tensor->data.raw, i);
|
auto ref = GetString(tensor, i);
|
||||||
p->assign(ref.str, ref.len);
|
p->assign(ref.str, ref.len);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -39,9 +39,9 @@ std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
|
|||||||
std::vector<string> result;
|
std::vector<string> result;
|
||||||
|
|
||||||
TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
|
TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
|
||||||
auto num_strings = GetStringCount(tensor->data.raw);
|
auto num_strings = GetStringCount(tensor);
|
||||||
for (size_t i = 0; i < num_strings; ++i) {
|
for (size_t i = 0; i < num_strings; ++i) {
|
||||||
auto ref = GetString(tensor->data.raw, i);
|
auto ref = GetString(tensor, i);
|
||||||
result.push_back(string(ref.str, ref.len));
|
result.push_back(string(ref.str, ref.len));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ bool RegisterCustomOpByName(const char* registerer_name,
|
|||||||
#else
|
#else
|
||||||
dlsym(RTLD_DEFAULT, registerer_name)
|
dlsym(RTLD_DEFAULT, registerer_name)
|
||||||
#endif // defined(_WIN32)
|
#endif // defined(_WIN32)
|
||||||
);
|
);
|
||||||
|
|
||||||
// Fail in an informative way if the function was not found.
|
// Fail in an informative way if the function was not found.
|
||||||
if (registerer == nullptr) {
|
if (registerer == nullptr) {
|
||||||
@ -429,9 +429,9 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
|
|||||||
|
|
||||||
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
|
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
|
||||||
PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
|
PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
|
||||||
auto num_strings = GetStringCount(tensor->data.raw);
|
auto num_strings = GetStringCount(tensor);
|
||||||
for (int j = 0; j < num_strings; ++j) {
|
for (int j = 0; j < num_strings; ++j) {
|
||||||
auto ref = GetString(tensor->data.raw, j);
|
auto ref = GetString(tensor, j);
|
||||||
|
|
||||||
PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
|
PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
|
||||||
if (bytes == nullptr) {
|
if (bytes == nullptr) {
|
||||||
@ -482,7 +482,7 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
|
|||||||
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
|
||||||
PyObject* data, const std::vector<std::string>& registerers,
|
PyObject* data, const std::vector<std::string>& registerers,
|
||||||
std::string* error_msg) {
|
std::string* error_msg) {
|
||||||
char * buf = nullptr;
|
char* buf = nullptr;
|
||||||
Py_ssize_t length;
|
Py_ssize_t length;
|
||||||
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
|
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
int expected_num_strings = GetStringCount(data_.raw);
|
int expected_num_strings = GetStringCount(data_.raw);
|
||||||
int returned_num_strings = GetStringCount(tensor.data.raw);
|
int returned_num_strings = GetStringCount(&tensor);
|
||||||
if (expected_num_strings != returned_num_strings) {
|
if (expected_num_strings != returned_num_strings) {
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
std::cerr << " string count differ: got " << returned_num_strings
|
std::cerr << " string count differ: got " << returned_num_strings
|
||||||
@ -187,7 +187,7 @@ bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
|
|||||||
}
|
}
|
||||||
for (int i = 0; i < returned_num_strings; ++i) {
|
for (int i = 0; i < returned_num_strings; ++i) {
|
||||||
auto expected_ref = GetString(data_.raw, i);
|
auto expected_ref = GetString(data_.raw, i);
|
||||||
auto returned_ref = GetString(tensor.data.raw, i);
|
auto returned_ref = GetString(&tensor, i);
|
||||||
if (expected_ref.len != returned_ref.len) {
|
if (expected_ref.len != returned_ref.len) {
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
std::cerr << " index " << i << ": got string of size "
|
std::cerr << " index " << i << ": got string of size "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user