Updated lite/ to use tstring.

This is a part of a larger migration effort for tensorflow::tstring.
See: https://github.com/tensorflow/community/pull/91
PiperOrigin-RevId: 263230002
This commit is contained in:
Dero Gharibian 2019-08-13 15:35:15 -07:00 committed by TensorFlower Gardener
parent 64b01d5574
commit ee875a8bb5
4 changed files with 12 additions and 10 deletions
tensorflow/lite

View File

@ -100,18 +100,20 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
~StringTfLiteTensorBuffer() override {
LogDeallocation();
tensorflow::TypedAllocator::Deallocate<tensorflow::string>(
tensorflow::cpu_allocator(), static_cast<tensorflow::string*>(data()),
tensorflow::TypedAllocator::Deallocate<tensorflow::tstring>(
tensorflow::cpu_allocator(), static_cast<tensorflow::tstring*>(data()),
num_strings_);
}
size_t size() const override { return num_strings_ * sizeof(string); }
size_t size() const override {
return num_strings_ * sizeof(tensorflow::tstring);
}
private:
StringTfLiteTensorBuffer(const TfLiteTensor* tensor, int num_strings)
: BaseTfLiteTensorBuffer(
num_strings != 0
? tensorflow::TypedAllocator::Allocate<tensorflow::string>(
? tensorflow::TypedAllocator::Allocate<tensorflow::tstring>(
tensorflow::cpu_allocator(), num_strings,
tensorflow::AllocationAttributes())
: nullptr),
@ -119,7 +121,7 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
LogAllocation();
if (data()) {
string* p = static_cast<string*>(data());
tensorflow::tstring* p = static_cast<tensorflow::tstring*>(data());
for (size_t i = 0; i < num_strings_; ++p, ++i) {
auto ref = GetString(tensor->data.raw, i);
p->assign(ref.str, ref.len);

View File

@ -128,7 +128,7 @@ TEST(BufferMapTest, SetFromTfLiteString) {
buffer_map.SetFromTfLite(0, t.get());
ASSERT_TRUE(buffer_map.HasTensor(0));
EXPECT_THAT(GetTensorData<string>(buffer_map.GetTensor(0)),
EXPECT_THAT(GetTensorData<tensorflow::tstring>(buffer_map.GetTensor(0)),
ElementsAre("", "", "", "str1", "", ""));
// Also check details of the tensor.
@ -162,7 +162,7 @@ TEST(BufferMapTest, SetFromTfLiteStringTwice) {
buffer_map.SetFromTfLite(0, t1.get());
buffer_map.SetFromTfLite(0, t2.get());
EXPECT_THAT(GetTensorData<string>(buffer_map.GetTensor(0)),
EXPECT_THAT(GetTensorData<tensorflow::tstring>(buffer_map.GetTensor(0)),
ElementsAre("", "", "", "s3", "", "", "s1", "s2"));
}

View File

@ -97,7 +97,7 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
}
DynamicBuffer dynamic_buffer;
auto tf_data = t.flat<string>();
auto tf_data = t.flat<tensorflow::tstring>();
for (int i = 0; i < t.NumElements(); ++i) {
dynamic_buffer.AddString(tf_data(i).data(), tf_data(i).size());
}

View File

@ -60,7 +60,7 @@ int FillTensorWithTfLiteHexString(tensorflow::Tensor* tensor,
int num_strings = values_as_string.empty() ? 0 : GetStringCount(s.data());
if (num_strings == tensor->NumElements()) {
auto data = tensor->flat<string>();
auto data = tensor->flat<tensorflow::tstring>();
for (size_t i = 0; i < num_strings; ++i) {
auto ref = GetString(s.data(), i);
data(i).assign(ref.str, ref.len);
@ -87,7 +87,7 @@ string TensorDataToCsvString(const tensorflow::Tensor& tensor) {
string TensorDataToTfLiteHexString(const tensorflow::Tensor& tensor) {
DynamicBuffer dynamic_buffer;
auto data = tensor.flat<string>();
auto data = tensor.flat<tensorflow::tstring>();
for (int i = 0; i < tensor.NumElements(); ++i) {
dynamic_buffer.AddString(data(i).data(), data(i).size());
}