internal tests cleanup

PiperOrigin-RevId: 340568129
Change-Id: I521739fa73d00d4ef75a61556e54dc0c0805c2a9
This commit is contained in:
A. Unique TensorFlower 2020-11-03 19:16:04 -08:00 committed by TensorFlower Gardener
parent 2f960b4bc7
commit cce04e5308
7 changed files with 59 additions and 75 deletions

View File

@ -18,8 +18,8 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
static void BM_DisabledVlog(int iters) { static void BM_DisabledVlog(::testing::benchmark::State& state) {
for (int i = 0; i < iters; ++i) { for (auto s : state) {
VLOG(1) << "Testing VLOG(1)!"; VLOG(1) << "Testing VLOG(1)!";
} }
} }

View File

@ -673,15 +673,17 @@ TEST(BCastTest, BatchIndices) {
BCastBatchIndices({3, 1}, {2, 1, 2})); BCastBatchIndices({3, 1}, {2, 1, 2}));
} }
static void BM_BCastSetup(int iters, int same_shape) { void BM_BCastSetup(::testing::benchmark::State& state) {
const int same_shape = state.range(0);
if (same_shape) { if (same_shape) {
testing::SetLabel("same_shapes"); state.SetLabel("same_shapes");
while (--iters > 0) { for (auto s : state) {
class BCast b({1000, 100}, {1000, 100}); class BCast b({1000, 100}, {1000, 100});
} }
} else { } else {
testing::SetLabel("different_shapes"); state.SetLabel("different_shapes");
while (--iters > 0) { for (auto s : state) {
class BCast b({3, 1, 5}, {2, 0, 3, 0, 5}); class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
} }
} }

View File

@ -572,9 +572,9 @@ TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) {
} }
} }
static void BM_ParseFullName(int iters) { static void BM_ParseFullName(::testing::benchmark::State& state) {
DeviceNameUtils::ParsedName p; DeviceNameUtils::ParsedName p;
while (iters--) { for (auto s : state) {
DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p); DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
} }
} }

View File

@ -164,13 +164,13 @@ static void CalculateKeys(uint64 num, std::vector<uint64> *dst) {
} }
} }
static void BM_CuckooFill(int iters, int arg) { void BM_CuckooFill(::testing::benchmark::State &state) {
const int arg = state.range(0);
uint64 table_size = arg; uint64 table_size = arg;
testing::StopTiming();
std::vector<uint64> calculated_keys; std::vector<uint64> calculated_keys;
CalculateKeys(table_size, &calculated_keys); CalculateKeys(table_size, &calculated_keys);
testing::StartTiming(); for (auto s : state) {
for (int iter = 0; iter < iters; iter++) {
PresizedCuckooMap<int> pscm(table_size); PresizedCuckooMap<int> pscm(table_size);
for (uint64 i = 0; i < table_size; i++) { for (uint64 i = 0; i < table_size; i++) {
pscm.InsertUnique(calculated_keys[i], i); pscm.InsertUnique(calculated_keys[i], i);
@ -180,25 +180,27 @@ static void BM_CuckooFill(int iters, int arg) {
BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000); BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000);
static void BM_CuckooRead(int iters, int arg) { void BM_CuckooRead(::testing::benchmark::State &state) {
const int arg = state.range(0);
uint64 table_size = arg; uint64 table_size = arg;
testing::StopTiming();
std::vector<uint64> calculated_keys; std::vector<uint64> calculated_keys;
CalculateKeys(table_size, &calculated_keys); CalculateKeys(table_size, &calculated_keys);
PresizedCuckooMap<int> pscm(table_size); PresizedCuckooMap<int> pscm(table_size);
for (uint64 i = 0; i < table_size; i++) { for (uint64 i = 0; i < table_size; i++) {
pscm.InsertUnique(calculated_keys[i], i); pscm.InsertUnique(calculated_keys[i], i);
} }
testing::StartTiming();
uint64_t defeat_optimization = 0; int i = 0;
for (int i = 0; i < iters; i++) { for (auto s : state) {
uint64 key_index = i % table_size; // May slow down bench! // Avoid using '%', which is expensive.
uint64 key_index = i;
++i;
if (i == table_size) i = 0;
int out = 0; int out = 0;
pscm.Find(calculated_keys[key_index], &out); pscm.Find(calculated_keys[key_index], &out);
defeat_optimization += out; tensorflow::testing::DoNotOptimize(out);
}
if (defeat_optimization == 0) {
printf("Preventing the compiler from eliding the inner loop\n");
} }
} }

View File

@ -1109,9 +1109,8 @@ TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
} }
} }
static void BM_BundleAlignmentByteOff(int iters, int alignment, static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
int tensor_size) { int alignment, int tensor_size) {
testing::StopTiming();
{ {
BundleWriter::Options opts; BundleWriter::Options opts;
opts.data_alignment = alignment; opts.data_alignment = alignment;
@ -1122,18 +1121,17 @@ static void BM_BundleAlignmentByteOff(int iters, int alignment,
} }
BundleReader reader(Env::Default(), Prefix("foo")); BundleReader reader(Env::Default(), Prefix("foo"));
TF_CHECK_OK(reader.status()); TF_CHECK_OK(reader.status());
testing::StartTiming(); for (auto s : state) {
for (int i = 0; i < iters; ++i) {
Tensor t; Tensor t;
TF_CHECK_OK(reader.Lookup("big", &t)); TF_CHECK_OK(reader.Lookup("big", &t));
} }
testing::StopTiming();
} }
#define BM_BundleAlignment(ALIGN, SIZE) \ #define BM_BundleAlignment(ALIGN, SIZE) \
static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \ static void BM_BundleAlignment_##ALIGN##_##SIZE( \
BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \ ::testing::benchmark::State& state) { \
} \ BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
} \
BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE) BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
BM_BundleAlignment(1, 512); BM_BundleAlignment(1, 512);

View File

@ -89,12 +89,14 @@ TEST(Shard, OverflowTest) {
} }
} }
void BM_Sharding(int iters, int arg) { void BM_Sharding(::testing::benchmark::State& state) {
const int arg = state.range(0);
thread::ThreadPool threads(Env::Default(), "test", 16); thread::ThreadPool threads(Env::Default(), "test", 16);
const int64 total = 1LL << 30; const int64 total = 1LL << 30;
auto lambda = [](int64 start, int64 limit) {}; auto lambda = [](int64 start, int64 limit) {};
auto work = std::cref(lambda); auto work = std::cref(lambda);
for (; iters > 0; iters -= arg) { for (auto s : state) {
Shard(arg - 1, &threads, total, 1, work); Shard(arg - 1, &threads, total, 1, work);
} }
} }

View File

@ -535,12 +535,10 @@ class BenchmarkType {
// Calibrate the amount of time spent just calling DoWork, since each of our // Calibrate the amount of time spent just calling DoWork, since each of our
// tests will do this, we can subtract this out of benchmark results. // tests will do this, we can subtract this out of benchmark results.
void BM_CalibrateWorkLoop(int iters) { void BM_CalibrateWorkLoop(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
BenchmarkType* result = factory.TrivialFactory(); BenchmarkType* result = factory.TrivialFactory();
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
if (result != nullptr) { if (result != nullptr) {
result->DoWork(); result->DoWork();
} }
@ -550,11 +548,9 @@ BENCHMARK(BM_CalibrateWorkLoop);
// Measure the time taken to call into the factory, return the value, // Measure the time taken to call into the factory, return the value,
// determine that it is OK, and invoke a trivial function. // determine that it is OK, and invoke a trivial function.
void BM_TrivialFactory(int iters) { void BM_TrivialFactory(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
BenchmarkType* result = factory.TrivialFactory(); BenchmarkType* result = factory.TrivialFactory();
if (result != nullptr) { if (result != nullptr) {
result->DoWork(); result->DoWork();
@ -566,11 +562,9 @@ BENCHMARK(BM_TrivialFactory);
// Measure the time taken to call into the factory, providing an // Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the // out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function. // result pointer, and invoking the trivial function.
void BM_ArgumentFactory(int iters) { void BM_ArgumentFactory(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
BenchmarkType* result = nullptr; BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactory(&result); Status status = factory.ArgumentFactory(&result);
if (status.ok() && result != nullptr) { if (status.ok() && result != nullptr) {
@ -582,11 +576,9 @@ BENCHMARK(BM_ArgumentFactory);
// Measure the time to use the StatusOr<T*> factory, evaluate the result, // Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function. // and invoke the trivial function.
void BM_StatusOrFactory(int iters) { void BM_StatusOrFactory(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactory(); StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
if (result.ok()) { if (result.ok()) {
result.ValueOrDie()->DoWork(); result.ValueOrDie()->DoWork();
@ -598,11 +590,9 @@ BENCHMARK(BM_StatusOrFactory);
// Measure the time taken to call into the factory, providing an // Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the // out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function. // result pointer, and invoking the trivial function.
void BM_ArgumentFactoryFail(int iters) { void BM_ArgumentFactoryFail(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
BenchmarkType* result = nullptr; BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFail(&result); Status status = factory.ArgumentFactoryFail(&result);
if (status.ok() && result != nullptr) { if (status.ok() && result != nullptr) {
@ -614,11 +604,9 @@ BENCHMARK(BM_ArgumentFactoryFail);
// Measure the time to use the StatusOr<T*> factory, evaluate the result, // Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function. // and invoke the trivial function.
void BM_StatusOrFactoryFail(int iters) { void BM_StatusOrFactoryFail(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail(); StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
if (result.ok()) { if (result.ok()) {
result.ValueOrDie()->DoWork(); result.ValueOrDie()->DoWork();
@ -630,11 +618,9 @@ BENCHMARK(BM_StatusOrFactoryFail);
// Measure the time taken to call into the factory, providing an // Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the // out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function. // result pointer, and invoking the trivial function.
void BM_ArgumentFactoryFailShortMsg(int iters) { void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
BenchmarkType* result = nullptr; BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFailShortMsg(&result); Status status = factory.ArgumentFactoryFailShortMsg(&result);
if (status.ok() && result != nullptr) { if (status.ok() && result != nullptr) {
@ -646,11 +632,9 @@ BENCHMARK(BM_ArgumentFactoryFailShortMsg);
// Measure the time to use the StatusOr<T*> factory, evaluate the result, // Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function. // and invoke the trivial function.
void BM_StatusOrFactoryFailShortMsg(int iters) { void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg(); StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
if (result.ok()) { if (result.ok()) {
result.ValueOrDie()->DoWork(); result.ValueOrDie()->DoWork();
@ -662,11 +646,9 @@ BENCHMARK(BM_StatusOrFactoryFailShortMsg);
// Measure the time taken to call into the factory, providing an // Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the // out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function. // result pointer, and invoking the trivial function.
void BM_ArgumentFactoryFailLongMsg(int iters) { void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
BenchmarkType* result = nullptr; BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFailLongMsg(&result); Status status = factory.ArgumentFactoryFailLongMsg(&result);
if (status.ok() && result != nullptr) { if (status.ok() && result != nullptr) {
@ -678,11 +660,9 @@ BENCHMARK(BM_ArgumentFactoryFailLongMsg);
// Measure the time to use the StatusOr<T*> factory, evaluate the result, // Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function. // and invoke the trivial function.
void BM_StatusOrFactoryFailLongMsg(int iters) { void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) {
tensorflow::testing::StopTiming();
BenchmarkFactory<BenchmarkType> factory; BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming(); for (auto s : state) {
for (int i = 0; i != iters; ++i) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg(); StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
if (result.ok()) { if (result.ok()) {
result.ValueOrDie()->DoWork(); result.ValueOrDie()->DoWork();