Rename WriteToTensor to make it explicit the fact that it set the shape to {size}
PiperOrigin-RevId: 223521732
This commit is contained in:
parent
f5291176d4
commit
3e66cee177
@ -566,7 +566,7 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
|
|||||||
DynamicBuffer buf;
|
DynamicBuffer buf;
|
||||||
StringRef str_ref = GetString(input, 0);
|
StringRef str_ref = GetString(input, 0);
|
||||||
buf.AddString(str_ref);
|
buf.AddString(str_ref);
|
||||||
buf.WriteToTensor(output);
|
buf.WriteToTensorAsVector(output);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -278,7 +278,7 @@ void WriteMultiDimensionalStringArray(JNIEnv* env, jobject src,
|
|||||||
tflite::DynamicBuffer dst_buffer;
|
tflite::DynamicBuffer dst_buffer;
|
||||||
PopulateStringDynamicBuffer(env, src, &dst_buffer, tensor->dims->size);
|
PopulateStringDynamicBuffer(env, src, &dst_buffer, tensor->dims->size);
|
||||||
if (!env->ExceptionCheck()) {
|
if (!env->ExceptionCheck()) {
|
||||||
dst_buffer.WriteToTensor(tensor);
|
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
|
|||||||
const auto string_ref = GetString(input, pos);
|
const auto string_ref = GetString(input, pos);
|
||||||
buffer.AddString(string_ref.str, string_ref.len);
|
buffer.AddString(string_ref.str, string_ref.len);
|
||||||
}
|
}
|
||||||
buffer.WriteToTensor(output);
|
buffer.WriteToTensorAsVector(output);
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (output->type == kTfLiteString) {
|
if (output->type == kTfLiteString) {
|
||||||
buf.WriteToTensor(output);
|
buf.WriteToTensorAsVector(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
|
@ -107,7 +107,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
// Generate n-grams recursively.
|
// Generate n-grams recursively.
|
||||||
tflite::DynamicBuffer buf;
|
tflite::DynamicBuffer buf;
|
||||||
if (words.size() < params->ngram_size) {
|
if (words.size() < params->ngram_size) {
|
||||||
buf.WriteToTensor(GetOutput(context, node, 0));
|
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buf.WriteToTensor(GetOutput(context, node, 0));
|
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -199,7 +199,7 @@ class SingleOpModel {
|
|||||||
for (const string& s : content) {
|
for (const string& s : content) {
|
||||||
buf.AddString(s.data(), s.length());
|
buf.AddString(s.data(), s.length());
|
||||||
}
|
}
|
||||||
buf.WriteToTensor(tensor);
|
buf.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Populate the tensor given its index.
|
// Populate the tensor given its index.
|
||||||
|
@ -92,7 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
tflite::DynamicBuffer buf;
|
tflite::DynamicBuffer buf;
|
||||||
buf.AddString(result.data(), result.length());
|
buf.AddString(result.data(), result.length());
|
||||||
buf.WriteToTensor(GetOutput(context, node, 0));
|
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ void ExecuteTfLite(const std::string& sentence,
|
|||||||
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
|
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
|
||||||
tflite::DynamicBuffer buf;
|
tflite::DynamicBuffer buf;
|
||||||
buf.AddString(sentence.data(), sentence.length());
|
buf.AddString(sentence.data(), sentence.length());
|
||||||
buf.WriteToTensor(input);
|
buf.WriteToTensorAsVector(input);
|
||||||
interpreter->AllocateTensors();
|
interpreter->AllocateTensors();
|
||||||
|
|
||||||
interpreter->Invoke();
|
interpreter->Invoke();
|
||||||
|
@ -96,8 +96,7 @@ int DynamicBuffer::WriteToBuffer(char** buffer) {
|
|||||||
return bytes;
|
return bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
|
void DynamicBuffer::WriteToTensorAsVector(TfLiteTensor* tensor) {
|
||||||
// Set tensor content pointer to tensor_buffer, and release original data.
|
|
||||||
auto dims = TfLiteIntArrayCreate(1);
|
auto dims = TfLiteIntArrayCreate(1);
|
||||||
dims->data[0] = offset_.size() - 1; // Store number of strings.
|
dims->data[0] = offset_.size() - 1; // Store number of strings.
|
||||||
WriteToTensor(tensor, dims);
|
WriteToTensor(tensor, dims);
|
||||||
@ -108,6 +107,10 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor,
|
|||||||
char* tensor_buffer;
|
char* tensor_buffer;
|
||||||
int bytes = WriteToBuffer(&tensor_buffer);
|
int bytes = WriteToBuffer(&tensor_buffer);
|
||||||
|
|
||||||
|
if (new_shape == nullptr) {
|
||||||
|
new_shape = TfLiteIntArrayCopy(tensor->dims);
|
||||||
|
}
|
||||||
|
|
||||||
// Set tensor content pointer to tensor_buffer, and release original data.
|
// Set tensor content pointer to tensor_buffer, and release original data.
|
||||||
TfLiteTensorReset(tensor->type, tensor->name, new_shape, tensor->params,
|
TfLiteTensorReset(tensor->type, tensor->name, new_shape, tensor->params,
|
||||||
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
|
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
|
||||||
|
@ -74,12 +74,18 @@ class DynamicBuffer {
|
|||||||
// The function allocates space for the buffer but does NOT take ownership.
|
// The function allocates space for the buffer but does NOT take ownership.
|
||||||
int WriteToBuffer(char** buffer);
|
int WriteToBuffer(char** buffer);
|
||||||
|
|
||||||
// Fill content into a string tensor, with the given new_shape. The new
|
// Fill content into a string tensor, with the given new_shape. The new shape
|
||||||
// shape must match the number of strings in this object.
|
// must match the number of strings in this object. Caller relinquishes
|
||||||
|
// ownership of new_shape. If 'new_shape' is nullptr, keep the tensor's
|
||||||
|
// existing shape.
|
||||||
void WriteToTensor(TfLiteTensor* tensor, TfLiteIntArray* new_shape);
|
void WriteToTensor(TfLiteTensor* tensor, TfLiteIntArray* new_shape);
|
||||||
|
|
||||||
// Fill content into a string tensor. Set shape to {num_strings}.
|
// Fill content into a string tensor. Set shape to {num_strings}.
|
||||||
void WriteToTensor(TfLiteTensor* tensor);
|
void WriteToTensorAsVector(TfLiteTensor* tensor);
|
||||||
|
|
||||||
|
// Deprecated. Use WriteToTensorAsVector() or pass in the new shpe.
|
||||||
|
// TODO(b/120230709): remove when people migrate away.
|
||||||
|
void WriteToTensor(TfLiteTensor* tensor) { WriteToTensorAsVector(tensor); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Data buffer to store contents of strings, not including headers.
|
// Data buffer to store contents of strings, not including headers.
|
||||||
|
@ -55,7 +55,7 @@ TEST(StringUtil, TestStringUtil) {
|
|||||||
new_shape->data[0] = 2;
|
new_shape->data[0] = 2;
|
||||||
new_shape->data[1] = 1;
|
new_shape->data[1] = 1;
|
||||||
buf0.WriteToTensor(t0, new_shape);
|
buf0.WriteToTensor(t0, new_shape);
|
||||||
buf1.WriteToTensor(t1);
|
buf1.WriteToTensorAsVector(t1);
|
||||||
|
|
||||||
// Check tensor shapes.
|
// Check tensor shapes.
|
||||||
EXPECT_EQ(t0->dims->size, 2);
|
EXPECT_EQ(t0->dims->size, 2);
|
||||||
@ -99,7 +99,7 @@ TEST(StringUtil, TestAddJoinedString) {
|
|||||||
|
|
||||||
DynamicBuffer buf;
|
DynamicBuffer buf;
|
||||||
buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' ');
|
buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' ');
|
||||||
buf.WriteToTensor(t0);
|
buf.WriteToTensorAsVector(t0);
|
||||||
|
|
||||||
ASSERT_EQ(GetStringCount(t0), 1);
|
ASSERT_EQ(GetStringCount(t0), 1);
|
||||||
StringRef str_ref;
|
StringRef str_ref;
|
||||||
@ -115,7 +115,7 @@ TEST(StringUtil, TestEmptyList) {
|
|||||||
t0->type = kTfLiteString;
|
t0->type = kTfLiteString;
|
||||||
t0->allocation_type = kTfLiteDynamic;
|
t0->allocation_type = kTfLiteDynamic;
|
||||||
DynamicBuffer buf;
|
DynamicBuffer buf;
|
||||||
buf.WriteToTensor(t0);
|
buf.WriteToTensorAsVector(t0);
|
||||||
|
|
||||||
ASSERT_EQ(GetStringCount(t0), 0);
|
ASSERT_EQ(GetStringCount(t0), 0);
|
||||||
ASSERT_EQ(t0->bytes, 8);
|
ASSERT_EQ(t0->bytes, 8);
|
||||||
|
@ -279,7 +279,7 @@ void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
|
|||||||
FillRandomString(&buffer, sizes, []() {
|
FillRandomString(&buffer, sizes, []() {
|
||||||
return "we're have some friends over saturday to hang out in the yard";
|
return "we're have some friends over saturday to hang out in the yard";
|
||||||
});
|
});
|
||||||
buffer.WriteToTensor(interpreter->tensor(i));
|
buffer.WriteToTensor(interpreter->tensor(i), /*new_shape=*/nullptr);
|
||||||
} else {
|
} else {
|
||||||
TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
|
TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
|
||||||
<< " of type " << t->type;
|
<< " of type " << t->type;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user