Rename WriteToTensor to make it explicit the fact that it set the shape to {size}

PiperOrigin-RevId: 223521732
This commit is contained in:
A. Unique TensorFlower 2018-11-30 08:48:24 -08:00 committed by TensorFlower Gardener
parent f5291176d4
commit 3e66cee177
12 changed files with 27 additions and 18 deletions

View File

@ -566,7 +566,7 @@ TEST(BasicInterpreter, ThreeStepAllocate) {
DynamicBuffer buf;
StringRef str_ref = GetString(input, 0);
buf.AddString(str_ref);
buf.WriteToTensor(output);
buf.WriteToTensorAsVector(output);
return kTfLiteOk;
};

View File

@ -278,7 +278,7 @@ void WriteMultiDimensionalStringArray(JNIEnv* env, jobject src,
tflite::DynamicBuffer dst_buffer;
PopulateStringDynamicBuffer(env, src, &dst_buffer, tensor->dims->size);
if (!env->ExceptionCheck()) {
dst_buffer.WriteToTensor(tensor);
dst_buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
}
}

View File

@ -118,7 +118,7 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
const auto string_ref = GetString(input, pos);
buffer.AddString(string_ref.str, string_ref.len);
}
buffer.WriteToTensor(output);
buffer.WriteToTensorAsVector(output);
return kTfLiteOk;
}

View File

@ -137,7 +137,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
}
if (output->type == kTfLiteString) {
buf.WriteToTensor(output);
buf.WriteToTensorAsVector(output);
}
return kTfLiteOk;

View File

@ -107,7 +107,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Generate n-grams recursively.
tflite::DynamicBuffer buf;
if (words.size() < params->ngram_size) {
buf.WriteToTensor(GetOutput(context, node, 0));
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
return kTfLiteOk;
}
@ -145,7 +145,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
}
buf.WriteToTensor(GetOutput(context, node, 0));
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
return kTfLiteOk;
}
} // namespace

View File

@ -199,7 +199,7 @@ class SingleOpModel {
for (const string& s : content) {
buf.AddString(s.data(), s.length());
}
buf.WriteToTensor(tensor);
buf.WriteToTensor(tensor, /*new_shape=*/nullptr);
}
// Populate the tensor given its index.

View File

@ -92,7 +92,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
tflite::DynamicBuffer buf;
buf.AddString(result.data(), result.length());
buf.WriteToTensor(GetOutput(context, node, 0));
buf.WriteToTensorAsVector(GetOutput(context, node, 0));
return kTfLiteOk;
}

View File

@ -49,7 +49,7 @@ void ExecuteTfLite(const std::string& sentence,
TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
tflite::DynamicBuffer buf;
buf.AddString(sentence.data(), sentence.length());
buf.WriteToTensor(input);
buf.WriteToTensorAsVector(input);
interpreter->AllocateTensors();
interpreter->Invoke();

View File

@ -96,8 +96,7 @@ int DynamicBuffer::WriteToBuffer(char** buffer) {
return bytes;
}
void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
// Set tensor content pointer to tensor_buffer, and release original data.
void DynamicBuffer::WriteToTensorAsVector(TfLiteTensor* tensor) {
auto dims = TfLiteIntArrayCreate(1);
dims->data[0] = offset_.size() - 1; // Store number of strings.
WriteToTensor(tensor, dims);
@ -108,6 +107,10 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor,
char* tensor_buffer;
int bytes = WriteToBuffer(&tensor_buffer);
if (new_shape == nullptr) {
new_shape = TfLiteIntArrayCopy(tensor->dims);
}
// Set tensor content pointer to tensor_buffer, and release original data.
TfLiteTensorReset(tensor->type, tensor->name, new_shape, tensor->params,
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,

View File

@ -74,12 +74,18 @@ class DynamicBuffer {
// The function allocates space for the buffer but does NOT take ownership.
int WriteToBuffer(char** buffer);
// Fill content into a string tensor, with the given new_shape. The new
// shape must match the number of strings in this object.
// Fill content into a string tensor, with the given new_shape. The new shape
// must match the number of strings in this object. Caller relinquishes
// ownership of new_shape. If 'new_shape' is nullptr, keep the tensor's
// existing shape.
void WriteToTensor(TfLiteTensor* tensor, TfLiteIntArray* new_shape);
// Fill content into a string tensor. Set shape to {num_strings}.
void WriteToTensor(TfLiteTensor* tensor);
void WriteToTensorAsVector(TfLiteTensor* tensor);
// Deprecated. Use WriteToTensorAsVector() or pass in the new shpe.
// TODO(b/120230709): remove when people migrate away.
void WriteToTensor(TfLiteTensor* tensor) { WriteToTensorAsVector(tensor); }
private:
// Data buffer to store contents of strings, not including headers.

View File

@ -55,7 +55,7 @@ TEST(StringUtil, TestStringUtil) {
new_shape->data[0] = 2;
new_shape->data[1] = 1;
buf0.WriteToTensor(t0, new_shape);
buf1.WriteToTensor(t1);
buf1.WriteToTensorAsVector(t1);
// Check tensor shapes.
EXPECT_EQ(t0->dims->size, 2);
@ -99,7 +99,7 @@ TEST(StringUtil, TestAddJoinedString) {
DynamicBuffer buf;
buf.AddJoinedString({{s0, 3}, {s1, 4}, {s2, 0}, {s3, 3}}, ' ');
buf.WriteToTensor(t0);
buf.WriteToTensorAsVector(t0);
ASSERT_EQ(GetStringCount(t0), 1);
StringRef str_ref;
@ -115,7 +115,7 @@ TEST(StringUtil, TestEmptyList) {
t0->type = kTfLiteString;
t0->allocation_type = kTfLiteDynamic;
DynamicBuffer buf;
buf.WriteToTensor(t0);
buf.WriteToTensorAsVector(t0);
ASSERT_EQ(GetStringCount(t0), 0);
ASSERT_EQ(t0->bytes, 8);

View File

@ -279,7 +279,7 @@ void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
FillRandomString(&buffer, sizes, []() {
return "we're have some friends over saturday to hang out in the yard";
});
buffer.WriteToTensor(interpreter->tensor(i));
buffer.WriteToTensor(interpreter->tensor(i), /*new_shape=*/nullptr);
} else {
TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
<< " of type " << t->type;