STT-tensorflow/tensorflow/lite/kernels/multinomial_test.cc
A. Unique TensorFlower e1b710dd04 Remove extraneous testing code.
PiperOrigin-RevId: 334260725
Change-Id: I6f5ad63c4cd652edc3761a116e43bb53f2ea0364
2020-09-28 16:29:29 -07:00

226 lines
7.0 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <limits>
#include <random>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/kernels/custom_ops_register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/testing/util.h"
namespace tflite {
namespace {
template <typename T>
tflite::TensorType GetTTEnum();
template <>
tflite::TensorType GetTTEnum<float>() {
return tflite::TensorType_FLOAT32;
}
template <>
tflite::TensorType GetTTEnum<double>() {
return tflite::TensorType_FLOAT64;
}
template <>
tflite::TensorType GetTTEnum<int>() {
return tflite::TensorType_INT32;
}
template <>
tflite::TensorType GetTTEnum<int64_t>() {
return tflite::TensorType_INT64;
}
class MultinomialOpModel : public tflite::SingleOpModel {
public:
MultinomialOpModel(tflite::TensorData logits, int num_samples,
tflite::TensorData output) {
logits_ = AddInput(logits);
num_samples_ = AddConstInput(tflite::TensorType_INT32, {num_samples}, {});
output_ = AddOutput(output);
SetCustomOp("Multinomial", {}, ops::custom::Register_MULTINOMIAL);
BuildInterpreter({GetShape(logits_), GetShape(num_samples_)});
}
int logits_;
int num_samples_;
int output_;
int logits() { return logits_; }
int num_samples() { return num_samples_; }
int output() { return output_; }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
};
} // namespace
} // namespace tflite
template <typename Type1, typename Type2>
struct TypePair {
using T1 = Type1;
using T2 = Type2;
};
template <typename TypePair>
class MultinomialTest : public ::testing::Test {
public:
using FloatType = typename TypePair::T1;
using IntegralType = typename TypePair::T2;
};
using TestTypes =
::testing::Types<TypePair<float, int>, TypePair<double, int>,
TypePair<float, int64_t>, TypePair<double, int64_t> >;
TYPED_TEST_SUITE(MultinomialTest, TestTypes);
TYPED_TEST(MultinomialTest, TestMultiBatch) {
const int kNumSamples = 1000;
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {3, 3}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
// Add 3 batches of 3 logits each.
m.PopulateTensor<Float>(m.logits(),
std::vector<Float>(9, static_cast<Float>(0.0f)));
m.Invoke();
auto output = m.GetOutput<Int>();
EXPECT_EQ(output.size(), kNumSamples * 3);
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
int c2 = std::count(output.begin(), output.end(), 2);
EXPECT_EQ(c0 + c1 + c2, 3 * kNumSamples);
// Make sure they're all sampled with roughly equal probability.
EXPECT_GT(c0, 750);
EXPECT_GT(c1, 750);
EXPECT_GT(c2, 750);
EXPECT_LT(c0, 1250);
EXPECT_LT(c1, 1250);
EXPECT_LT(c2, 1250);
}
// Test that higher log odds are sampled more often.
TYPED_TEST(MultinomialTest, TestSampleHighLogOdds) {
const int kNumSamples = 1000;
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
// Add 1 batch of 3 logits.
m.PopulateTensor<Float>(m.logits(),
{static_cast<Float>(0.0f), static_cast<Float>(1.0f),
static_cast<Float>(0.0f)});
m.Invoke();
auto output = m.GetOutput<Int>();
EXPECT_EQ(output.size(), kNumSamples);
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
int c2 = std::count(output.begin(), output.end(), 2);
EXPECT_EQ(c0 + c1 + c2, kNumSamples);
EXPECT_GT(c1, c0);
EXPECT_GT(c1, c2);
}
// Test that very low log odds are never sampled.
TYPED_TEST(MultinomialTest, TestVeryLowLogOdds) {
const int kNumSamples = 1000;
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, 3}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
// Add 1 batch of 3 logits.
m.PopulateTensor<Float>(
m.logits(), {static_cast<Float>(-1000.0f), static_cast<Float>(-1000.0f),
static_cast<Float>(0.0f)});
m.Invoke();
auto output = m.GetOutput<Int>();
EXPECT_EQ(output.size(), kNumSamples);
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
int c2 = std::count(output.begin(), output.end(), 2);
EXPECT_EQ(c0, 0);
EXPECT_EQ(c1, 0);
EXPECT_EQ(c2, kNumSamples);
}
TYPED_TEST(MultinomialTest, TestSamplesDifferent) {
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
const int kNumSamples = 5;
const int kNumLogits = 1000;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
std::vector<Float> logits(kNumLogits, static_cast<Float>(0.0f));
m.PopulateTensor<Float>(m.logits(), logits);
m.Invoke();
auto output1 = m.GetOutput<Int>();
m.Invoke();
auto output2 = m.GetOutput<Int>();
bool successive_samples_are_different = false;
for (int i = 0; i < kNumSamples; ++i) {
if (output1[i] == output2[i]) continue;
successive_samples_are_different = true;
break;
}
EXPECT_TRUE(successive_samples_are_different);
}
TYPED_TEST(MultinomialTest, TestSamplesPrecise) {
using Float = typename TestFixture::FloatType;
using Int = typename TestFixture::IntegralType;
const int kNumSamples = 100000;
const int kNumLogits = 2;
tflite::MultinomialOpModel m({tflite::GetTTEnum<Float>(), {1, kNumLogits}},
kNumSamples, {tflite::GetTTEnum<Int>(), {}});
std::vector<Float> logits(
{static_cast<Float>(1000.0), static_cast<float>(1001.0)});
m.PopulateTensor<Float>(m.logits(), logits);
m.Invoke();
auto output = m.GetOutput<Int>();
int c0 = std::count(output.begin(), output.end(), 0);
int c1 = std::count(output.begin(), output.end(), 1);
double p0 = static_cast<double>(c0) / (c0 + c1);
EXPECT_LT(std::abs(p0 - 0.26894142137), 0.01);
}