Add prototype custom op for tf.multinomial.
PiperOrigin-RevId: 333302488 Change-Id: I16f2f21a93ee3c0ce4f3e2a45640548145d21d9a
This commit is contained in:
parent
77c5c05aef
commit
70621a5df8
tensorflow/lite/kernels
@ -700,6 +700,7 @@ cc_library(
|
||||
srcs = [
|
||||
"complex_support.cc",
|
||||
"cumsum.cc",
|
||||
"multinomial.cc",
|
||||
"rfft2d.cc",
|
||||
],
|
||||
hdrs = ["custom_ops_register.h"],
|
||||
@ -1340,6 +1341,20 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "multinomial_test",
|
||||
size = "small",
|
||||
srcs = ["multinomial_test.cc"],
|
||||
deps = [
|
||||
":custom_ops",
|
||||
":test_main",
|
||||
":test_util",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"//tensorflow/lite/testing:util",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "pad_test",
|
||||
size = "small",
|
||||
|
@ -22,13 +22,14 @@ namespace ops {
|
||||
namespace custom {
|
||||
|
||||
TfLiteRegistration* Register_CUMSUM();
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
TfLiteRegistration* Register_HASHTABLE();
|
||||
TfLiteRegistration* Register_HASHTABLE_FIND();
|
||||
TfLiteRegistration* Register_HASHTABLE_IMPORT();
|
||||
TfLiteRegistration* Register_HASHTABLE_SIZE();
|
||||
TfLiteRegistration* Register_REAL();
|
||||
TfLiteRegistration* Register_IMAG();
|
||||
TfLiteRegistration* Register_MULTINOMIAL();
|
||||
TfLiteRegistration* Register_REAL();
|
||||
TfLiteRegistration* Register_RFFT2D();
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
|
206
tensorflow/lite/kernels/multinomial.cc
Normal file
206
tensorflow/lite/kernels/multinomial.cc
Normal file
@ -0,0 +1,206 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace custom {
|
||||
namespace multinomial {
|
||||
|
||||
struct MultinomialParams {
|
||||
std::default_random_engine rng;
|
||||
};
|
||||
|
||||
// Draws a sample from a categorical distribution.
|
||||
template <typename FloatType, typename IntegralType>
|
||||
TfLiteStatus MultinomialSample(std::default_random_engine& rng,
|
||||
const FloatType* logits, int logits_size,
|
||||
IntegralType* outputs, int output_size) {
|
||||
// Computes arg_max(cumsum(exp(logits)) > rand()).
|
||||
// TODO(b/169166131): Remove hard-coded double for constrained use-cases.
|
||||
std::vector<double> cumulative_odds;
|
||||
cumulative_odds.reserve(logits_size);
|
||||
double last_odds = 0.0;
|
||||
|
||||
// Compute max logit for numerical stability.
|
||||
FloatType max_logit = std::numeric_limits<FloatType>::lowest();
|
||||
for (int i = 0; i < logits_size; i++) {
|
||||
max_logit = std::max(max_logit, logits[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < logits_size; i++) {
|
||||
FloatType odds = std::exp(logits[i] - max_logit) + last_odds;
|
||||
cumulative_odds.push_back(odds);
|
||||
last_odds = odds;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> distribution{0.0,
|
||||
cumulative_odds.back()};
|
||||
|
||||
for (int i = 0; i < output_size; i++) {
|
||||
double sample = distribution(rng);
|
||||
auto it = std::lower_bound(cumulative_odds.begin(), cumulative_odds.end(),
|
||||
sample);
|
||||
if (it == cumulative_odds.end()) {
|
||||
// This should be impossible by construction.
|
||||
return kTfLiteError;
|
||||
}
|
||||
*outputs++ = static_cast<IntegralType>(it - cumulative_odds.begin());
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <typename FloatType>
|
||||
TfLiteStatus MultinomialSample(TfLiteContext* context,
|
||||
std::default_random_engine& rng,
|
||||
const FloatType* logits, int logits_size,
|
||||
TfLiteTensor* output, int outputs_offset,
|
||||
int output_size) {
|
||||
switch (output->type) {
|
||||
case kTfLiteInt32:
|
||||
return MultinomialSample<FloatType, int32_t>(
|
||||
rng, logits, logits_size,
|
||||
GetTensorData<int32_t>(output) + outputs_offset, output_size);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
return MultinomialSample<FloatType, int64_t>(
|
||||
rng, logits, logits_size,
|
||||
GetTensorData<int64_t>(output) + outputs_offset, output_size);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unsupported datatype for multinomial output: %s",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus MultinomialSample(TfLiteContext* context,
|
||||
std::default_random_engine& rng,
|
||||
const TfLiteTensor* logits, int logits_offset,
|
||||
int logits_size, TfLiteTensor* output,
|
||||
int outputs_offset, int output_size) {
|
||||
switch (logits->type) {
|
||||
case kTfLiteFloat16:
|
||||
TF_LITE_KERNEL_LOG(context, "TfLiteFloat16 is currently not supported.");
|
||||
return kTfLiteError;
|
||||
break;
|
||||
case kTfLiteFloat32:
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
MultinomialSample<float>(
|
||||
context, rng, GetTensorData<float>(logits) + logits_offset,
|
||||
logits_size, output, outputs_offset, output_size));
|
||||
break;
|
||||
case kTfLiteFloat64:
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
MultinomialSample<double>(
|
||||
context, rng, GetTensorData<double>(logits) + logits_offset,
|
||||
logits_size, output, outputs_offset, output_size));
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context,
|
||||
"Unsupported datatype for multinomial logit input: %s",
|
||||
TfLiteTypeGetName(output->type));
|
||||
return kTfLiteError;
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
return new MultinomialParams();
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete reinterpret_cast<MultinomialParams*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO(b/169166131): Handle optional seed input.
|
||||
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
|
||||
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
|
||||
|
||||
// 'logits' is a float matrix [batch_size, num_categories]
|
||||
const TfLiteTensor* logits_input = tflite::GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(logits_input), 2);
|
||||
int batch_size = tflite::SizeOfDimension(logits_input, 0);
|
||||
|
||||
// 'num_samples' is an int scalar.
|
||||
const TfLiteTensor* num_samples_input = tflite::GetInput(context, node, 1);
|
||||
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(num_samples_input), 0);
|
||||
// TODO(b/169166131): Allow different integer input types.
|
||||
TF_LITE_ENSURE_EQ(context, num_samples_input->type, kTfLiteInt32);
|
||||
// TODO(b/169166131): Support dynamic output tensors.
|
||||
TF_LITE_ENSURE(context, IsConstantTensor(num_samples_input));
|
||||
|
||||
int num_samples = *num_samples_input->data.i32;
|
||||
|
||||
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(2);
|
||||
output_shape->data[0] = batch_size;
|
||||
output_shape->data[1] = num_samples;
|
||||
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
|
||||
// ResizeTensor takes ownership of output_shape.
|
||||
return context->ResizeTensor(context, output, output_shape);
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO(b/169166131): Handle optional seed input.
|
||||
MultinomialParams* params =
|
||||
reinterpret_cast<MultinomialParams*>(node->user_data);
|
||||
TF_LITE_ENSURE(context, params != nullptr);
|
||||
|
||||
const TfLiteTensor* logits = tflite::GetInput(context, node, 0);
|
||||
int batch_size = tflite::SizeOfDimension(logits, 0);
|
||||
int logits_size = tflite::SizeOfDimension(logits, 1);
|
||||
|
||||
const TfLiteTensor* num_samples_input = tflite::GetInput(context, node, 1);
|
||||
int output_size = *num_samples_input->data.i32;
|
||||
|
||||
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
|
||||
|
||||
for (int batch = 0; batch < batch_size; ++batch) {
|
||||
int logits_offset = logits_size * batch;
|
||||
int output_offset = output_size * batch;
|
||||
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
MultinomialSample(context, params->rng, logits, logits_offset,
|
||||
logits_size, output, output_offset, output_size));
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace multinomial
|
||||
|
||||
TfLiteRegistration* Register_MULTINOMIAL() {
|
||||
static TfLiteRegistration r = {multinomial::Init, multinomial::Free,
|
||||
multinomial::Prepare, multinomial::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace custom
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
243
tensorflow/lite/kernels/multinomial_test.cc
Normal file
243
tensorflow/lite/kernels/multinomial_test.cc
Normal file
@ -0,0 +1,243 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/kernels/custom_ops_register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/testing/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
tflite::TensorType GetTTEnum();
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<float>() {
|
||||
return tflite::TensorType_FLOAT32;
|
||||
}
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<double>() {
|
||||
return tflite::TensorType_FLOAT64;
|
||||
}
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<int>() {
|
||||
return tflite::TensorType_INT32;
|
||||
}
|
||||
|
||||
template <>
|
||||
tflite::TensorType GetTTEnum<int64_t>() {
|
||||
return tflite::TensorType_INT64;
|
||||
}
|
||||
|
||||
class MultinomialOpModel : public tflite::SingleOpModel {
|
||||
public:
|
||||
MultinomialOpModel(tflite::TensorData logits, int num_samples,
|
||||
tflite::TensorData output) {
|
||||
logits_ = AddInput(logits);
|
||||
num_samples_ = AddConstInput(tflite::TensorType_INT32, {num_samples}, {});
|
||||
output_ = AddOutput(output);
|
||||
SetCustomOp("Multinomial", {}, ops::custom::Register_MULTINOMIAL);
|
||||
BuildInterpreter({GetShape(logits_), GetShape(num_samples_)});
|
||||
}
|
||||
|
||||
int logits_;
|
||||
int num_samples_;
|
||||
int output_;
|
||||
|
||||
int logits() { return logits_; }
|
||||
int num_samples() { return num_samples_; }
|
||||
int output() { return output_; }
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
template <typename Type1, typename Type2>
|
||||
struct TypePair {
|
||||
using T1 = Type1;
|
||||
using T2 = Type2;
|
||||
};
|
||||
|
||||
template <typename TypePair>
|
||||
class MultinomialTest : public ::testing::Test {
|
||||
public:
|
||||
using FloatType = typename TypePair::T1;
|
||||
using IntegralType = typename TypePair::T2;
|
||||
};
|
||||
|
||||
using TestTypes =
|
||||
::testing::Types<TypePair<float, int>, TypePair<double, int>,
|
||||
TypePair<float, int64_t>, TypePair<double, int64_t> >;
|
||||
|
||||
TYPED_TEST_SUITE(MultinomialTest, TestTypes);
|
||||
|
||||
TYPED_TEST(MultinomialTest, TestMultiBatch) {
|
||||
std::default_random_engine rng;
|
||||
std::uniform_real_distribution<float> dist(0, 10);
|
||||
std::vector<float> results;
|
||||
for (int i = 0; i < 1000; ++i) {
|
||||
results.push_back(dist(rng));
|
||||
}
|
||||
double sum = 0;
|
||||
for (auto f : results) {
|
||||
sum += f;
|
||||
}
|
||||
double avg = sum / results.size();
|
||||
double min = *std::min_element(results.begin(), results.end());
|
||||
double max = *std::max_element(results.begin(), results.end());
|
||||
EXPECT_GT(avg, 4.8);
|
||||
EXPECT_LT(avg, 5.2);
|
||||
EXPECT_LT(min, 0.5);
|
||||
EXPECT_GT(max, 9.5);
|
||||
|
||||
const int kNumSamples = 1000;
|
||||
using Float = typename TestFixture::FloatType;
|
||||
using Int = typename TestFixture::IntegralType;
|
||||
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {3, 3}},
|
||||
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
|
||||
// Add 3 batches of 3 logits each.
|
||||
m.PopulateTensor<Float>(m.logits(),
|
||||
std::vector<Float>(9, static_cast<Float>(0.0f)));
|
||||
|
||||
m.Invoke();
|
||||
auto output = m.GetOutput<Int>();
|
||||
EXPECT_EQ(output.size(), kNumSamples * 3);
|
||||
|
||||
int c0 = std::count(output.begin(), output.end(), 0);
|
||||
int c1 = std::count(output.begin(), output.end(), 1);
|
||||
int c2 = std::count(output.begin(), output.end(), 2);
|
||||
|
||||
EXPECT_EQ(c0 + c1 + c2, 3 * kNumSamples);
|
||||
|
||||
// Make sure they're all sampled with roughly equal probability.
|
||||
EXPECT_GT(c0, 750);
|
||||
EXPECT_GT(c1, 750);
|
||||
EXPECT_GT(c2, 750);
|
||||
EXPECT_LT(c0, 1250);
|
||||
EXPECT_LT(c1, 1250);
|
||||
EXPECT_LT(c2, 1250);
|
||||
}
|
||||
|
||||
// Test that higher log odds are sampled more often.
|
||||
TYPED_TEST(MultinomialTest, TestSampleHighLogOdds) {
|
||||
const int kNumSamples = 1000;
|
||||
using Float = typename TestFixture::FloatType;
|
||||
using Int = typename TestFixture::IntegralType;
|
||||
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}},
|
||||
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
|
||||
|
||||
// Add 1 batch of 3 logits.
|
||||
m.PopulateTensor<Float>(m.logits(),
|
||||
{static_cast<Float>(0.0f), static_cast<Float>(1.0f),
|
||||
static_cast<Float>(0.0f)});
|
||||
m.Invoke();
|
||||
auto output = m.GetOutput<Int>();
|
||||
EXPECT_EQ(output.size(), kNumSamples);
|
||||
|
||||
int c0 = std::count(output.begin(), output.end(), 0);
|
||||
int c1 = std::count(output.begin(), output.end(), 1);
|
||||
int c2 = std::count(output.begin(), output.end(), 2);
|
||||
EXPECT_EQ(c0 + c1 + c2, kNumSamples);
|
||||
EXPECT_GT(c1, c0);
|
||||
EXPECT_GT(c1, c2);
|
||||
}
|
||||
|
||||
// Test that very low log odds are never sampled.
|
||||
TYPED_TEST(MultinomialTest, TestVeryLowLogOdds) {
|
||||
const int kNumSamples = 1000;
|
||||
using Float = typename TestFixture::FloatType;
|
||||
using Int = typename TestFixture::IntegralType;
|
||||
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}},
|
||||
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
|
||||
|
||||
// Add 1 batch of 3 logits.
|
||||
m.PopulateTensor<Float>(
|
||||
m.logits(), {static_cast<Float>(-1000.0f), static_cast<Float>(-1000.0f),
|
||||
static_cast<Float>(0.0f)});
|
||||
m.Invoke();
|
||||
auto output = m.GetOutput<Int>();
|
||||
EXPECT_EQ(output.size(), kNumSamples);
|
||||
|
||||
int c0 = std::count(output.begin(), output.end(), 0);
|
||||
int c1 = std::count(output.begin(), output.end(), 1);
|
||||
int c2 = std::count(output.begin(), output.end(), 2);
|
||||
EXPECT_EQ(c0, 0);
|
||||
EXPECT_EQ(c1, 0);
|
||||
EXPECT_EQ(c2, kNumSamples);
|
||||
}
|
||||
|
||||
TYPED_TEST(MultinomialTest, TestSamplesDifferent) {
|
||||
using Float = typename TestFixture::FloatType;
|
||||
using Int = typename TestFixture::IntegralType;
|
||||
const int kNumSamples = 5;
|
||||
const int kNumLogits = 1000;
|
||||
|
||||
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}},
|
||||
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
|
||||
|
||||
std::vector<Float> logits(kNumLogits, static_cast<Float>(0.0f));
|
||||
m.PopulateTensor<Float>(m.logits(), logits);
|
||||
|
||||
m.Invoke();
|
||||
auto output1 = m.GetOutput<Int>();
|
||||
m.Invoke();
|
||||
auto output2 = m.GetOutput<Int>();
|
||||
|
||||
bool successive_samples_are_different = false;
|
||||
for (int i = 0; i < kNumSamples; ++i) {
|
||||
if (output1[i] == output2[i]) continue;
|
||||
successive_samples_are_different = true;
|
||||
break;
|
||||
}
|
||||
EXPECT_TRUE(successive_samples_are_different);
|
||||
}
|
||||
|
||||
TYPED_TEST(MultinomialTest, TestSamplesPrecise) {
|
||||
using Float = typename TestFixture::FloatType;
|
||||
using Int = typename TestFixture::IntegralType;
|
||||
const int kNumSamples = 100000;
|
||||
const int kNumLogits = 2;
|
||||
|
||||
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}},
|
||||
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
|
||||
|
||||
std::vector<Float> logits(
|
||||
{static_cast<Float>(1000.0), static_cast<float>(1001.0)});
|
||||
m.PopulateTensor<Float>(m.logits(), logits);
|
||||
|
||||
m.Invoke();
|
||||
auto output = m.GetOutput<Int>();
|
||||
int c0 = std::count(output.begin(), output.end(), 0);
|
||||
int c1 = std::count(output.begin(), output.end(), 1);
|
||||
|
||||
double p0 = static_cast<double>(c0) / (c0 + c1);
|
||||
EXPECT_LT(std::abs(p0 - 0.26894142137), 0.01);
|
||||
}
|
Loading…
Reference in New Issue
Block a user