Random uniform sampling for TFLite

PiperOrigin-RevId: 359683608
Change-Id: I20263546ee17bb889042f0a05c5a9e725879d865
This commit is contained in:
A. Unique TensorFlower 2021-02-25 21:34:18 -08:00 committed by TensorFlower Gardener
parent d0591b7a26
commit 94dbd28ba0
4 changed files with 376 additions and 0 deletions

View File

@ -740,6 +740,7 @@ cc_library(
srcs = [
"multinomial.cc",
"random_standard_normal.cc",
"random_uniform.cc",
],
hdrs = ["custom_ops_register.h"],
copts = tflite_copts(),
@ -1448,6 +1449,20 @@ cc_test(
],
)
cc_test(
name = "random_uniform_test",
size = "small",
srcs = ["random_uniform_test.cc"],
deps = [
":custom_ops",
":test_main",
":test_util",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "reshape_test_common",
testonly = 1,

View File

@ -27,6 +27,8 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
TfLiteRegistration* Register_MULTINOMIAL();
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
TfLiteRegistration* Register_RANDOM_UNIFORM();
TfLiteRegistration* Register_RANDOM_UNIFORM_INT();
} // namespace custom
} // namespace ops

View File

@ -0,0 +1,184 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <limits>
#include <random>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace random_uniform {
struct OpData {
// This implementation uses a random generator from the standard C++ library
// on the platform where TFLite is build. This is different from the TF
// version of the kernel that uses custom implementations of random
// generator, different for different hardware.
std::default_random_engine rng;
};
namespace {
template <typename T, typename dist_type>
void RandomUniformSample(std::default_random_engine& rng, T* buffer,
size_t buffer_size, T min_value, T max_value) {
dist_type dist(min_value, max_value);
std::generate(buffer, buffer + buffer_size, [&]() { return dist(rng); });
}
TfLiteIntArray* CreateDimensionsFromTensor(const TfLiteTensor* tensor) {
const int output_dims = tflite::SizeOfDimension(tensor, 0);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_dims);
for (int i = 0; i < output_dims; i++) {
output_shape->data[i] = tensor->data.i32[i];
}
return output_shape;
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new OpData();
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/169611265): Handle optional seed input.
TF_LITE_ENSURE(context, tflite::NumInputs(node) >= 1);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
// Input is a shape tensor.
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 1);
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
if (!IsConstantTensor(input)) {
SetTensorToDynamic(output);
return kTfLiteOk;
}
return context->ResizeTensor(context, output,
CreateDimensionsFromTensor(input));
}
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node) {
OpData* params = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE(context, params != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
if (IsDynamicTensor(output)) {
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output,
CreateDimensionsFromTensor(input)));
}
const size_t output_size = tflite::NumElements(output);
switch (output->type) {
case kTfLiteFloat32:
RandomUniformSample<float, std::uniform_real_distribution<float>>(
params->rng, GetTensorData<float>(output), output_size, 0.f, 1.f);
break;
case kTfLiteFloat64:
RandomUniformSample<double, std::uniform_real_distribution<double>>(
params->rng, GetTensorData<double>(output), output_size, 0.f, 1.f);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Unsupported output datatype for RandomUniform: %s",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;
}
int64_t IntValueFromTensor(const TfLiteTensor* tensor) {
switch (tensor->type) {
case kTfLiteInt8:
return *GetTensorData<int8_t>(tensor);
case kTfLiteInt32:
return *GetTensorData<int32_t>(tensor);
case kTfLiteInt64:
return *GetTensorData<int64_t>(tensor);
default:
return -1;
}
}
TfLiteStatus EvalInt(TfLiteContext* context, TfLiteNode* node) {
OpData* params = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE(context, params != nullptr);
TF_LITE_ENSURE(context, tflite::NumInputs(node) >= 3);
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
if (IsDynamicTensor(output)) {
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output,
CreateDimensionsFromTensor(input)));
}
int64_t min_value = IntValueFromTensor(tflite::GetInput(context, node, 1));
int64_t max_value = IntValueFromTensor(tflite::GetInput(context, node, 2));
TF_LITE_ENSURE(context, min_value < max_value);
size_t output_size = tflite::NumElements(output);
switch (output->type) {
case kTfLiteInt8:
RandomUniformSample<int8_t, std::uniform_int_distribution<int8_t>>(
params->rng, GetTensorData<int8_t>(output), output_size, min_value,
max_value);
break;
case kTfLiteInt32:
RandomUniformSample<int32_t, std::uniform_int_distribution<int32_t>>(
params->rng, GetTensorData<int32_t>(output), output_size, min_value,
max_value);
break;
case kTfLiteInt64:
RandomUniformSample<int64_t, std::uniform_int_distribution<int64_t>>(
params->rng, GetTensorData<int64_t>(output), output_size, min_value,
max_value);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Unsupported output datatype for RandomUniformInt: %s",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace random_uniform
TfLiteRegistration* Register_RANDOM_UNIFORM() {
static TfLiteRegistration r = {random_uniform::Init, random_uniform::Free,
random_uniform::Prepare,
random_uniform::EvalFloat};
return &r;
}
TfLiteRegistration* Register_RANDOM_UNIFORM_INT() {
static TfLiteRegistration r = {random_uniform::Init, random_uniform::Free,
random_uniform::Prepare,
random_uniform::EvalInt};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,175 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace {
template <typename T>
tflite::TensorType GetTTEnum();
template <>
tflite::TensorType GetTTEnum<float>() {
return tflite::TensorType_FLOAT32;
}
template <>
tflite::TensorType GetTTEnum<double>() {
return tflite::TensorType_FLOAT64;
}
template <>
tflite::TensorType GetTTEnum<int32_t>() {
return tflite::TensorType_INT32;
}
template <>
tflite::TensorType GetTTEnum<int64_t>() {
return tflite::TensorType_INT64;
}
class RandomUniformOpModel : public tflite::SingleOpModel {
public:
RandomUniformOpModel(const std::initializer_list<int>& input,
tflite::TensorData output, bool dynamic_input) {
if (dynamic_input) {
input_ = AddInput({tflite::TensorType_INT32, {3}});
} else {
input_ = AddConstInput(tflite::TensorType_INT32, input,
{static_cast<int>(input.size())});
}
output_ = AddOutput(output);
SetCustomOp("RandomUniform", {}, ops::custom::Register_RANDOM_UNIFORM);
BuildInterpreter({GetShape(input_)});
if (dynamic_input) {
PopulateTensor<int32_t>(input_, std::vector<int32_t>(input));
}
}
int input_;
int output_;
int input() { return input_; }
int output() { return output_; }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
};
class RandomUniformIntOpModel : public tflite::SingleOpModel {
public:
RandomUniformIntOpModel(const std::initializer_list<int>& input,
tflite::TensorData output, int min_val, int max_val) {
input_ = AddConstInput(tflite::TensorType_INT32, input,
{static_cast<int>(input.size())});
input_minval_ = AddConstInput(tflite::TensorType_INT32, {min_val}, {1});
input_maxval_ = AddConstInput(tflite::TensorType_INT32, {max_val}, {1});
output_ = AddOutput(output);
SetCustomOp("RandomUniformInt", {},
ops::custom::Register_RANDOM_UNIFORM_INT);
BuildInterpreter({GetShape(input_)});
}
int input_;
int input_minval_;
int input_maxval_;
int output_;
int input() { return input_; }
int output() { return output_; }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
};
} // namespace
} // namespace tflite
template <typename FloatType>
class RandomUniformTest : public ::testing::Test {
public:
using Float = FloatType;
};
using TestTypes = ::testing::Types<float, double>;
TYPED_TEST_SUITE(RandomUniformTest, TestTypes);
TYPED_TEST(RandomUniformTest, TestOutput) {
using Float = typename TestFixture::Float;
for (const auto dynamic : {true, false}) {
tflite::RandomUniformOpModel m({1000, 50, 5},
{tflite::GetTTEnum<Float>(), {}}, dynamic);
m.Invoke();
auto output = m.GetOutput<Float>();
EXPECT_EQ(output.size(), 1000 * 50 * 5);
double sum = 0;
for (const auto r : output) {
sum += r;
}
double avg = sum / output.size();
ASSERT_LT(std::abs(avg - 0.5), 0.05); // Average should approximately 0.5
double sum_squared = 0;
for (const auto r : output) {
sum_squared += std::pow(r - avg, 2);
}
double var = sum_squared / output.size();
EXPECT_LT(std::abs(1. / 12 - var),
0.05); // Variance should be approximately 1./12
}
}
template <typename IntType>
class RandomUniformIntTest : public ::testing::Test {
public:
using Int = IntType;
};
using TestTypesInt = ::testing::Types<int32_t, int64_t>;
TYPED_TEST_SUITE(RandomUniformIntTest, TestTypesInt);
TYPED_TEST(RandomUniformIntTest, TestOutput) {
using Int = typename TestFixture::Int;
tflite::RandomUniformIntOpModel m({1000, 50, 5},
{tflite::GetTTEnum<Int>(), {}}, 0, 5);
m.Invoke();
auto output = m.GetOutput<Int>();
EXPECT_EQ(output.size(), 1000 * 50 * 5);
int counters[] = {0, 0, 0, 0, 0, 0};
for (const auto r : output) {
ASSERT_GE(r, 0);
ASSERT_LE(r, 5);
++counters[r];
}
// Check that all numbers are meet with near the same frequency.
for (int i = 1; i < 6; ++i) {
EXPECT_LT(std::abs(counters[i] - counters[0]), 1000);
}
}