Fix incorrect output resizing

PiperOrigin-RevId: 359453070
Change-Id: If6ff8e22e376dd6d2faf78684566b86825b559cc
This commit is contained in:
David Rim 2021-02-24 23:16:32 -08:00 committed by TensorFlower Gardener
parent 59b6ae1e28
commit 60c930f605
2 changed files with 131 additions and 16 deletions

View File

@ -717,8 +717,10 @@ TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
data->sparse_size = nodedef.attr().at("num_sparse").i();
}
auto dense_shapes = nodedef.attr().at("dense_shapes").list();
for (int i = 0; i < dense_shapes.shape_size(); ++i) {
data->dense_shapes.push_back(dense_shapes.shape(i));
if (data->dense_shapes.empty()) {
for (int i = 0; i < dense_shapes.shape_size(); ++i) {
data->dense_shapes.push_back(dense_shapes.shape(i));
}
}
} else {
const flexbuffers::Map& m =
@ -760,10 +762,13 @@ TfLiteStatus PrepareParseExample(TfLiteContext* context, TfLiteNode* node) {
GetOutput(context, node, data->sparse_size * 3 + i);
TfLiteIntArray* output_size = TfLiteIntArrayCopy(dense_key_tensor->dims);
if (missing_shape_info) {
RuntimeShape runtime_shape = GetTensorShape(dense_key_tensor);
data->dense_shapes.push_back(TfLiteToTfShape(output_size));
}
output_size->data[0] = batch_size * output_size->data[0];
// use original tflite tensor size if inputs are resized.
const int original_size = data->dense_shapes[i].dims() > 0
? data->dense_shapes[i].dim_size(0)
: 1;
output_size->data[0] = batch_size * original_size;
context->ResizeTensor(context, dense_key_tensor, output_size);
}

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/lite/kernels/parse_example/parse_example.h"
#include <cstdint>
#include <initializer_list>
#include "flatbuffers/flexbuffers.h" // from @flatbuffers
@ -150,10 +151,39 @@ const char* kNodeDefTxt4 = R"pb(
}
)pb";
const char* kNodeDefTxt5 = R"pb(
name: "ParseExample/ParseExample"
op: "ParseExample"
input: "serialized"
input: "ParseExample/ParseExample/names"
input: "ParseExample/ParseExample/dense_keys_0"
input: "ParseExample/Const"
attr {
key: "Ndense"
value { i: 1 }
}
attr {
key: "Nsparse"
value { i: 0 }
}
attr {
key: "Tdense"
value { list { type: DT_FLOAT } }
}
attr {
key: "dense_shapes"
value {}
}
attr {
key: "sparse_types"
value { list { type: DT_FLOAT } }
}
)pb";
template <typename DefaultType>
class ParseExampleOpModel : public SingleOpModel {
public:
ParseExampleOpModel(std::string serialized_example,
ParseExampleOpModel(std::vector<std::string> serialized_examples,
std::vector<std::string> sparse_keys,
std::vector<std::string> dense_keys,
std::initializer_list<DefaultType> dense_defaults,
@ -161,7 +191,9 @@ class ParseExampleOpModel : public SingleOpModel {
std::vector<TensorType> sparse_types,
const char* text_def, int dense_size = 2) {
// Example
string_indices_.push_back(AddInput(TensorData(TensorType_STRING, {1})));
const int input_size = serialized_examples.size();
auto input_tensor_data = TensorData(TensorType_STRING, {input_size});
string_indices_.push_back(AddInput(input_tensor_data));
// Names
string_indices_.push_back(
AddConstInput<std::string>(TensorData(TensorType_STRING, {0}), {""}));
@ -206,9 +238,9 @@ class ParseExampleOpModel : public SingleOpModel {
fbb.Finish();
const auto buffer = fbb.GetBuffer();
SetCustomOp("ParseExample", buffer, Register_PARSE_EXAMPLE);
BuildInterpreter({});
BuildInterpreter({{input_size}});
int idx = 0;
PopulateStringTensor(string_indices_[idx++], {serialized_example});
PopulateStringTensor(string_indices_[idx++], serialized_examples);
PopulateStringTensor(string_indices_[idx++], {""});
for (const auto& key : sparse_keys) {
PopulateStringTensor(string_indices_[idx++], {key});
@ -218,6 +250,16 @@ class ParseExampleOpModel : public SingleOpModel {
}
}
void ResizeInputTensor(std::vector<std::vector<int>> input_shapes) {
for (size_t i = 0; i < input_shapes.size(); ++i) {
const int input_idx = interpreter_->inputs()[i];
if (input_idx == kTfLiteOptionalTensor) continue;
const auto& shape = input_shapes[i];
if (shape.empty()) continue;
CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
}
}
template <typename T>
std::vector<T> GetSparseIndicesOutput(int i) {
return ExtractVector<T>(sparse_indices_outputs_[i]);
@ -267,7 +309,7 @@ TEST(ParseExampleOpsTest, SimpleTest) {
tf::Example example;
tf::AppendFeatureValues<float>({1.5f, 1.5f}, "time", &example);
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
ParseExampleOpModel<float> m(example.SerializeAsString(), {}, {"time"},
ParseExampleOpModel<float> m({example.SerializeAsString()}, {}, {"time"},
{0.f, 0.f}, {TensorType_FLOAT32}, {},
kNodeDefTxt);
m.Invoke();
@ -278,7 +320,7 @@ TEST(ParseExampleOpsTest, SimpleTest) {
TEST(ParseExampleOpsTest, SparseTest) {
tf::Example example;
tf::AppendFeatureValues<float>({1.5f}, "time", &example);
ParseExampleOpModel<float> m(example.SerializeAsString(), {"time"}, {}, {},
ParseExampleOpModel<float> m({example.SerializeAsString()}, {"time"}, {}, {},
{}, {TensorType_FLOAT32}, kNodeDefTxt2, 0);
m.Invoke();
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
@ -295,9 +337,9 @@ TEST(ParseExampleOpsTest, SimpleBytesTest) {
tf::AppendFeatureValues<tensorflow::tstring>({test_data}, "time", &example);
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
std::string default_value = "missing";
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {}, {"time"},
{default_value}, {TensorType_STRING}, {},
kNodeDefTxt3, 1);
ParseExampleOpModel<std::string> m({example.SerializeAsString()}, {},
{"time"}, {default_value},
{TensorType_STRING}, {}, kNodeDefTxt3, 1);
m.PopulateStringTensor(m.DenseDefaults(), {default_value});
m.Invoke();
std::vector<string> c = m.GetStringOutput(m.DenseOutputs(0));
@ -311,9 +353,9 @@ TEST(ParseExampleOpsTest, SparseBytesTest) {
tf::AppendFeatureValues<tensorflow::tstring>({test_data, test_data}, "time",
&example);
tf::AppendFeatureValues<float>({1.0f, 1.0f}, "num", &example);
ParseExampleOpModel<std::string> m(example.SerializeAsString(), {"time"}, {},
{}, {}, {TensorType_STRING}, kNodeDefTxt4,
0);
ParseExampleOpModel<std::string> m({example.SerializeAsString()}, {"time"},
{}, {}, {}, {TensorType_STRING},
kNodeDefTxt4, 0);
m.Invoke();
EXPECT_THAT(m.GetSparseIndicesOutput<int64_t>(0),
testing::ElementsAreArray({0, 0, 0, 1}));
@ -325,6 +367,74 @@ TEST(ParseExampleOpsTest, SparseBytesTest) {
testing::ElementsAreArray({1, 2}));
}
TEST(ParseExampleOpsTest, ResizeTest) {
const int num_tests = 3;
std::vector<tf::Example> examples(num_tests);
std::vector<std::vector<float>> expected(num_tests);
std::vector<std::vector<std::string>> inputs(num_tests);
std::vector<int> sizes;
for (int i = 0; i < num_tests; ++i) {
float val = i;
std::initializer_list<float> floats = {val + val / 10.f, -val - val / 10.f};
tf::AppendFeatureValues<float>({val, val}, "num", &examples[i]);
tf::AppendFeatureValues<float>(floats, "time", &examples[i]);
sizes.push_back((num_tests - i) * 2);
for (int j = 0; j < sizes.back(); ++j) {
inputs[i].push_back(examples[i].SerializeAsString());
expected[i].insert(expected[i].end(), floats.begin(), floats.end());
}
}
ParseExampleOpModel<float> m(inputs[0], {}, {"time"}, {0.f, 0.f},
{TensorType_FLOAT32}, {}, kNodeDefTxt);
m.Invoke();
EXPECT_THAT(m.GetDenseOutput<float>(0),
ElementsAreArray(ArrayFloatNear(expected[0])));
for (int i = 1; i < num_tests; ++i) {
m.ResizeInputTensor({{sizes[i]}});
m.AllocateAndDelegate(false);
m.PopulateStringTensor(0, inputs[i]);
m.Invoke();
EXPECT_THAT(m.GetDenseOutput<float>(0),
ElementsAreArray(ArrayFloatNear(expected[i])));
}
}
TEST(ParseExampleOpsTest, ResizeMissingInfoTest) {
const int num_tests = 3;
std::vector<tf::Example> examples(num_tests);
std::vector<std::vector<float>> expected(num_tests);
std::vector<std::vector<std::string>> inputs(num_tests);
std::vector<int> sizes;
for (int i = 0; i < num_tests; ++i) {
float val = i;
std::initializer_list<float> floats = {val + val / 10.f, -val - val / 10.f};
tf::AppendFeatureValues<float>({val, val}, "num", &examples[i]);
tf::AppendFeatureValues<float>(floats, "time", &examples[i]);
sizes.push_back((num_tests - i) * 2);
for (int j = 0; j < sizes.back(); ++j) {
inputs[i].push_back(examples[i].SerializeAsString());
expected[i].insert(expected[i].end(), floats.begin(), floats.end());
}
}
ParseExampleOpModel<float> m(inputs[0], {}, {"time"}, {0.f, 0.f},
{TensorType_FLOAT32}, {}, kNodeDefTxt5);
m.Invoke();
EXPECT_THAT(m.GetDenseOutput<float>(0),
ElementsAreArray(ArrayFloatNear(expected[0])));
for (int i = 1; i < num_tests; ++i) {
m.ResizeInputTensor({{sizes[i]}});
m.AllocateAndDelegate(false);
m.PopulateStringTensor(0, inputs[i]);
m.Invoke();
EXPECT_THAT(m.GetDenseOutput<float>(0),
ElementsAreArray(ArrayFloatNear(expected[i])));
}
}
} // namespace custom
} // namespace ops
} // namespace tflite