Updated lite/ to use tstring.
This is a part of a larger migration effort for tensorflow::tstring. See: https://github.com/tensorflow/community/pull/91 PiperOrigin-RevId: 263230002
This commit is contained in:
parent
64b01d5574
commit
ee875a8bb5
tensorflow/lite
@ -100,18 +100,20 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
||||
|
||||
~StringTfLiteTensorBuffer() override {
|
||||
LogDeallocation();
|
||||
tensorflow::TypedAllocator::Deallocate<tensorflow::string>(
|
||||
tensorflow::cpu_allocator(), static_cast<tensorflow::string*>(data()),
|
||||
tensorflow::TypedAllocator::Deallocate<tensorflow::tstring>(
|
||||
tensorflow::cpu_allocator(), static_cast<tensorflow::tstring*>(data()),
|
||||
num_strings_);
|
||||
}
|
||||
|
||||
size_t size() const override { return num_strings_ * sizeof(string); }
|
||||
size_t size() const override {
|
||||
return num_strings_ * sizeof(tensorflow::tstring);
|
||||
}
|
||||
|
||||
private:
|
||||
StringTfLiteTensorBuffer(const TfLiteTensor* tensor, int num_strings)
|
||||
: BaseTfLiteTensorBuffer(
|
||||
num_strings != 0
|
||||
? tensorflow::TypedAllocator::Allocate<tensorflow::string>(
|
||||
? tensorflow::TypedAllocator::Allocate<tensorflow::tstring>(
|
||||
tensorflow::cpu_allocator(), num_strings,
|
||||
tensorflow::AllocationAttributes())
|
||||
: nullptr),
|
||||
@ -119,7 +121,7 @@ class StringTfLiteTensorBuffer : public BaseTfLiteTensorBuffer {
|
||||
LogAllocation();
|
||||
|
||||
if (data()) {
|
||||
string* p = static_cast<string*>(data());
|
||||
tensorflow::tstring* p = static_cast<tensorflow::tstring*>(data());
|
||||
for (size_t i = 0; i < num_strings_; ++p, ++i) {
|
||||
auto ref = GetString(tensor->data.raw, i);
|
||||
p->assign(ref.str, ref.len);
|
||||
|
@ -128,7 +128,7 @@ TEST(BufferMapTest, SetFromTfLiteString) {
|
||||
buffer_map.SetFromTfLite(0, t.get());
|
||||
ASSERT_TRUE(buffer_map.HasTensor(0));
|
||||
|
||||
EXPECT_THAT(GetTensorData<string>(buffer_map.GetTensor(0)),
|
||||
EXPECT_THAT(GetTensorData<tensorflow::tstring>(buffer_map.GetTensor(0)),
|
||||
ElementsAre("", "", "", "str1", "", ""));
|
||||
|
||||
// Also check details of the tensor.
|
||||
@ -162,7 +162,7 @@ TEST(BufferMapTest, SetFromTfLiteStringTwice) {
|
||||
buffer_map.SetFromTfLite(0, t1.get());
|
||||
buffer_map.SetFromTfLite(0, t2.get());
|
||||
|
||||
EXPECT_THAT(GetTensorData<string>(buffer_map.GetTensor(0)),
|
||||
EXPECT_THAT(GetTensorData<tensorflow::tstring>(buffer_map.GetTensor(0)),
|
||||
ElementsAre("", "", "", "s3", "", "", "s1", "s2"));
|
||||
}
|
||||
|
||||
|
@ -97,7 +97,7 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
|
||||
}
|
||||
DynamicBuffer dynamic_buffer;
|
||||
|
||||
auto tf_data = t.flat<string>();
|
||||
auto tf_data = t.flat<tensorflow::tstring>();
|
||||
for (int i = 0; i < t.NumElements(); ++i) {
|
||||
dynamic_buffer.AddString(tf_data(i).data(), tf_data(i).size());
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ int FillTensorWithTfLiteHexString(tensorflow::Tensor* tensor,
|
||||
int num_strings = values_as_string.empty() ? 0 : GetStringCount(s.data());
|
||||
|
||||
if (num_strings == tensor->NumElements()) {
|
||||
auto data = tensor->flat<string>();
|
||||
auto data = tensor->flat<tensorflow::tstring>();
|
||||
for (size_t i = 0; i < num_strings; ++i) {
|
||||
auto ref = GetString(s.data(), i);
|
||||
data(i).assign(ref.str, ref.len);
|
||||
@ -87,7 +87,7 @@ string TensorDataToCsvString(const tensorflow::Tensor& tensor) {
|
||||
string TensorDataToTfLiteHexString(const tensorflow::Tensor& tensor) {
|
||||
DynamicBuffer dynamic_buffer;
|
||||
|
||||
auto data = tensor.flat<string>();
|
||||
auto data = tensor.flat<tensorflow::tstring>();
|
||||
for (int i = 0; i < tensor.NumElements(); ++i) {
|
||||
dynamic_buffer.AddString(data(i).data(), data(i).size());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user