Add prototype custom op for RandomStandardNormal.

PiperOrigin-RevId: 334414381
Change-Id: I0de25e8261c4a2d3f22d195b717942c83c1885ee
This commit is contained in:
A. Unique TensorFlower 2020-09-29 10:50:13 -07:00 committed by TensorFlower Gardener
parent 4cb5e1ebc3
commit 11823c6179
4 changed files with 246 additions and 0 deletions

View File

@ -706,6 +706,7 @@ cc_library(
"complex_support.cc",
"cumsum.cc",
"multinomial.cc",
"random_standard_normal.cc",
"rfft2d.cc",
],
hdrs = ["custom_ops_register.h"],
@ -1375,6 +1376,20 @@ cc_test(
],
)
cc_test(
name = "random_standard_normal_test",
size = "small",
srcs = ["random_standard_normal_test.cc"],
deps = [
":custom_ops",
":test_main",
":test_util",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
cc_library(
name = "reshape_test_common",
testonly = 1,

View File

@ -28,6 +28,7 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
TfLiteRegistration* Register_IMAG();
TfLiteRegistration* Register_MULTINOMIAL();
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_RFFT2D();

View File

@ -0,0 +1,127 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <cstdint>
#include <limits>
#include <random>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace random_standard_normal {
struct OpData {
std::default_random_engine rng;
};
// Draws a sample from standard normal distribution.
template <typename Float>
TfLiteStatus RandomStandardNormalSample(std::default_random_engine& rng,
Float* output, size_t output_size) {
std::normal_distribution<Float> dist;
for (Float* it = output; it != output + output_size; ++it) {
*it = dist(rng);
}
return kTfLiteOk;
}
TfLiteStatus RandomStandardNormalSample(TfLiteContext* context,
std::default_random_engine& rng,
TfLiteTensor* output,
size_t output_size) {
switch (output->type) {
case kTfLiteFloat32:
TF_LITE_ENSURE_OK(context,
RandomStandardNormalSample<float>(
rng, GetTensorData<float>(output), output_size));
break;
case kTfLiteFloat64:
TF_LITE_ENSURE_OK(context,
RandomStandardNormalSample<double>(
rng, GetTensorData<double>(output), output_size));
break;
default:
TF_LITE_KERNEL_LOG(
context, "Unsupported output datatype for RandomStandardNormal: %s",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new OpData();
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/169611265): Handle optional seed input.
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
// Input is a shape tensor.
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 1);
// TODO(b/169611265): Support dynamic output tensors.
TF_LITE_ENSURE(context, IsConstantTensor(input));
// TODO(b/169611265): Handle other input data types.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32);
int output_dims = tflite::SizeOfDimension(input, 0);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_dims);
for (int i = 0; i < output_dims; i++) {
output_shape->data[i] = input->data.i32[i];
}
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
// ResizeTensor takes ownership of output_shape.
return context->ResizeTensor(context, output, output_shape);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/169611265): Handle optional seed input.
OpData* params = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE(context, params != nullptr);
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
size_t output_size = tflite::NumElements(output);
TF_LITE_ENSURE_OK(context, RandomStandardNormalSample(context, params->rng,
output, output_size));
return kTfLiteOk;
}
} // namespace random_standard_normal
TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL() {
static TfLiteRegistration r = {
random_standard_normal::Init, random_standard_normal::Free,
random_standard_normal::Prepare, random_standard_normal::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,103 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <limits>
#include <random>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace {
template <typename T>
tflite::TensorType GetTTEnum();
template <>
tflite::TensorType GetTTEnum<float>() {
return tflite::TensorType_FLOAT32;
}
template <>
tflite::TensorType GetTTEnum<double>() {
return tflite::TensorType_FLOAT64;
}
class RandomStandardNormalOpModel : public tflite::SingleOpModel {
public:
RandomStandardNormalOpModel(const std::initializer_list<int>& input,
tflite::TensorData output) {
input_ = AddConstInput(tflite::TensorType_INT32, input,
{static_cast<int>(input.size())});
output_ = AddOutput(output);
SetCustomOp("RandomStandardNormal", {},
ops::custom::Register_RANDOM_STANDARD_NORMAL);
BuildInterpreter({GetShape(input_)});
}
int input_;
int output_;
int input() { return input_; }
int output() { return output_; }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
};
} // namespace
} // namespace tflite
template <typename FloatType>
class RandomStandardNormalTest : public ::testing::Test {
public:
using Float = FloatType;
};
using TestTypes = ::testing::Types<float, double>;
TYPED_TEST_SUITE(RandomStandardNormalTest, TestTypes);
TYPED_TEST(RandomStandardNormalTest, TestOutput) {
using Float = typename TestFixture::Float;
tflite::RandomStandardNormalOpModel m({1000, 50, 5},
{tflite::GetTTEnum<Float>(), {}});
m.Invoke();
auto output = m.GetOutput<Float>();
EXPECT_EQ(output.size(), 1000 * 50 * 5);
double sum = 0;
for (auto r : output) {
sum += r;
}
double avg = sum / output.size();
ASSERT_LT(std::abs(avg), 0.05); // Average should approximately 0.
double sum_squared = 0;
for (auto r : output) {
sum_squared += std::pow(r - avg, 2);
}
double var = sum_squared / output.size();
EXPECT_LT(std::abs(1 - var), 0.05); // Variance should be approximately 1.
}