Rename WriteToTensor to make it explicit the fact that it set the shape to {size}
PiperOrigin-RevId: 223521732
This commit is contained in:
parent
f5291176d4
commit
3e66cee177
@ -566,7 +566,7 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
|
||||
DynamicBuffer buf;
|
||||
StringRef str_ref = GetString(input, 0);
|
||||
buf.AddString(str_ref);
|
||||
buf.WriteToTensor(output);
|
||||
buf.WriteToTensorAsVector(output);
|
||||
return kTfLiteOk;
|
||||
};
|
||||
|
||||
|
@ -278,7 +278,7 @@ void WriteMultiDimensionalStringArray(JNIEnv* env, jobject src,
|
||||
tflite::DynamicBuffer dst_buffer;
|
||||
PopulateStringDynamicBuffer(env, src, &dst_buffer, tensor->dims->size);
|
||||
if (!env->ExceptionCheck()) {
|
||||
dst_buffer.WriteToTensor(tensor);
|
||||
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
|
||||
const auto string_ref = GetString(input, pos);
|
||||
buffer.AddString(string_ref.str, string_ref.len);
|
||||
}
|
||||
buffer.WriteToTensor(output);
|
||||
buffer.WriteToTensorAsVector(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -137,7 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
if (output->type == kTfLiteString) {
|
||||
buf.WriteToTensor(output);
|
||||
buf.WriteToTensorAsVector(output);
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
|
@ -107,7 +107,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// Generate n-grams recursively.
|
||||
tflite::DynamicBuffer buf;
|
||||
if (words.size() < params->ngram_size) {
|
||||
buf.WriteToTensor(GetOutput(context, node, 0));
|
||||
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
@ -145,7 +145,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteToTensor(GetOutput(context, node, 0));
|
||||
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -199,7 +199,7 @@ class SingleOpModel {
|
||||
for (const string& s : content) {
|
||||
buf.AddString(s.data(), s.length());
|
||||
}
|
||||
buf.WriteToTensor(tensor);
|
||||
buf.WriteToTensor(tensor, /*new_shape=*/nullptr);
|
||||
}
|
||||
|
||||
// Populate the tensor given its index.
|
||||
|
@ -92,7 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
|
||||
tflite::DynamicBuffer buf;
|
||||
buf.AddString(result.data(), result.length());
|
||||
buf.WriteToTensor(GetOutput(context, node, 0));
|
||||
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ void ExecuteTfLite(const std::string& sentence,
|
||||
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
|
||||
tflite::DynamicBuffer buf;
|
||||
buf.AddString(sentence.data(), sentence.length());
|
||||
buf.WriteToTensor(input);
|
||||
buf.WriteToTensorAsVector(input);
|
||||
interpreter->AllocateTensors();
|
||||
|
||||
interpreter->Invoke();
|
||||
|
@ -96,8 +96,7 @@ int DynamicBuffer::WriteToBuffer(char** buffer) {
|
||||
return bytes;
|
||||
}
|
||||
|
||||
void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
|
||||
// Set tensor content pointer to tensor_buffer, and release original data.
|
||||
void DynamicBuffer::WriteToTensorAsVector(TfLiteTensor* tensor) {
|
||||
auto dims = TfLiteIntArrayCreate(1);
|
||||
dims->data[0] = offset_.size() - 1; // Store number of strings.
|
||||
WriteToTensor(tensor, dims);
|
||||
@ -108,6 +107,10 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor,
|
||||
char* tensor_buffer;
|
||||
int bytes = WriteToBuffer(&tensor_buffer);
|
||||
|
||||
if (new_shape == nullptr) {
|
||||
new_shape = TfLiteIntArrayCopy(tensor->dims);
|
||||
}
|
||||
|
||||
// Set tensor content pointer to tensor_buffer, and release original data.
|
||||
TfLiteTensorReset(tensor->type, tensor->name, new_shape, tensor->params,
|
||||
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
|
||||
|
@ -74,12 +74,18 @@ class DynamicBuffer {
|
||||
// The function allocates space for the buffer but does NOT take ownership.
|
||||
int WriteToBuffer(char** buffer);
|
||||
|
||||
// Fill content into a string tensor, with the given new_shape. The new
|
||||
// shape must match the number of strings in this object.
|
||||
// Fill content into a string tensor, with the given new_shape. The new shape
|
||||
// must match the number of strings in this object. Caller relinquishes
|
||||
// ownership of new_shape. If 'new_shape' is nullptr, keep the tensor's
|
||||
// existing shape.
|
||||
void WriteToTensor(TfLiteTensor* tensor, TfLiteIntArray* new_shape);
|
||||
|
||||
// Fill content into a string tensor. Set shape to {num_strings}.
|
||||
void WriteToTensor(TfLiteTensor* tensor);
|
||||
void WriteToTensorAsVector(TfLiteTensor* tensor);
|
||||
|
||||
// Deprecated. Use WriteToTensorAsVector() or pass in the new shpe.
|
||||
// TODO(b/120230709): remove when people migrate away.
|
||||
void WriteToTensor(TfLiteTensor* tensor) { WriteToTensorAsVector(tensor); }
|
||||
|
||||
private:
|
||||
// Data buffer to store contents of strings, not including headers.
|
||||
|
@ -55,7 +55,7 @@ TEST(StringUtil, TestStringUtil) {
|
||||
new_shape->data[0] = 2;
|
||||
new_shape->data[1] = 1;
|
||||
buf0.WriteToTensor(t0, new_shape);
|
||||
buf1.WriteToTensor(t1);
|
||||
buf1.WriteToTensorAsVector(t1);
|
||||
|
||||
// Check tensor shapes.
|
||||
EXPECT_EQ(t0->dims->size, 2);
|
||||
@ -99,7 +99,7 @@ TEST(StringUtil, TestAddJoinedString) {
|
||||
|
||||
DynamicBuffer buf;
|
||||
buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' ');
|
||||
buf.WriteToTensor(t0);
|
||||
buf.WriteToTensorAsVector(t0);
|
||||
|
||||
ASSERT_EQ(GetStringCount(t0), 1);
|
||||
StringRef str_ref;
|
||||
@ -115,7 +115,7 @@ TEST(StringUtil, TestEmptyList) {
|
||||
t0->type = kTfLiteString;
|
||||
t0->allocation_type = kTfLiteDynamic;
|
||||
DynamicBuffer buf;
|
||||
buf.WriteToTensor(t0);
|
||||
buf.WriteToTensorAsVector(t0);
|
||||
|
||||
ASSERT_EQ(GetStringCount(t0), 0);
|
||||
ASSERT_EQ(t0->bytes, 8);
|
||||
|
@ -279,7 +279,7 @@ void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
|
||||
FillRandomString(&buffer, sizes, []() {
|
||||
return "we're have some friends over saturday to hang out in the yard";
|
||||
});
|
||||
buffer.WriteToTensor(interpreter->tensor(i));
|
||||
buffer.WriteToTensor(interpreter->tensor(i), /*new_shape=*/nullptr);
|
||||
} else {
|
||||
TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
|
||||
<< " of type " << t->type;
|
||||
|
Loading…
Reference in New Issue
Block a user