From cce04e5308388ca3412ff8f89d4a6a742cd3edee Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 3 Nov 2020 19:16:04 -0800 Subject: [PATCH] internal tests cleanup PiperOrigin-RevId: 340568129 Change-Id: I521739fa73d00d4ef75a61556e54dc0c0805c2a9 --- .../core/platform/vmodule_benchmark_test.cc | 4 +- tensorflow/core/util/bcast_test.cc | 12 ++-- .../core/util/device_name_utils_test.cc | 4 +- .../core/util/presized_cuckoo_map_test.cc | 30 +++++----- .../util/tensor_bundle/tensor_bundle_test.cc | 18 +++--- tensorflow/core/util/work_sharder_test.cc | 6 +- .../stream_executor/lib/statusor_test.cc | 60 +++++++------------ 7 files changed, 59 insertions(+), 75 deletions(-) diff --git a/tensorflow/core/platform/vmodule_benchmark_test.cc b/tensorflow/core/platform/vmodule_benchmark_test.cc index 0f9e75bf9cd..f164ece93a8 100644 --- a/tensorflow/core/platform/vmodule_benchmark_test.cc +++ b/tensorflow/core/platform/vmodule_benchmark_test.cc @@ -18,8 +18,8 @@ limitations under the License. namespace tensorflow { -static void BM_DisabledVlog(int iters) { - for (int i = 0; i < iters; ++i) { +static void BM_DisabledVlog(::testing::benchmark::State& state) { + for (auto s : state) { VLOG(1) << "Testing VLOG(1)!"; } } diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc index b6e8bcd706b..63133fc5d4c 100644 --- a/tensorflow/core/util/bcast_test.cc +++ b/tensorflow/core/util/bcast_test.cc @@ -673,15 +673,17 @@ TEST(BCastTest, BatchIndices) { BCastBatchIndices({3, 1}, {2, 1, 2})); } -static void BM_BCastSetup(int iters, int same_shape) { +void BM_BCastSetup(::testing::benchmark::State& state) { + const int same_shape = state.range(0); + if (same_shape) { - testing::SetLabel("same_shapes"); - while (--iters > 0) { + state.SetLabel("same_shapes"); + for (auto s : state) { class BCast b({1000, 100}, {1000, 100}); } } else { - testing::SetLabel("different_shapes"); - while (--iters > 0) { + state.SetLabel("different_shapes"); + for (auto s : state) { class BCast b({3, 1, 5}, {2, 0, 3, 0, 5}); } } diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc index 065fcfbf2ce..7824d9f11c9 100644 --- a/tensorflow/core/util/device_name_utils_test.cc +++ b/tensorflow/core/util/device_name_utils_test.cc @@ -572,9 +572,9 @@ TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) { } } -static void BM_ParseFullName(int iters) { +static void BM_ParseFullName(::testing::benchmark::State& state) { DeviceNameUtils::ParsedName p; - while (iters--) { + for (auto s : state) { DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p); } } diff --git a/tensorflow/core/util/presized_cuckoo_map_test.cc b/tensorflow/core/util/presized_cuckoo_map_test.cc index f2c7904b004..36a764272da 100644 --- a/tensorflow/core/util/presized_cuckoo_map_test.cc +++ b/tensorflow/core/util/presized_cuckoo_map_test.cc @@ -164,13 +164,13 @@ static void CalculateKeys(uint64 num, std::vector *dst) { } } -static void BM_CuckooFill(int iters, int arg) { +void BM_CuckooFill(::testing::benchmark::State &state) { + const int arg = state.range(0); + uint64 table_size = arg; - testing::StopTiming(); std::vector calculated_keys; CalculateKeys(table_size, &calculated_keys); - testing::StartTiming(); - for (int iter = 0; iter < iters; iter++) { + for (auto s : state) { PresizedCuckooMap pscm(table_size); for (uint64 i = 0; i < table_size; i++) { pscm.InsertUnique(calculated_keys[i], i); @@ -180,25 +180,27 @@ static void BM_CuckooFill(int iters, int arg) { BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000); -static void BM_CuckooRead(int iters, int arg) { +void BM_CuckooRead(::testing::benchmark::State &state) { + const int arg = state.range(0); + uint64 table_size = arg; - testing::StopTiming(); std::vector calculated_keys; CalculateKeys(table_size, &calculated_keys); PresizedCuckooMap pscm(table_size); for (uint64 i = 0; i < table_size; i++) { pscm.InsertUnique(calculated_keys[i], i); } - testing::StartTiming(); - uint64_t defeat_optimization = 0; - for (int i = 0; i < iters; i++) { - uint64 key_index = i % table_size; // May slow down bench! + + int i = 0; + for (auto s : state) { + // Avoid using '%', which is expensive. + uint64 key_index = i; + ++i; + if (i == table_size) i = 0; + int out = 0; pscm.Find(calculated_keys[key_index], &out); - defeat_optimization += out; - } - if (defeat_optimization == 0) { - printf("Preventing the compiler from eliding the inner loop\n"); + tensorflow::testing::DoNotOptimize(out); } } diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index a2ac7c30073..dea55f3acbd 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -1109,9 +1109,8 @@ TEST_F(TensorBundleAlignmentTest, AlignmentTest) { } } -static void BM_BundleAlignmentByteOff(int iters, int alignment, - int tensor_size) { - testing::StopTiming(); +static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state, + int alignment, int tensor_size) { { BundleWriter::Options opts; opts.data_alignment = alignment; @@ -1122,18 +1121,17 @@ static void BM_BundleAlignmentByteOff(int iters, int alignment, } BundleReader reader(Env::Default(), Prefix("foo")); TF_CHECK_OK(reader.status()); - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { Tensor t; TF_CHECK_OK(reader.Lookup("big", &t)); } - testing::StopTiming(); } -#define BM_BundleAlignment(ALIGN, SIZE) \ - static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \ - BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \ - } \ +#define BM_BundleAlignment(ALIGN, SIZE) \ + static void BM_BundleAlignment_##ALIGN##_##SIZE( \ + ::testing::benchmark::State& state) { \ + BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \ + } \ BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE) BM_BundleAlignment(1, 512); diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc index bc5a1d221fb..f69572d6d7d 100644 --- a/tensorflow/core/util/work_sharder_test.cc +++ b/tensorflow/core/util/work_sharder_test.cc @@ -89,12 +89,14 @@ TEST(Shard, OverflowTest) { } } -void BM_Sharding(int iters, int arg) { +void BM_Sharding(::testing::benchmark::State& state) { + const int arg = state.range(0); + thread::ThreadPool threads(Env::Default(), "test", 16); const int64 total = 1LL << 30; auto lambda = [](int64 start, int64 limit) {}; auto work = std::cref(lambda); - for (; iters > 0; iters -= arg) { + for (auto s : state) { Shard(arg - 1, &threads, total, 1, work); } } diff --git a/tensorflow/stream_executor/lib/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc index 46bdb9d208f..6b59eaa4029 100644 --- a/tensorflow/stream_executor/lib/statusor_test.cc +++ b/tensorflow/stream_executor/lib/statusor_test.cc @@ -535,12 +535,10 @@ class BenchmarkType { // Calibrate the amount of time spent just calling DoWork, since each of our // tests will do this, we can subtract this out of benchmark results. -void BM_CalibrateWorkLoop(int iters) { - tensorflow::testing::StopTiming(); +void BM_CalibrateWorkLoop(::testing::benchmark::State& state) { BenchmarkFactory factory; BenchmarkType* result = factory.TrivialFactory(); - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { if (result != nullptr) { result->DoWork(); } @@ -550,11 +548,9 @@ BENCHMARK(BM_CalibrateWorkLoop); // Measure the time taken to call into the factory, return the value, // determine that it is OK, and invoke a trivial function. -void BM_TrivialFactory(int iters) { - tensorflow::testing::StopTiming(); +void BM_TrivialFactory(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { BenchmarkType* result = factory.TrivialFactory(); if (result != nullptr) { result->DoWork(); @@ -566,11 +562,9 @@ BENCHMARK(BM_TrivialFactory); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -void BM_ArgumentFactory(int iters) { - tensorflow::testing::StopTiming(); +void BM_ArgumentFactory(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { BenchmarkType* result = nullptr; Status status = factory.ArgumentFactory(&result); if (status.ok() && result != nullptr) { @@ -582,11 +576,9 @@ BENCHMARK(BM_ArgumentFactory); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -void BM_StatusOrFactory(int iters) { - tensorflow::testing::StopTiming(); +void BM_StatusOrFactory(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { StatusOr result = factory.StatusOrFactory(); if (result.ok()) { result.ValueOrDie()->DoWork(); @@ -598,11 +590,9 @@ BENCHMARK(BM_StatusOrFactory); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -void BM_ArgumentFactoryFail(int iters) { - tensorflow::testing::StopTiming(); +void BM_ArgumentFactoryFail(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { BenchmarkType* result = nullptr; Status status = factory.ArgumentFactoryFail(&result); if (status.ok() && result != nullptr) { @@ -614,11 +604,9 @@ BENCHMARK(BM_ArgumentFactoryFail); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -void BM_StatusOrFactoryFail(int iters) { - tensorflow::testing::StopTiming(); +void BM_StatusOrFactoryFail(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { StatusOr result = factory.StatusOrFactoryFail(); if (result.ok()) { result.ValueOrDie()->DoWork(); @@ -630,11 +618,9 @@ BENCHMARK(BM_StatusOrFactoryFail); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -void BM_ArgumentFactoryFailShortMsg(int iters) { - tensorflow::testing::StopTiming(); +void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { BenchmarkType* result = nullptr; Status status = factory.ArgumentFactoryFailShortMsg(&result); if (status.ok() && result != nullptr) { @@ -646,11 +632,9 @@ BENCHMARK(BM_ArgumentFactoryFailShortMsg); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -void BM_StatusOrFactoryFailShortMsg(int iters) { - tensorflow::testing::StopTiming(); +void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { StatusOr result = factory.StatusOrFactoryFailShortMsg(); if (result.ok()) { result.ValueOrDie()->DoWork(); @@ -662,11 +646,9 @@ BENCHMARK(BM_StatusOrFactoryFailShortMsg); // Measure the time taken to call into the factory, providing an // out-param for the result, evaluating the status result and the // result pointer, and invoking the trivial function. -void BM_ArgumentFactoryFailLongMsg(int iters) { - tensorflow::testing::StopTiming(); +void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { BenchmarkType* result = nullptr; Status status = factory.ArgumentFactoryFailLongMsg(&result); if (status.ok() && result != nullptr) { @@ -678,11 +660,9 @@ BENCHMARK(BM_ArgumentFactoryFailLongMsg); // Measure the time to use the StatusOr factory, evaluate the result, // and invoke the trivial function. -void BM_StatusOrFactoryFailLongMsg(int iters) { - tensorflow::testing::StopTiming(); +void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) { BenchmarkFactory factory; - tensorflow::testing::StartTiming(); - for (int i = 0; i != iters; ++i) { + for (auto s : state) { StatusOr result = factory.StatusOrFactoryFailLongMsg(); if (result.ok()) { result.ValueOrDie()->DoWork();