diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index f09cbad8a37..040815b0a82 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -700,6 +700,7 @@ cc_library( srcs = [ "complex_support.cc", "cumsum.cc", + "multinomial.cc", "rfft2d.cc", ], hdrs = ["custom_ops_register.h"], @@ -1340,6 +1341,20 @@ cc_test( ], ) +cc_test( + name = "multinomial_test", + size = "small", + srcs = ["multinomial_test.cc"], + deps = [ + ":custom_ops", + ":test_main", + ":test_util", + "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "pad_test", size = "small", diff --git a/tensorflow/lite/kernels/custom_ops_register.h b/tensorflow/lite/kernels/custom_ops_register.h index 659091f12fd..94d4fec5347 100644 --- a/tensorflow/lite/kernels/custom_ops_register.h +++ b/tensorflow/lite/kernels/custom_ops_register.h @@ -22,13 +22,14 @@ namespace ops { namespace custom { TfLiteRegistration* Register_CUMSUM(); -TfLiteRegistration* Register_RFFT2D(); TfLiteRegistration* Register_HASHTABLE(); TfLiteRegistration* Register_HASHTABLE_FIND(); TfLiteRegistration* Register_HASHTABLE_IMPORT(); TfLiteRegistration* Register_HASHTABLE_SIZE(); -TfLiteRegistration* Register_REAL(); TfLiteRegistration* Register_IMAG(); +TfLiteRegistration* Register_MULTINOMIAL(); +TfLiteRegistration* Register_REAL(); +TfLiteRegistration* Register_RFFT2D(); } // namespace custom } // namespace ops diff --git a/tensorflow/lite/kernels/multinomial.cc b/tensorflow/lite/kernels/multinomial.cc new file mode 100644 index 00000000000..ea471802992 --- /dev/null +++ b/tensorflow/lite/kernels/multinomial.cc @@ -0,0 +1,206 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cmath> +#include <cstdint> +#include <limits> +#include <random> + +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace multinomial { + +struct MultinomialParams { + std::default_random_engine rng; +}; + +// Draws a sample from a categorical distribution. +template <typename FloatType, typename IntegralType> +TfLiteStatus MultinomialSample(std::default_random_engine& rng, + const FloatType* logits, int logits_size, + IntegralType* outputs, int output_size) { + // Computes arg_max(cumsum(exp(logits)) > rand()). + // TODO(b/169166131): Remove hard-coded double for constrained use-cases. + std::vector<double> cumulative_odds; + cumulative_odds.reserve(logits_size); + double last_odds = 0.0; + + // Compute max logit for numerical stability. + FloatType max_logit = std::numeric_limits<FloatType>::lowest(); + for (int i = 0; i < logits_size; i++) { + max_logit = std::max(max_logit, logits[i]); + } + + for (int i = 0; i < logits_size; i++) { + FloatType odds = std::exp(logits[i] - max_logit) + last_odds; + cumulative_odds.push_back(odds); + last_odds = odds; + } + + std::uniform_real_distribution<double> distribution{0.0, + cumulative_odds.back()}; + + for (int i = 0; i < output_size; i++) { + double sample = distribution(rng); + auto it = std::lower_bound(cumulative_odds.begin(), cumulative_odds.end(), + sample); + if (it == cumulative_odds.end()) { + // This should be impossible by construction. + return kTfLiteError; + } + *outputs++ = static_cast<IntegralType>(it - cumulative_odds.begin()); + } + return kTfLiteOk; +} + +template <typename FloatType> +TfLiteStatus MultinomialSample(TfLiteContext* context, + std::default_random_engine& rng, + const FloatType* logits, int logits_size, + TfLiteTensor* output, int outputs_offset, + int output_size) { + switch (output->type) { + case kTfLiteInt32: + return MultinomialSample<FloatType, int32_t>( + rng, logits, logits_size, + GetTensorData<int32_t>(output) + outputs_offset, output_size); + break; + case kTfLiteInt64: + return MultinomialSample<FloatType, int64_t>( + rng, logits, logits_size, + GetTensorData<int64_t>(output) + outputs_offset, output_size); + break; + default: + TF_LITE_KERNEL_LOG(context, + "Unsupported datatype for multinomial output: %s", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } +} + +TfLiteStatus MultinomialSample(TfLiteContext* context, + std::default_random_engine& rng, + const TfLiteTensor* logits, int logits_offset, + int logits_size, TfLiteTensor* output, + int outputs_offset, int output_size) { + switch (logits->type) { + case kTfLiteFloat16: + TF_LITE_KERNEL_LOG(context, "TfLiteFloat16 is currently not supported."); + return kTfLiteError; + break; + case kTfLiteFloat32: + TF_LITE_ENSURE_OK( + context, + MultinomialSample<float>( + context, rng, GetTensorData<float>(logits) + logits_offset, + logits_size, output, outputs_offset, output_size)); + break; + case kTfLiteFloat64: + TF_LITE_ENSURE_OK( + context, + MultinomialSample<double>( + context, rng, GetTensorData<double>(logits) + logits_offset, + logits_size, output, outputs_offset, output_size)); + break; + default: + TF_LITE_KERNEL_LOG(context, + "Unsupported datatype for multinomial logit input: %s", + TfLiteTypeGetName(output->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + return new MultinomialParams(); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<MultinomialParams*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO(b/169166131): Handle optional seed input. + TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1); + + // 'logits' is a float matrix [batch_size, num_categories] + const TfLiteTensor* logits_input = tflite::GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(logits_input), 2); + int batch_size = tflite::SizeOfDimension(logits_input, 0); + + // 'num_samples' is an int scalar. + const TfLiteTensor* num_samples_input = tflite::GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(num_samples_input), 0); + // TODO(b/169166131): Allow different integer input types. + TF_LITE_ENSURE_EQ(context, num_samples_input->type, kTfLiteInt32); + // TODO(b/169166131): Support dynamic output tensors. + TF_LITE_ENSURE(context, IsConstantTensor(num_samples_input)); + + int num_samples = *num_samples_input->data.i32; + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(2); + output_shape->data[0] = batch_size; + output_shape->data[1] = num_samples; + + TfLiteTensor* output = tflite::GetOutput(context, node, 0); + // ResizeTensor takes ownership of output_shape. + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // TODO(b/169166131): Handle optional seed input. + MultinomialParams* params = + reinterpret_cast<MultinomialParams*>(node->user_data); + TF_LITE_ENSURE(context, params != nullptr); + + const TfLiteTensor* logits = tflite::GetInput(context, node, 0); + int batch_size = tflite::SizeOfDimension(logits, 0); + int logits_size = tflite::SizeOfDimension(logits, 1); + + const TfLiteTensor* num_samples_input = tflite::GetInput(context, node, 1); + int output_size = *num_samples_input->data.i32; + + TfLiteTensor* output = tflite::GetOutput(context, node, 0); + + for (int batch = 0; batch < batch_size; ++batch) { + int logits_offset = logits_size * batch; + int output_offset = output_size * batch; + + TF_LITE_ENSURE_OK( + context, + MultinomialSample(context, params->rng, logits, logits_offset, + logits_size, output, output_offset, output_size)); + } + + return kTfLiteOk; +} + +} // namespace multinomial + +TfLiteRegistration* Register_MULTINOMIAL() { + static TfLiteRegistration r = {multinomial::Init, multinomial::Free, + multinomial::Prepare, multinomial::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/multinomial_test.cc b/tensorflow/lite/kernels/multinomial_test.cc new file mode 100644 index 00000000000..f1e3d7a039e --- /dev/null +++ b/tensorflow/lite/kernels/multinomial_test.cc @@ -0,0 +1,243 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <algorithm> +#include <cmath> +#include <cstddef> +#include <limits> +#include <random> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/lite/kernels/custom_ops_register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/schema/schema_generated.h" +#include "tensorflow/lite/testing/util.h" + +namespace tflite { +namespace { + +template <typename T> +tflite::TensorType GetTTEnum(); + +template <> +tflite::TensorType GetTTEnum<float>() { + return tflite::TensorType_FLOAT32; +} + +template <> +tflite::TensorType GetTTEnum<double>() { + return tflite::TensorType_FLOAT64; +} + +template <> +tflite::TensorType GetTTEnum<int>() { + return tflite::TensorType_INT32; +} + +template <> +tflite::TensorType GetTTEnum<int64_t>() { + return tflite::TensorType_INT64; +} + +class MultinomialOpModel : public tflite::SingleOpModel { + public: + MultinomialOpModel(tflite::TensorData logits, int num_samples, + tflite::TensorData output) { + logits_ = AddInput(logits); + num_samples_ = AddConstInput(tflite::TensorType_INT32, {num_samples}, {}); + output_ = AddOutput(output); + SetCustomOp("Multinomial", {}, ops::custom::Register_MULTINOMIAL); + BuildInterpreter({GetShape(logits_), GetShape(num_samples_)}); + } + + int logits_; + int num_samples_; + int output_; + + int logits() { return logits_; } + int num_samples() { return num_samples_; } + int output() { return output_; } + + template <typename T> + std::vector<T> GetOutput() { + return ExtractVector<T>(output_); + } +}; + +} // namespace +} // namespace tflite + +template <typename Type1, typename Type2> +struct TypePair { + using T1 = Type1; + using T2 = Type2; +}; + +template <typename TypePair> +class MultinomialTest : public ::testing::Test { + public: + using FloatType = typename TypePair::T1; + using IntegralType = typename TypePair::T2; +}; + +using TestTypes = + ::testing::Types<TypePair<float, int>, TypePair<double, int>, + TypePair<float, int64_t>, TypePair<double, int64_t> >; + +TYPED_TEST_SUITE(MultinomialTest, TestTypes); + +TYPED_TEST(MultinomialTest, TestMultiBatch) { + std::default_random_engine rng; + std::uniform_real_distribution<float> dist(0, 10); + std::vector<float> results; + for (int i = 0; i < 1000; ++i) { + results.push_back(dist(rng)); + } + double sum = 0; + for (auto f : results) { + sum += f; + } + double avg = sum / results.size(); + double min = *std::min_element(results.begin(), results.end()); + double max = *std::max_element(results.begin(), results.end()); + EXPECT_GT(avg, 4.8); + EXPECT_LT(avg, 5.2); + EXPECT_LT(min, 0.5); + EXPECT_GT(max, 9.5); + + const int kNumSamples = 1000; + using Float = typename TestFixture::FloatType; + using Int = typename TestFixture::IntegralType; + tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {3, 3}}, + kNumSamples, {tflite::GetTTEnum<Int>(), {}}); + // Add 3 batches of 3 logits each. + m.PopulateTensor<Float>(m.logits(), + std::vector<Float>(9, static_cast<Float>(0.0f))); + + m.Invoke(); + auto output = m.GetOutput<Int>(); + EXPECT_EQ(output.size(), kNumSamples * 3); + + int c0 = std::count(output.begin(), output.end(), 0); + int c1 = std::count(output.begin(), output.end(), 1); + int c2 = std::count(output.begin(), output.end(), 2); + + EXPECT_EQ(c0 + c1 + c2, 3 * kNumSamples); + + // Make sure they're all sampled with roughly equal probability. + EXPECT_GT(c0, 750); + EXPECT_GT(c1, 750); + EXPECT_GT(c2, 750); + EXPECT_LT(c0, 1250); + EXPECT_LT(c1, 1250); + EXPECT_LT(c2, 1250); +} + +// Test that higher log odds are sampled more often. +TYPED_TEST(MultinomialTest, TestSampleHighLogOdds) { + const int kNumSamples = 1000; + using Float = typename TestFixture::FloatType; + using Int = typename TestFixture::IntegralType; + tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}}, + kNumSamples, {tflite::GetTTEnum<Int>(), {}}); + + // Add 1 batch of 3 logits. + m.PopulateTensor<Float>(m.logits(), + {static_cast<Float>(0.0f), static_cast<Float>(1.0f), + static_cast<Float>(0.0f)}); + m.Invoke(); + auto output = m.GetOutput<Int>(); + EXPECT_EQ(output.size(), kNumSamples); + + int c0 = std::count(output.begin(), output.end(), 0); + int c1 = std::count(output.begin(), output.end(), 1); + int c2 = std::count(output.begin(), output.end(), 2); + EXPECT_EQ(c0 + c1 + c2, kNumSamples); + EXPECT_GT(c1, c0); + EXPECT_GT(c1, c2); +} + +// Test that very low log odds are never sampled. +TYPED_TEST(MultinomialTest, TestVeryLowLogOdds) { + const int kNumSamples = 1000; + using Float = typename TestFixture::FloatType; + using Int = typename TestFixture::IntegralType; + tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}}, + kNumSamples, {tflite::GetTTEnum<Int>(), {}}); + + // Add 1 batch of 3 logits. + m.PopulateTensor<Float>( + m.logits(), {static_cast<Float>(-1000.0f), static_cast<Float>(-1000.0f), + static_cast<Float>(0.0f)}); + m.Invoke(); + auto output = m.GetOutput<Int>(); + EXPECT_EQ(output.size(), kNumSamples); + + int c0 = std::count(output.begin(), output.end(), 0); + int c1 = std::count(output.begin(), output.end(), 1); + int c2 = std::count(output.begin(), output.end(), 2); + EXPECT_EQ(c0, 0); + EXPECT_EQ(c1, 0); + EXPECT_EQ(c2, kNumSamples); +} + +TYPED_TEST(MultinomialTest, TestSamplesDifferent) { + using Float = typename TestFixture::FloatType; + using Int = typename TestFixture::IntegralType; + const int kNumSamples = 5; + const int kNumLogits = 1000; + + tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}}, + kNumSamples, {tflite::GetTTEnum<Int>(), {}}); + + std::vector<Float> logits(kNumLogits, static_cast<Float>(0.0f)); + m.PopulateTensor<Float>(m.logits(), logits); + + m.Invoke(); + auto output1 = m.GetOutput<Int>(); + m.Invoke(); + auto output2 = m.GetOutput<Int>(); + + bool successive_samples_are_different = false; + for (int i = 0; i < kNumSamples; ++i) { + if (output1[i] == output2[i]) continue; + successive_samples_are_different = true; + break; + } + EXPECT_TRUE(successive_samples_are_different); +} + +TYPED_TEST(MultinomialTest, TestSamplesPrecise) { + using Float = typename TestFixture::FloatType; + using Int = typename TestFixture::IntegralType; + const int kNumSamples = 100000; + const int kNumLogits = 2; + + tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}}, + kNumSamples, {tflite::GetTTEnum<Int>(), {}}); + + std::vector<Float> logits( + {static_cast<Float>(1000.0), static_cast<float>(1001.0)}); + m.PopulateTensor<Float>(m.logits(), logits); + + m.Invoke(); + auto output = m.GetOutput<Int>(); + int c0 = std::count(output.begin(), output.end(), 0); + int c1 = std::count(output.begin(), output.end(), 1); + + double p0 = static_cast<double>(c0) / (c0 + c1); + EXPECT_LT(std::abs(p0 - 0.26894142137), 0.01); +}