From 11823c6179a5e2ca6ac4f8480aa00c42c772885e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 Sep 2020 10:50:13 -0700 Subject: [PATCH] Add prototype custom op for RandomStandardNormal. PiperOrigin-RevId: 334414381 Change-Id: I0de25e8261c4a2d3f22d195b717942c83c1885ee --- tensorflow/lite/kernels/BUILD | 15 +++ tensorflow/lite/kernels/custom_ops_register.h | 1 + .../lite/kernels/random_standard_normal.cc | 127 ++++++++++++++++++ .../kernels/random_standard_normal_test.cc | 103 ++++++++++++++ 4 files changed, 246 insertions(+) create mode 100644 tensorflow/lite/kernels/random_standard_normal.cc create mode 100644 tensorflow/lite/kernels/random_standard_normal_test.cc diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 592aaeebed5..99ccb67663a 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -706,6 +706,7 @@ cc_library( "complex_support.cc", "cumsum.cc", "multinomial.cc", + "random_standard_normal.cc", "rfft2d.cc", ], hdrs = ["custom_ops_register.h"], @@ -1375,6 +1376,20 @@ cc_test( ], ) +cc_test( + name = "random_standard_normal_test", + size = "small", + srcs = ["random_standard_normal_test.cc"], + deps = [ + ":custom_ops", + ":test_main", + ":test_util", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "reshape_test_common", testonly = 1, diff --git a/tensorflow/lite/kernels/custom_ops_register.h b/tensorflow/lite/kernels/custom_ops_register.h index 94d4fec5347..8aadd379a43 100644 --- a/tensorflow/lite/kernels/custom_ops_register.h +++ b/tensorflow/lite/kernels/custom_ops_register.h @@ -28,6 +28,7 @@ TfLiteRegistration* Register_HASHTABLE_IMPORT(); TfLiteRegistration* Register_HASHTABLE_SIZE(); TfLiteRegistration* Register_IMAG(); TfLiteRegistration* Register_MULTINOMIAL(); +TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL(); TfLiteRegistration* Register_REAL(); TfLiteRegistration* Register_RFFT2D(); diff --git a/tensorflow/lite/kernels/random_standard_normal.cc b/tensorflow/lite/kernels/random_standard_normal.cc new file mode 100644 index 00000000000..9b0b7b0b5d8 --- /dev/null +++ b/tensorflow/lite/kernels/random_standard_normal.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace random_standard_normal { + +struct OpData { + std::default_random_engine rng; +}; + +// Draws a sample from standard normal distribution. +template +TfLiteStatus RandomStandardNormalSample(std::default_random_engine& rng, + Float* output, size_t output_size) { + std::normal_distribution dist; + for (Float* it = output; it != output + output_size; ++it) { + *it = dist(rng); + } + return kTfLiteOk; +} + +TfLiteStatus RandomStandardNormalSample(TfLiteContext* context, + std::default_random_engine& rng, + TfLiteTensor* output, + size_t output_size) { + switch (output->type) { + case kTfLiteFloat32: + TF_LITE_ENSURE_OK(context, + RandomStandardNormalSample( + rng, GetTensorData(output), output_size)); + break; + case kTfLiteFloat64: + TF_LITE_ENSURE_OK(context, + RandomStandardNormalSample( + rng, GetTensorData(output), output_size)); + break; + default: + TF_LITE_KERNEL_LOG( + context, "Unsupported output datatype for RandomStandardNormal: %s", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return new OpData(); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO(b/169611265): Handle optional seed input. + TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1); + + // Input is a shape tensor. + const TfLiteTensor* input = tflite::GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(input), 1); + // TODO(b/169611265): Support dynamic output tensors. + TF_LITE_ENSURE(context, IsConstantTensor(input)); + + // TODO(b/169611265): Handle other input data types. + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteInt32); + + int output_dims = tflite::SizeOfDimension(input, 0); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_dims); + for (int i = 0; i < output_dims; i++) { + output_shape->data[i] = input->data.i32[i]; + } + + TfLiteTensor* output = tflite::GetOutput(context, node, 0); + // ResizeTensor takes ownership of output_shape. + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(b/169611265): Handle optional seed input. + OpData* params = reinterpret_cast(node->user_data); + TF_LITE_ENSURE(context, params != nullptr); + + TfLiteTensor* output = tflite::GetOutput(context, node, 0); + size_t output_size = tflite::NumElements(output); + + TF_LITE_ENSURE_OK(context, RandomStandardNormalSample(context, params->rng, + output, output_size)); + + return kTfLiteOk; +} + +} // namespace random_standard_normal + +TfLiteRegistration* Register_RANDOM_STANDARD_NORMAL() { + static TfLiteRegistration r = { + random_standard_normal::Init, random_standard_normal::Free, + random_standard_normal::Prepare, random_standard_normal::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/random_standard_normal_test.cc b/tensorflow/lite/kernels/random_standard_normal_test.cc new file mode 100644 index 00000000000..88e71f27669 --- /dev/null +++ b/tensorflow/lite/kernels/random_standard_normal_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include +#include +#include "tensorflow/lite/kernels/custom_ops_register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/testing/util.h" + +namespace tflite { +namespace { + +template +tflite::TensorType GetTTEnum(); + +template <> +tflite::TensorType GetTTEnum() { + return tflite::TensorType_FLOAT32; +} + +template <> +tflite::TensorType GetTTEnum() { + return tflite::TensorType_FLOAT64; +} + +class RandomStandardNormalOpModel : public tflite::SingleOpModel { + public: + RandomStandardNormalOpModel(const std::initializer_list& input, + tflite::TensorData output) { + input_ = AddConstInput(tflite::TensorType_INT32, input, + {static_cast(input.size())}); + output_ = AddOutput(output); + SetCustomOp("RandomStandardNormal", {}, + ops::custom::Register_RANDOM_STANDARD_NORMAL); + BuildInterpreter({GetShape(input_)}); + } + + int input_; + int output_; + + int input() { return input_; } + int output() { return output_; } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } +}; + +} // namespace +} // namespace tflite + +template +class RandomStandardNormalTest : public ::testing::Test { + public: + using Float = FloatType; +}; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(RandomStandardNormalTest, TestTypes); + +TYPED_TEST(RandomStandardNormalTest, TestOutput) { + using Float = typename TestFixture::Float; + tflite::RandomStandardNormalOpModel m({1000, 50, 5}, + {tflite::GetTTEnum(), {}}); + m.Invoke(); + auto output = m.GetOutput(); + EXPECT_EQ(output.size(), 1000 * 50 * 5); + + double sum = 0; + for (auto r : output) { + sum += r; + } + double avg = sum / output.size(); + ASSERT_LT(std::abs(avg), 0.05); // Average should approximately 0. + + double sum_squared = 0; + for (auto r : output) { + sum_squared += std::pow(r - avg, 2); + } + double var = sum_squared / output.size(); + EXPECT_LT(std::abs(1 - var), 0.05); // Variance should be approximately 1. +}