Random uniform sampling for TFLite
PiperOrigin-RevId: 359683608 Change-Id: I20263546ee17bb889042f0a05c5a9e725879d865
This commit is contained in:
parent
d0591b7a26
commit
94dbd28ba0
@ -740,6 +740,7 @@ cc_library(
|
||||
srcs = [
|
||||
"multinomial.cc",
|
||||
"random_standard_normal.cc",
|
||||
"random_uniform.cc",
|
||||
],
|
||||
hdrs = ["custom_ops_register.h"],
|
||||
copts = tflite_copts(),
|
||||
@ -1448,6 +1449,20 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "random_uniform_test",
|
||||
size = "small",
|
||||
srcs = ["random_uniform_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "reshape_test_common",
|
||||
testonly = 1,
|
||||
|
@ -27,6 +27,8 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
||||
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
||||
TfLiteRegistration* Register_MULTINOMIAL();
|
||||
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
|
||||
TfLiteRegistration* Register_RANDOM_UNIFORM();
|
||||
TfLiteRegistration* Register_RANDOM_UNIFORM_INT();
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
184
tensorflow/lite/kernels/random_uniform.cc
Normal file
184
tensorflow/lite/kernels/random_uniform.cc
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace random_uniform {
|
||||
|
||||
struct OpData {
|
||||
// This implementation uses a random generator from the standard C++ library
|
||||
// on the platform where TFLite is build. This is different from the TF
|
||||
// version of the kernel that uses custom implementations of random
|
||||
// generator, different for different hardware.
|
||||
std::default_random_engine rng;
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, typename dist_type>
|
||||
void RandomUniformSample(std::default_random_engine& rng, T* buffer,
|
||||
size_t buffer_size, T min_value, T max_value) {
|
||||
dist_type dist(min_value, max_value);
|
||||
std::generate(buffer, buffer + buffer_size, [&]() { return dist(rng); });
|
||||
}
|
||||
|
||||
TfLiteIntArray* CreateDimensionsFromTensor(const TfLiteTensor* tensor) {
|
||||
const int output_dims = tflite::SizeOfDimension(tensor, 0);
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_dims);
|
||||
for (int i = 0; i < output_dims; i++) {
|
||||
output_shape->data[i] = tensor->data.i32[i];
|
||||
}
|
||||
return output_shape;
|
||||
}
|
||||
} // namespace
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return new OpData();
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO(b/169611265): Handle optional seed input.
|
||||
TF_LITE_ENSURE(context, tflite::NumInputs(node) >= 1);
|
||||
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
|
||||
|
||||
// Input is a shape tensor.
|
||||
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 1);
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
|
||||
if (!IsConstantTensor(input)) {
|
||||
SetTensorToDynamic(output);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
return context->ResizeTensor(context, output,
|
||||
CreateDimensionsFromTensor(input));
|
||||
}
|
||||
|
||||
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* params = reinterpret_cast<OpData*>(node->user_data);
|
||||
TF_LITE_ENSURE(context, params != nullptr);
|
||||
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
|
||||
if (IsDynamicTensor(output)) {
|
||||
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output,
|
||||
CreateDimensionsFromTensor(input)));
|
||||
}
|
||||
const size_t output_size = tflite::NumElements(output);
|
||||
switch (output->type) {
|
||||
case kTfLiteFloat32:
|
||||
RandomUniformSample<float, std::uniform_real_distribution<float>>(
|
||||
params->rng, GetTensorData<float>(output), output_size, 0.f, 1.f);
|
||||
break;
|
||||
case kTfLiteFloat64:
|
||||
RandomUniformSample<double, std::uniform_real_distribution<double>>(
|
||||
params->rng, GetTensorData<double>(output), output_size, 0.f, 1.f);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unsupported output datatype for RandomUniform: %s",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
int64_t IntValueFromTensor(const TfLiteTensor* tensor) {
|
||||
switch (tensor->type) {
|
||||
case kTfLiteInt8:
|
||||
return *GetTensorData<int8_t>(tensor);
|
||||
case kTfLiteInt32:
|
||||
return *GetTensorData<int32_t>(tensor);
|
||||
case kTfLiteInt64:
|
||||
return *GetTensorData<int64_t>(tensor);
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus EvalInt(TfLiteContext* context, TfLiteNode* node) {
|
||||
OpData* params = reinterpret_cast<OpData*>(node->user_data);
|
||||
TF_LITE_ENSURE(context, params != nullptr);
|
||||
|
||||
TF_LITE_ENSURE(context, tflite::NumInputs(node) >= 3);
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
|
||||
if (IsDynamicTensor(output)) {
|
||||
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output,
|
||||
CreateDimensionsFromTensor(input)));
|
||||
}
|
||||
int64_t min_value = IntValueFromTensor(tflite::GetInput(context, node, 1));
|
||||
int64_t max_value = IntValueFromTensor(tflite::GetInput(context, node, 2));
|
||||
TF_LITE_ENSURE(context, min_value < max_value);
|
||||
size_t output_size = tflite::NumElements(output);
|
||||
switch (output->type) {
|
||||
case kTfLiteInt8:
|
||||
RandomUniformSample<int8_t, std::uniform_int_distribution<int8_t>>(
|
||||
params->rng, GetTensorData<int8_t>(output), output_size, min_value,
|
||||
max_value);
|
||||
break;
|
||||
case kTfLiteInt32:
|
||||
RandomUniformSample<int32_t, std::uniform_int_distribution<int32_t>>(
|
||||
params->rng, GetTensorData<int32_t>(output), output_size, min_value,
|
||||
max_value);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
RandomUniformSample<int64_t, std::uniform_int_distribution<int64_t>>(
|
||||
params->rng, GetTensorData<int64_t>(output), output_size, min_value,
|
||||
max_value);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unsupported output datatype for RandomUniformInt: %s",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace random_uniform
|
||||
|
||||
TfLiteRegistration* Register_RANDOM_UNIFORM() {
|
||||
static TfLiteRegistration r = {random_uniform::Init, random_uniform::Free,
|
||||
random_uniform::Prepare,
|
||||
random_uniform::EvalFloat};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_RANDOM_UNIFORM_INT() {
|
||||
static TfLiteRegistration r = {random_uniform::Init, random_uniform::Free,
|
||||
random_uniform::Prepare,
|
||||
random_uniform::EvalInt};
|
||||
return &r;
|
||||
}
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
175
tensorflow/lite/kernels/random_uniform_test.cc
Normal file
175
tensorflow/lite/kernels/random_uniform_test.cc
Normal file
@ -0,0 +1,175 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
template <typename T>
|
||||
tflite::TensorType GetTTEnum();
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<float>() {
|
||||
return tflite::TensorType_FLOAT32;
|
||||
}
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<double>() {
|
||||
return tflite::TensorType_FLOAT64;
|
||||
}
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<int32_t>() {
|
||||
return tflite::TensorType_INT32;
|
||||
}
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<int64_t>() {
|
||||
return tflite::TensorType_INT64;
|
||||
}
|
||||
|
||||
class RandomUniformOpModel : public tflite::SingleOpModel {
|
||||
public:
|
||||
RandomUniformOpModel(const std::initializer_list<int>& input,
|
||||
tflite::TensorData output, bool dynamic_input) {
|
||||
if (dynamic_input) {
|
||||
input_ = AddInput({tflite::TensorType_INT32, {3}});
|
||||
} else {
|
||||
input_ = AddConstInput(tflite::TensorType_INT32, input,
|
||||
{static_cast<int>(input.size())});
|
||||
}
|
||||
output_ = AddOutput(output);
|
||||
SetCustomOp("RandomUniform", {}, ops::custom::Register_RANDOM_UNIFORM);
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
if (dynamic_input) {
|
||||
PopulateTensor<int32_t>(input_, std::vector<int32_t>(input));
|
||||
}
|
||||
}
|
||||
|
||||
int input_;
|
||||
int output_;
|
||||
|
||||
int input() { return input_; }
|
||||
int output() { return output_; }
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
};
|
||||
|
||||
class RandomUniformIntOpModel : public tflite::SingleOpModel {
|
||||
public:
|
||||
RandomUniformIntOpModel(const std::initializer_list<int>& input,
|
||||
tflite::TensorData output, int min_val, int max_val) {
|
||||
input_ = AddConstInput(tflite::TensorType_INT32, input,
|
||||
{static_cast<int>(input.size())});
|
||||
input_minval_ = AddConstInput(tflite::TensorType_INT32, {min_val}, {1});
|
||||
input_maxval_ = AddConstInput(tflite::TensorType_INT32, {max_val}, {1});
|
||||
output_ = AddOutput(output);
|
||||
SetCustomOp("RandomUniformInt", {},
|
||||
ops::custom::Register_RANDOM_UNIFORM_INT);
|
||||
BuildInterpreter({GetShape(input_)});
|
||||
}
|
||||
|
||||
int input_;
|
||||
int input_minval_;
|
||||
int input_maxval_;
|
||||
|
||||
int output_;
|
||||
|
||||
int input() { return input_; }
|
||||
int output() { return output_; }
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
template <typename FloatType>
|
||||
class RandomUniformTest : public ::testing::Test {
|
||||
public:
|
||||
using Float = FloatType;
|
||||
};
|
||||
|
||||
using TestTypes = ::testing::Types<float, double>;
|
||||
|
||||
TYPED_TEST_SUITE(RandomUniformTest, TestTypes);
|
||||
|
||||
TYPED_TEST(RandomUniformTest, TestOutput) {
|
||||
using Float = typename TestFixture::Float;
|
||||
for (const auto dynamic : {true, false}) {
|
||||
tflite::RandomUniformOpModel m({1000, 50, 5},
|
||||
{tflite::GetTTEnum<Float>(), {}}, dynamic);
|
||||
m.Invoke();
|
||||
auto output = m.GetOutput<Float>();
|
||||
EXPECT_EQ(output.size(), 1000 * 50 * 5);
|
||||
|
||||
double sum = 0;
|
||||
for (const auto r : output) {
|
||||
sum += r;
|
||||
}
|
||||
double avg = sum / output.size();
|
||||
ASSERT_LT(std::abs(avg - 0.5), 0.05); // Average should approximately 0.5
|
||||
|
||||
double sum_squared = 0;
|
||||
for (const auto r : output) {
|
||||
sum_squared += std::pow(r - avg, 2);
|
||||
}
|
||||
double var = sum_squared / output.size();
|
||||
EXPECT_LT(std::abs(1. / 12 - var),
|
||||
0.05); // Variance should be approximately 1./12
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntType>
|
||||
class RandomUniformIntTest : public ::testing::Test {
|
||||
public:
|
||||
using Int = IntType;
|
||||
};
|
||||
|
||||
using TestTypesInt = ::testing::Types<int32_t, int64_t>;
|
||||
|
||||
TYPED_TEST_SUITE(RandomUniformIntTest, TestTypesInt);
|
||||
|
||||
TYPED_TEST(RandomUniformIntTest, TestOutput) {
|
||||
using Int = typename TestFixture::Int;
|
||||
tflite::RandomUniformIntOpModel m({1000, 50, 5},
|
||||
{tflite::GetTTEnum<Int>(), {}}, 0, 5);
|
||||
m.Invoke();
|
||||
auto output = m.GetOutput<Int>();
|
||||
EXPECT_EQ(output.size(), 1000 * 50 * 5);
|
||||
|
||||
int counters[] = {0, 0, 0, 0, 0, 0};
|
||||
for (const auto r : output) {
|
||||
ASSERT_GE(r, 0);
|
||||
ASSERT_LE(r, 5);
|
||||
++counters[r];
|
||||
}
|
||||
// Check that all numbers are meet with near the same frequency.
|
||||
for (int i = 1; i < 6; ++i) {
|
||||
EXPECT_LT(std::abs(counters[i] - counters[0]), 1000);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user