diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 3448dde8ce1..82fb62cae8d 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -290,11 +291,9 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { return default_params; } -BenchmarkTfLiteModel::BenchmarkTfLiteModel() - : BenchmarkTfLiteModel(DefaultParams()) {} - BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params) - : BenchmarkModel(std::move(params)) {} + : BenchmarkModel(std::move(params)), + random_engine_(std::random_device()()) {} void BenchmarkTfLiteModel::CleanUp() { // Free up any pre-allocated tensor data during PrepareInputData. @@ -453,22 +452,16 @@ TfLiteStatus BenchmarkTfLiteModel::PrepareInputData() { } InputTensorData t_data; if (t->type == kTfLiteFloat32) { - t_data = InputTensorData::Create(num_elements, []() { - return static_cast(rand()) / RAND_MAX - 0.5f; - }); + t_data = CreateInputTensorData( + num_elements, std::uniform_real_distribution(-0.5f, 0.5f)); } else if (t->type == kTfLiteFloat16) { // TODO(b/138843274): Remove this preprocessor guard when bug is fixed. #if TFLITE_ENABLE_FP16_CPU_BENCHMARKS #if __GNUC__ && \ (__clang__ || __ARM_FP16_FORMAT_IEEE || __ARM_FP16_FORMAT_ALTERNATIVE) // __fp16 is available on Clang or when __ARM_FP16_FORMAT_* is defined. - t_data = InputTensorData::Create( - num_elements, []() -> TfLiteFloat16 { - __fp16 f16_value = static_cast(rand()) / RAND_MAX - 0.5f; - TfLiteFloat16 f16_placeholder_value; - memcpy(&f16_placeholder_value, &f16_value, sizeof(TfLiteFloat16)); - return f16_placeholder_value; - }); + t_data = CreateInputTensorData<__fp16>( + num_elements, std::uniform_real_distribution(-0.5f, 0.5f)); #else TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name << " of type FLOAT16 on this platform."; @@ -484,33 +477,28 @@ TfLiteStatus BenchmarkTfLiteModel::PrepareInputData() { } else if (t->type == kTfLiteInt64) { int low = has_value_range ? low_range : 0; int high = has_value_range ? high_range : 99; - t_data = InputTensorData::Create(num_elements, [=]() { - return static_cast(rand() % (high - low + 1) + low); - }); + t_data = CreateInputTensorData( + num_elements, std::uniform_int_distribution(low, high)); } else if (t->type == kTfLiteInt32) { int low = has_value_range ? low_range : 0; int high = has_value_range ? high_range : 99; - t_data = InputTensorData::Create(num_elements, [=]() { - return static_cast(rand() % (high - low + 1) + low); - }); + t_data = CreateInputTensorData( + num_elements, std::uniform_int_distribution(low, high)); } else if (t->type == kTfLiteInt16) { int low = has_value_range ? low_range : 0; int high = has_value_range ? high_range : 99; - t_data = InputTensorData::Create(num_elements, [=]() { - return static_cast(rand() % (high - low + 1) + low); - }); + t_data = CreateInputTensorData( + num_elements, std::uniform_int_distribution(low, high)); } else if (t->type == kTfLiteUInt8) { int low = has_value_range ? low_range : 0; int high = has_value_range ? high_range : 254; - t_data = InputTensorData::Create(num_elements, [=]() { - return static_cast(rand() % (high - low + 1) + low); - }); + t_data = CreateInputTensorData( + num_elements, std::uniform_int_distribution(low, high)); } else if (t->type == kTfLiteInt8) { int low = has_value_range ? low_range : -127; int high = has_value_range ? high_range : 127; - t_data = InputTensorData::Create(num_elements, [=]() { - return static_cast(rand() % (high - low + 1) + low); - }); + t_data = CreateInputTensorData( + num_elements, std::uniform_int_distribution(low, high)); } else if (t->type == kTfLiteString) { // TODO(haoliang): No need to cache string tensors right now. } else { diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h index 491007f64c7..ca7731eed33 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -47,8 +48,7 @@ class BenchmarkTfLiteModel : public BenchmarkModel { int high; }; - BenchmarkTfLiteModel(); - explicit BenchmarkTfLiteModel(BenchmarkParams params); + explicit BenchmarkTfLiteModel(BenchmarkParams params = DefaultParams()); ~BenchmarkTfLiteModel() override; std::vector GetFlags() override; @@ -80,30 +80,33 @@ class BenchmarkTfLiteModel : public BenchmarkModel { struct InputTensorData { InputTensorData() : data(nullptr, nullptr) {} - template - static InputTensorData Create(int num_elements, - const std::function& val_generator) { - InputTensorData tmp; - tmp.bytes = sizeof(T) * num_elements; - T* raw = new T[num_elements]; - std::generate_n(raw, num_elements, val_generator); - // Now initialize the type-erased unique_ptr (with custom deleter) from - // 'raw'. - tmp.data = std::unique_ptr( - static_cast(raw), - [](void* ptr) { delete[] static_cast(ptr); }); - return tmp; - } - std::unique_ptr data; size_t bytes; }; + template + inline InputTensorData CreateInputTensorData(int num_elements, + Distribution distribution) { + InputTensorData tmp; + tmp.bytes = sizeof(T) * num_elements; + T* raw = new T[num_elements]; + std::generate_n(raw, num_elements, + [&]() { return distribution(random_engine_); }); + // Now initialize the type-erased unique_ptr (with custom deleter) from + // 'raw'. + tmp.data = std::unique_ptr( + static_cast(raw), + [](void* ptr) { delete[] static_cast(ptr); }); + return tmp; + } + std::vector inputs_; std::vector inputs_data_; std::unique_ptr profiling_listener_; std::unique_ptr gemmlowp_profiling_listener_; TfLiteDelegatePtrMap delegates_; + + std::mt19937 random_engine_; }; } // namespace benchmark