micro: remove lite-specific code from copy of FILL

Remove the bulk of lite-specific code from the micro
implementation of operator FILL.
- Flatten the namespace
- Don't resize output tensors
- Remove input and output types other than int8 and float32
- Don't use gtest
This commit is contained in:
Ryan Kuester 2020-12-11 10:48:06 -06:00
parent 9a7c57d1d8
commit a5a87b420b
2 changed files with 2 additions and 168 deletions

View File

@ -23,9 +23,6 @@ limitations under the License.
#include "tensorflow/lite/string_util.h" #include "tensorflow/lite/string_util.h"
namespace tflite { namespace tflite {
namespace ops {
namespace builtin {
namespace fill {
namespace { namespace {
@ -33,41 +30,6 @@ constexpr int kDimsTensor = 0;
constexpr int kValueTensor = 1; constexpr int kValueTensor = 1;
constexpr int kOutputTensor = 0; constexpr int kOutputTensor = 0;
template <typename T>
TfLiteStatus ResizeOutputImpl(TfLiteContext* context, const TfLiteTensor* dims,
TfLiteTensor* output) {
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(dims->dims->data[0]);
for (int i = 0; i < output_shape->size; ++i) {
T data = GetTensorData<T>(dims)[i];
if (data < 0) {
TfLiteIntArrayFree(output_shape);
TF_LITE_KERNEL_LOG(context, "Fill dimensions must be >= 0",
TfLiteTypeGetName(dims->type));
return kTfLiteError;
}
output_shape->data[i] = data;
}
return context->ResizeTensor(context, output, output_shape);
}
TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* dims,
TfLiteTensor* output) {
switch (dims->type) {
case kTfLiteInt32:
return ResizeOutputImpl<int32_t>(context, dims, output);
case kTfLiteInt64:
return ResizeOutputImpl<int64_t>(context, dims, output);
default:
TF_LITE_KERNEL_LOG(
context,
"Fill only currently supports int32, int64 for input 0, got %s.",
TfLiteTypeGetName(dims->type));
return kTfLiteError;
}
}
} // namespace
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@ -100,20 +62,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
TfLiteStatus FillString(const TfLiteTensor* value, TfLiteTensor* output) {
DynamicBuffer buffer;
const auto string_ref = GetString(value, 0);
int n = 1;
for (int i = 0; i < output->dims->size; ++i) {
n *= output->dims->data[i];
}
for (int i = 0; i < n; ++i) {
buffer.AddString(string_ref.str, string_ref.len);
}
buffer.WriteToTensor(output, /*new_shape=*/nullptr);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* value; const TfLiteTensor* value;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value)); TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kValueTensor, &value));
@ -132,26 +80,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorShape(output), \ GetTensorShape(output), \
GetTensorData<data_type>(output)) GetTensorData<data_type>(output))
switch (output->type) { switch (output->type) {
case kTfLiteInt32:
TF_LITE_FILL(int32_t);
break;
case kTfLiteInt64:
TF_LITE_FILL(int64_t);
break;
case kTfLiteFloat32: case kTfLiteFloat32:
TF_LITE_FILL(float); TF_LITE_FILL(float);
break; break;
case kTfLiteBool:
TF_LITE_FILL(bool);
break;
case kTfLiteString:
FillString(value, output);
break;
default: default:
TF_LITE_KERNEL_LOG( TF_LITE_KERNEL_LOG(
context, context,
"Fill only currently supports int32, int64, float32, bool, string " "Fill only currently supports float32 for input 1, got %d.",
"for input 1, got %d.",
TfLiteTypeGetName(value->type)); TfLiteTypeGetName(value->type));
return kTfLiteError; return kTfLiteError;
} }
@ -159,7 +94,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk; return kTfLiteOk;
} }
} // namespace fill } // namespace
TfLiteRegistration* Register_FILL() { TfLiteRegistration* Register_FILL() {
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
@ -167,6 +102,4 @@ TfLiteRegistration* Register_FILL() {
return &r; return &r;
} }
} // namespace builtin
} // namespace ops
} // namespace tflite } // namespace tflite

View File

@ -12,85 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <stdint.h>
#include <initializer_list>
#include <string>
#include <vector>
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_type.h"
namespace tflite {
namespace { namespace {
using ::testing::ElementsAreArray;
using ::testing::IsEmpty;
enum class TestType {
kConst = 0,
kDynamic = 1,
};
template <typename dims_type, typename value_type>
class FillOpModel : public SingleOpModel {
public:
explicit FillOpModel(TensorType dims_tensor_type,
std::initializer_list<int> dims_shape,
std::initializer_list<dims_type> dims_data,
value_type value, TestType input_tensor_types) {
if (input_tensor_types == TestType::kDynamic) {
dims_ = AddInput(dims_tensor_type);
value_ = AddInput(GetTensorType<value_type>());
} else {
dims_ = AddConstInput(dims_tensor_type, dims_data, dims_shape);
value_ = AddConstInput(GetTensorType<value_type>(), {value}, {});
}
output_ = AddOutput(GetTensorType<value_type>());
SetBuiltinOp(BuiltinOperator_FILL, BuiltinOptions_FillOptions,
CreateFillOptions(builder_).Union());
BuildInterpreter({dims_shape, {}});
if (input_tensor_types == TestType::kDynamic) {
if (dims_data.size() > 0) {
PopulateTensor<dims_type>(dims_, dims_data);
}
PopulateTensor<value_type>(value_, {value});
}
}
std::vector<value_type> GetOutput() {
return ExtractVector<value_type>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
protected:
int dims_;
int value_;
int output_;
};
class FillOpTest : public ::testing::TestWithParam<TestType> {};
TEST_P(FillOpTest, FillInt32) {
FillOpModel<int32_t, int32_t> m(TensorType_INT32, {2}, {2, 3}, -11,
GetParam());
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-11, -11, -11, -11, -11, -11}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
}
TEST_P(FillOpTest, FillInt64) {
FillOpModel<int64_t, int64_t> m(TensorType_INT64, {2}, {2, 4}, 1LL << 45,
GetParam());
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray({1LL << 45, 1LL << 45, 1LL << 45, 1LL << 45,
1LL << 45, 1LL << 45, 1LL << 45, 1LL << 45}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 4}));
}
TEST_P(FillOpTest, FillFloat) { TEST_P(FillOpTest, FillFloat) {
FillOpModel<int64_t, float> m(TensorType_INT64, {3}, {2, 2, 2}, 4.0, FillOpModel<int64_t, float> m(TensorType_INT64, {3}, {2, 2, 2}, 4.0,
GetParam()); GetParam());
@ -116,27 +40,4 @@ TEST_P(FillOpTest, FillOutputScalar) {
EXPECT_THAT(m.GetOutputShape(), IsEmpty()); EXPECT_THAT(m.GetOutputShape(), IsEmpty());
} }
TEST_P(FillOpTest, FillBool) {
FillOpModel<int64_t, bool> m(TensorType_INT64, {3}, {2, 2, 2}, true,
GetParam());
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({true, true, true, true, true,
true, true, true}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
TEST(FillOpTest, FillString) {
FillOpModel<int64_t, std::string> m(TensorType_INT64, {3}, {2, 2, 2}, "AB",
TestType::kDynamic);
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({"AB", "AB", "AB", "AB", "AB",
"AB", "AB", "AB"}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
INSTANTIATE_TEST_SUITE_P(FillOpTest, FillOpTest,
::testing::Values(TestType::kConst,
TestType::kDynamic));
} // namespace } // namespace
} // namespace tflite