Add prototype custom op for tf.multinomial.

PiperOrigin-RevId: 333302488
Change-Id: I16f2f21a93ee3c0ce4f3e2a45640548145d21d9a
This commit is contained in:
A. Unique TensorFlower 2020-09-23 09:00:22 -07:00 committed by TensorFlower Gardener
parent 77c5c05aef
commit 70621a5df8
4 changed files with 467 additions and 2 deletions

View File

@ -700,6 +700,7 @@ cc_library(
srcs = [
"complex_support.cc",
"cumsum.cc",
"multinomial.cc",
"rfft2d.cc",
],
hdrs = ["custom_ops_register.h"],
@ -1340,6 +1341,20 @@ cc_test(
],
)
cc_test(
name = "multinomial_test",
size = "small",
srcs = ["multinomial_test.cc"],
deps = [
":custom_ops",
":test_main",
":test_util",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/testing:util",
"@com_google_googletest//:gtest",
],
)
cc_test(
name = "pad_test",
size = "small",

View File

@ -22,13 +22,14 @@ namespace ops {
namespace custom {
TfLiteRegistration* Register_CUMSUM();
TfLiteRegistration* Register_RFFT2D();
TfLiteRegistration* Register_HASHTABLE();
TfLiteRegistration* Register_HASHTABLE_FIND();
TfLiteRegistration* Register_HASHTABLE_IMPORT();
TfLiteRegistration* Register_HASHTABLE_SIZE();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_IMAG();
TfLiteRegistration* Register_MULTINOMIAL();
TfLiteRegistration* Register_REAL();
TfLiteRegistration* Register_RFFT2D();
} // namespace custom
} // namespace ops

View File

@ -0,0 +1,206 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <cstdint>
#include <limits>
#include <random>
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace multinomial {
struct MultinomialParams {
std::default_random_engine rng;
};
// Draws a sample from a categorical distribution.
template <typename FloatType, typename IntegralType>
TfLiteStatus MultinomialSample(std::default_random_engine& rng,
const FloatType* logits, int logits_size,
IntegralType* outputs, int output_size) {
// Computes arg_max(cumsum(exp(logits)) > rand()).
// TODO(b/169166131): Remove hard-coded double for constrained use-cases.
std::vector<double> cumulative_odds;
cumulative_odds.reserve(logits_size);
double last_odds = 0.0;
// Compute max logit for numerical stability.
FloatType max_logit = std::numeric_limits<FloatType>::lowest();
for (int i = 0; i < logits_size; i++) {
max_logit = std::max(max_logit, logits[i]);
}
for (int i = 0; i < logits_size; i++) {
FloatType odds = std::exp(logits[i] - max_logit) + last_odds;
cumulative_odds.push_back(odds);
last_odds = odds;
}
std::uniform_real_distribution<double> distribution{0.0,
cumulative_odds.back()};
for (int i = 0; i < output_size; i++) {
double sample = distribution(rng);
auto it = std::lower_bound(cumulative_odds.begin(), cumulative_odds.end(),
sample);
if (it == cumulative_odds.end()) {
// This should be impossible by construction.
return kTfLiteError;
}
*outputs++ = static_cast<IntegralType>(it - cumulative_odds.begin());
}
return kTfLiteOk;
}
template <typename FloatType>
TfLiteStatus MultinomialSample(TfLiteContext* context,
std::default_random_engine& rng,
const FloatType* logits, int logits_size,
TfLiteTensor* output, int outputs_offset,
int output_size) {
switch (output->type) {
case kTfLiteInt32:
return MultinomialSample<FloatType, int32_t>(
rng, logits, logits_size,
GetTensorData<int32_t>(output) + outputs_offset, output_size);
break;
case kTfLiteInt64:
return MultinomialSample<FloatType, int64_t>(
rng, logits, logits_size,
GetTensorData<int64_t>(output) + outputs_offset, output_size);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Unsupported datatype for multinomial output: %s",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
}
TfLiteStatus MultinomialSample(TfLiteContext* context,
std::default_random_engine& rng,
const TfLiteTensor* logits, int logits_offset,
int logits_size, TfLiteTensor* output,
int outputs_offset, int output_size) {
switch (logits->type) {
case kTfLiteFloat16:
TF_LITE_KERNEL_LOG(context, "TfLiteFloat16 is currently not supported.");
return kTfLiteError;
break;
case kTfLiteFloat32:
TF_LITE_ENSURE_OK(
context,
MultinomialSample<float>(
context, rng, GetTensorData<float>(logits) + logits_offset,
logits_size, output, outputs_offset, output_size));
break;
case kTfLiteFloat64:
TF_LITE_ENSURE_OK(
context,
MultinomialSample<double>(
context, rng, GetTensorData<double>(logits) + logits_offset,
logits_size, output, outputs_offset, output_size));
break;
default:
TF_LITE_KERNEL_LOG(context,
"Unsupported datatype for multinomial logit input: %s",
TfLiteTypeGetName(output->type));
return kTfLiteError;
}
return kTfLiteOk;
}
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
return new MultinomialParams();
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<MultinomialParams*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/169166131): Handle optional seed input.
TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
// 'logits' is a float matrix [batch_size, num_categories]
const TfLiteTensor* logits_input = tflite::GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(logits_input), 2);
int batch_size = tflite::SizeOfDimension(logits_input, 0);
// 'num_samples' is an int scalar.
const TfLiteTensor* num_samples_input = tflite::GetInput(context, node, 1);
TF_LITE_ENSURE_EQ(context, tflite::NumDimensions(num_samples_input), 0);
// TODO(b/169166131): Allow different integer input types.
TF_LITE_ENSURE_EQ(context, num_samples_input->type, kTfLiteInt32);
// TODO(b/169166131): Support dynamic output tensors.
TF_LITE_ENSURE(context, IsConstantTensor(num_samples_input));
int num_samples = *num_samples_input->data.i32;
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(2);
output_shape->data[0] = batch_size;
output_shape->data[1] = num_samples;
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
// ResizeTensor takes ownership of output_shape.
return context->ResizeTensor(context, output, output_shape);
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(b/169166131): Handle optional seed input.
MultinomialParams* params =
reinterpret_cast<MultinomialParams*>(node->user_data);
TF_LITE_ENSURE(context, params != nullptr);
const TfLiteTensor* logits = tflite::GetInput(context, node, 0);
int batch_size = tflite::SizeOfDimension(logits, 0);
int logits_size = tflite::SizeOfDimension(logits, 1);
const TfLiteTensor* num_samples_input = tflite::GetInput(context, node, 1);
int output_size = *num_samples_input->data.i32;
TfLiteTensor* output = tflite::GetOutput(context, node, 0);
for (int batch = 0; batch < batch_size; ++batch) {
int logits_offset = logits_size * batch;
int output_offset = output_size * batch;
TF_LITE_ENSURE_OK(
context,
MultinomialSample(context, params->rng, logits, logits_offset,
logits_size, output, output_offset, output_size));
}
return kTfLiteOk;
}
} // namespace multinomial
TfLiteRegistration* Register_MULTINOMIAL() {
static TfLiteRegistration r = {multinomial::Init, multinomial::Free,
multinomial::Prepare, multinomial::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite

View File

@ -0,0 +1,243 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <limits>
#include <random>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace {
template <typename T>
tflite::TensorType GetTTEnum();
template <>
tflite::TensorType GetTTEnum<float>() {
return tflite::TensorType_FLOAT32;
}
template <>
tflite::TensorType GetTTEnum<double>() {
return tflite::TensorType_FLOAT64;
}
template <>
tflite::TensorType GetTTEnum<int>() {
return tflite::TensorType_INT32;
}
template <>
tflite::TensorType GetTTEnum<int64_t>() {
return tflite::TensorType_INT64;
}
class MultinomialOpModel : public tflite::SingleOpModel {
public:
MultinomialOpModel(tflite::TensorData logits, int num_samples,
tflite::TensorData output) {
logits_ = AddInput(logits);
num_samples_ = AddConstInput(tflite::TensorType_INT32, {num_samples}, {});
output_ = AddOutput(output);
SetCustomOp("Multinomial", {}, ops::custom::Register_MULTINOMIAL);
BuildInterpreter({GetShape(logits_), GetShape(num_samples_)});
}
int logits_;
int num_samples_;
int output_;
int logits() { return logits_; }
int num_samples() { return num_samples_; }
int output() { return output_; }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
};
} // namespace
} // namespace tflite
template <typename Type1, typename Type2>
struct TypePair {
using T1 = Type1;
using T2 = Type2;
};
template <typename TypePair>
class MultinomialTest : public ::testing::Test {
public:
using FloatType = typename TypePair::T1;
using IntegralType = typename TypePair::T2;
};
using TestTypes =
::testing::Types<TypePair<float, int>, TypePair<double, int>,
TypePair<float, int64_t>, TypePair<double, int64_t> >;
TYPED_TEST_SUITE(MultinomialTest, TestTypes);
TYPED_TEST(MultinomialTest, TestMultiBatch) {
std::default_random_engine rng;
std::uniform_real_distribution<float> dist(0, 10);
std::vector<float> results;
for (int i = 0; i < 1000; ++i) {
results.push_back(dist(rng));
}
double sum = 0;
for (auto f : results) {
sum += f;
}
double avg = sum / results.size();
double min = *std::min_element(results.begin(), results.end());
double max = *std::max_element(results.begin(), results.end());
EXPECT_GT(avg, 4.8);
EXPECT_LT(avg, 5.2);
EXPECT_LT(min, 0.5);
EXPECT_GT(max, 9.5);
const int kNumSamples = 1000;
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {3, 3}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
// Add 3 batches of 3 logits each.
m.PopulateTensor<Float>(m.logits(),
std::vector<Float>(9, static_cast<Float>(0.0f)));
m.Invoke();
auto output = m.GetOutput<Int>();
EXPECT_EQ(output.size(), kNumSamples * 3);
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
int c2 = std::count(output.begin(), output.end(), 2);
EXPECT_EQ(c0 + c1 + c2, 3 * kNumSamples);
// Make sure they're all sampled with roughly equal probability.
EXPECT_GT(c0, 750);
EXPECT_GT(c1, 750);
EXPECT_GT(c2, 750);
EXPECT_LT(c0, 1250);
EXPECT_LT(c1, 1250);
EXPECT_LT(c2, 1250);
}
// Test that higher log odds are sampled more often.
TYPED_TEST(MultinomialTest, TestSampleHighLogOdds) {
const int kNumSamples = 1000;
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
// Add 1 batch of 3 logits.
m.PopulateTensor<Float>(m.logits(),
{static_cast<Float>(0.0f), static_cast<Float>(1.0f),
static_cast<Float>(0.0f)});
m.Invoke();
auto output = m.GetOutput<Int>();
EXPECT_EQ(output.size(), kNumSamples);
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
int c2 = std::count(output.begin(), output.end(), 2);
EXPECT_EQ(c0 + c1 + c2, kNumSamples);
EXPECT_GT(c1, c0);
EXPECT_GT(c1, c2);
}
// Test that very low log odds are never sampled.
TYPED_TEST(MultinomialTest, TestVeryLowLogOdds) {
const int kNumSamples = 1000;
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
// Add 1 batch of 3 logits.
m.PopulateTensor<Float>(
m.logits(), {static_cast<Float>(-1000.0f), static_cast<Float>(-1000.0f),
static_cast<Float>(0.0f)});
m.Invoke();
auto output = m.GetOutput<Int>();
EXPECT_EQ(output.size(), kNumSamples);
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
int c2 = std::count(output.begin(), output.end(), 2);
EXPECT_EQ(c0, 0);
EXPECT_EQ(c1, 0);
EXPECT_EQ(c2, kNumSamples);
}
TYPED_TEST(MultinomialTest, TestSamplesDifferent) {
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
const int kNumSamples = 5;
const int kNumLogits = 1000;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
std::vector<Float> logits(kNumLogits, static_cast<Float>(0.0f));
m.PopulateTensor<Float>(m.logits(), logits);
m.Invoke();
auto output1 = m.GetOutput<Int>();
m.Invoke();
auto output2 = m.GetOutput<Int>();
bool successive_samples_are_different = false;
for (int i = 0; i < kNumSamples; ++i) {
if (output1[i] == output2[i]) continue;
successive_samples_are_different = true;
break;
}
EXPECT_TRUE(successive_samples_are_different);
}
TYPED_TEST(MultinomialTest, TestSamplesPrecise) {
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
const int kNumSamples = 100000;
const int kNumLogits = 2;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
std::vector<Float> logits(
{static_cast<Float>(1000.0), static_cast<float>(1001.0)});
m.PopulateTensor<Float>(m.logits(), logits);
m.Invoke();
auto output = m.GetOutput<Int>();
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
double p0 = static_cast<double>(c0) / (c0 + c1);
EXPECT_LT(std::abs(p0 - 0.26894142137), 0.01);
}