Use GetStringCount(tensor) and GetString(tensor, i) instead of the raw pointer overloads where applicable.

PiperOrigin-RevId: 273616765
This commit is contained in:
Robert David 2019-10-08 15:16:11 -07:00 committed by TensorFlower Gardener
parent 9ca19e7af7
commit 6c10e9096b
4 changed files with 11 additions and 12 deletions

View File

@ -94,9 +94,8 @@ class TfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer { class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
public: public:
explicit StringTfLiteTensorBuffer(const TfLiteTensor* tensor) explicit StringTfLiteTensorBuffer(const TfLiteTensor* tensor)
: StringTfLiteTensorBuffer(tensor, tensor->data.raw != nullptr : StringTfLiteTensorBuffer(
? GetStringCount(tensor->data.raw) tensor, tensor->data.raw != nullptr ? GetStringCount(tensor) : 0) {}
: 0) {}
~StringTfLiteTensorBuffer() override { ~StringTfLiteTensorBuffer() override {
LogDeallocation(); LogDeallocation();
@ -123,7 +122,7 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
if (data()) { if (data()) {
tensorflow::tstring* p = static_cast<tensorflow::tstring*>(data()); tensorflow::tstring* p = static_cast<tensorflow::tstring*>(data());
for (size_t i = 0; i < num_strings_; ++p, ++i) { for (size_t i = 0; i < num_strings_; ++p, ++i) {
auto ref = GetString(tensor->data.raw, i); auto ref = GetString(tensor, i);
p->assign(ref.str, ref.len); p->assign(ref.str, ref.len);
} }
} }

View File

@ -39,9 +39,9 @@ std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
std::vector<string> result; std::vector<string> result;
TfLiteTensor* tensor = interpreter_->tensor(tensor_index); TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
auto num_strings = GetStringCount(tensor->data.raw); auto num_strings = GetStringCount(tensor);
for (size_t i = 0; i < num_strings; ++i) { for (size_t i = 0; i < num_strings; ++i) {
auto ref = GetString(tensor->data.raw, i); auto ref = GetString(tensor, i);
result.push_back(string(ref.str, ref.len)); result.push_back(string(ref.str, ref.len));
} }

View File

@ -110,7 +110,7 @@ bool RegisterCustomOpByName(const char* registerer_name,
#else #else
dlsym(RTLD_DEFAULT, registerer_name) dlsym(RTLD_DEFAULT, registerer_name)
#endif // defined(_WIN32) #endif // defined(_WIN32)
); );
// Fail in an informative way if the function was not found. // Fail in an informative way if the function was not found.
if (registerer == nullptr) { if (registerer == nullptr) {
@ -429,9 +429,9 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object); PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array)); PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
auto num_strings = GetStringCount(tensor->data.raw); auto num_strings = GetStringCount(tensor);
for (int j = 0; j < num_strings; ++j) { for (int j = 0; j < num_strings; ++j) {
auto ref = GetString(tensor->data.raw, j); auto ref = GetString(tensor, j);
PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len); PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
if (bytes == nullptr) { if (bytes == nullptr) {
@ -482,7 +482,7 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
PyObject* data, const std::vector<std::string>& registerers, PyObject* data, const std::vector<std::string>& registerers,
std::string* error_msg) { std::string* error_msg) {
char * buf = nullptr; char* buf = nullptr;
Py_ssize_t length; Py_ssize_t length;
std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter); std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);

View File

@ -177,7 +177,7 @@ bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
return false; return false;
} }
int expected_num_strings = GetStringCount(data_.raw); int expected_num_strings = GetStringCount(data_.raw);
int returned_num_strings = GetStringCount(tensor.data.raw); int returned_num_strings = GetStringCount(&tensor);
if (expected_num_strings != returned_num_strings) { if (expected_num_strings != returned_num_strings) {
if (verbose) { if (verbose) {
std::cerr << " string count differ: got " << returned_num_strings std::cerr << " string count differ: got " << returned_num_strings
@ -187,7 +187,7 @@ bool TfLiteDriver::Expectation::TypedCheckString(bool verbose,
} }
for (int i = 0; i < returned_num_strings; ++i) { for (int i = 0; i < returned_num_strings; ++i) {
auto expected_ref = GetString(data_.raw, i); auto expected_ref = GetString(data_.raw, i);
auto returned_ref = GetString(tensor.data.raw, i); auto returned_ref = GetString(&tensor, i);
if (expected_ref.len != returned_ref.len) { if (expected_ref.len != returned_ref.len) {
if (verbose) { if (verbose) {
std::cerr << " index " << i << ": got string of size " std::cerr << " index " << i << ": got string of size "