diff --git a/tensorflow/lite/kernels/random_uniform.cc b/tensorflow/lite/kernels/random_uniform.cc index 1b19a80cc43..1a248755907 100644 --- a/tensorflow/lite/kernels/random_uniform.cc +++ b/tensorflow/lite/kernels/random_uniform.cc @@ -140,7 +140,7 @@ TfLiteStatus EvalInt(TfLiteContext* context, TfLiteNode* node) { size_t output_size = tflite::NumElements(output); switch (output->type) { case kTfLiteInt8: - RandomUniformSample>( + RandomUniformSample>( params->rng, GetTensorData(output), output_size, min_value, max_value); break; diff --git a/tensorflow/lite/kernels/random_uniform_test.cc b/tensorflow/lite/kernels/random_uniform_test.cc index 28a795470af..d852f69e482 100644 --- a/tensorflow/lite/kernels/random_uniform_test.cc +++ b/tensorflow/lite/kernels/random_uniform_test.cc @@ -36,6 +36,11 @@ tflite::TensorType GetTTEnum() { return tflite::TensorType_FLOAT64; } +template <> +tflite::TensorType GetTTEnum() { + return tflite::TensorType_INT8; +} + template <> tflite::TensorType GetTTEnum() { return tflite::TensorType_INT32; @@ -150,7 +155,7 @@ class RandomUniformIntTest : public ::testing::Test { using Int = IntType; }; -using TestTypesInt = ::testing::Types; +using TestTypesInt = ::testing::Types; TYPED_TEST_SUITE(RandomUniformIntTest, TestTypesInt);