internal tests cleanup

PiperOrigin-RevId: 340568129
Change-Id: I521739fa73d00d4ef75a61556e54dc0c0805c2a9
This commit is contained in:
A. Unique TensorFlower 2020-11-03 19:16:04 -08:00 committed by TensorFlower Gardener
parent 2f960b4bc7
commit cce04e5308
7 changed files with 59 additions and 75 deletions

View File

@ -18,8 +18,8 @@ limitations under the License.
namespace tensorflow {
static void BM_DisabledVlog(int iters) {
for (int i = 0; i < iters; ++i) {
static void BM_DisabledVlog(::testing::benchmark::State& state) {
for (auto s : state) {
VLOG(1) << "Testing VLOG(1)!";
}
}

View File

@ -673,15 +673,17 @@ TEST(BCastTest, BatchIndices) {
BCastBatchIndices({3, 1}, {2, 1, 2}));
}
static void BM_BCastSetup(int iters, int same_shape) {
void BM_BCastSetup(::testing::benchmark::State& state) {
const int same_shape = state.range(0);
if (same_shape) {
testing::SetLabel("same_shapes");
while (--iters > 0) {
state.SetLabel("same_shapes");
for (auto s : state) {
class BCast b({1000, 100}, {1000, 100});
}
} else {
testing::SetLabel("different_shapes");
while (--iters > 0) {
state.SetLabel("different_shapes");
for (auto s : state) {
class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
}
}

View File

@ -572,9 +572,9 @@ TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) {
}
}
static void BM_ParseFullName(int iters) {
static void BM_ParseFullName(::testing::benchmark::State& state) {
DeviceNameUtils::ParsedName p;
while (iters--) {
for (auto s : state) {
DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
}
}

View File

@ -164,13 +164,13 @@ static void CalculateKeys(uint64 num, std::vector<uint64> *dst) {
}
}
static void BM_CuckooFill(int iters, int arg) {
void BM_CuckooFill(::testing::benchmark::State &state) {
const int arg = state.range(0);
uint64 table_size = arg;
testing::StopTiming();
std::vector<uint64> calculated_keys;
CalculateKeys(table_size, &calculated_keys);
testing::StartTiming();
for (int iter = 0; iter < iters; iter++) {
for (auto s : state) {
PresizedCuckooMap<int> pscm(table_size);
for (uint64 i = 0; i < table_size; i++) {
pscm.InsertUnique(calculated_keys[i], i);
@ -180,25 +180,27 @@ static void BM_CuckooFill(int iters, int arg) {
BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000);
static void BM_CuckooRead(int iters, int arg) {
void BM_CuckooRead(::testing::benchmark::State &state) {
const int arg = state.range(0);
uint64 table_size = arg;
testing::StopTiming();
std::vector<uint64> calculated_keys;
CalculateKeys(table_size, &calculated_keys);
PresizedCuckooMap<int> pscm(table_size);
for (uint64 i = 0; i < table_size; i++) {
pscm.InsertUnique(calculated_keys[i], i);
}
testing::StartTiming();
uint64_t defeat_optimization = 0;
for (int i = 0; i < iters; i++) {
uint64 key_index = i % table_size; // May slow down bench!
int i = 0;
for (auto s : state) {
// Avoid using '%', which is expensive.
uint64 key_index = i;
++i;
if (i == table_size) i = 0;
int out = 0;
pscm.Find(calculated_keys[key_index], &out);
defeat_optimization += out;
}
if (defeat_optimization == 0) {
printf("Preventing the compiler from eliding the inner loop\n");
tensorflow::testing::DoNotOptimize(out);
}
}

View File

@ -1109,9 +1109,8 @@ TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
}
}
static void BM_BundleAlignmentByteOff(int iters, int alignment,
int tensor_size) {
testing::StopTiming();
static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
int alignment, int tensor_size) {
{
BundleWriter::Options opts;
opts.data_alignment = alignment;
@ -1122,18 +1121,17 @@ static void BM_BundleAlignmentByteOff(int iters, int alignment,
}
BundleReader reader(Env::Default(), Prefix("foo"));
TF_CHECK_OK(reader.status());
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
Tensor t;
TF_CHECK_OK(reader.Lookup("big", &t));
}
testing::StopTiming();
}
#define BM_BundleAlignment(ALIGN, SIZE) \
static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \
} \
#define BM_BundleAlignment(ALIGN, SIZE) \
static void BM_BundleAlignment_##ALIGN##_##SIZE( \
::testing::benchmark::State& state) { \
BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
} \
BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
BM_BundleAlignment(1, 512);

View File

@ -89,12 +89,14 @@ TEST(Shard, OverflowTest) {
}
}
void BM_Sharding(int iters, int arg) {
void BM_Sharding(::testing::benchmark::State& state) {
const int arg = state.range(0);
thread::ThreadPool threads(Env::Default(), "test", 16);
const int64 total = 1LL << 30;
auto lambda = [](int64 start, int64 limit) {};
auto work = std::cref(lambda);
for (; iters > 0; iters -= arg) {
for (auto s : state) {
Shard(arg - 1, &threads, total, 1, work);
}
}

View File

@ -535,12 +535,10 @@ class BenchmarkType {
// Calibrate the amount of time spent just calling DoWork, since each of our
// tests will do this, we can subtract this out of benchmark results.
void BM_CalibrateWorkLoop(int iters) {
tensorflow::testing::StopTiming();
void BM_CalibrateWorkLoop(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
BenchmarkType* result = factory.TrivialFactory();
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
if (result != nullptr) {
result->DoWork();
}
@ -550,11 +548,9 @@ BENCHMARK(BM_CalibrateWorkLoop);
// Measure the time taken to call into the factory, return the value,
// determine that it is OK, and invoke a trivial function.
void BM_TrivialFactory(int iters) {
tensorflow::testing::StopTiming();
void BM_TrivialFactory(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
BenchmarkType* result = factory.TrivialFactory();
if (result != nullptr) {
result->DoWork();
@ -566,11 +562,9 @@ BENCHMARK(BM_TrivialFactory);
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
void BM_ArgumentFactory(int iters) {
tensorflow::testing::StopTiming();
void BM_ArgumentFactory(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactory(&result);
if (status.ok() && result != nullptr) {
@ -582,11 +576,9 @@ BENCHMARK(BM_ArgumentFactory);
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
void BM_StatusOrFactory(int iters) {
tensorflow::testing::StopTiming();
void BM_StatusOrFactory(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
if (result.ok()) {
result.ValueOrDie()->DoWork();
@ -598,11 +590,9 @@ BENCHMARK(BM_StatusOrFactory);
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
void BM_ArgumentFactoryFail(int iters) {
tensorflow::testing::StopTiming();
void BM_ArgumentFactoryFail(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFail(&result);
if (status.ok() && result != nullptr) {
@ -614,11 +604,9 @@ BENCHMARK(BM_ArgumentFactoryFail);
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
void BM_StatusOrFactoryFail(int iters) {
tensorflow::testing::StopTiming();
void BM_StatusOrFactoryFail(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
if (result.ok()) {
result.ValueOrDie()->DoWork();
@ -630,11 +618,9 @@ BENCHMARK(BM_StatusOrFactoryFail);
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
void BM_ArgumentFactoryFailShortMsg(int iters) {
tensorflow::testing::StopTiming();
void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFailShortMsg(&result);
if (status.ok() && result != nullptr) {
@ -646,11 +632,9 @@ BENCHMARK(BM_ArgumentFactoryFailShortMsg);
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
void BM_StatusOrFactoryFailShortMsg(int iters) {
tensorflow::testing::StopTiming();
void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
if (result.ok()) {
result.ValueOrDie()->DoWork();
@ -662,11 +646,9 @@ BENCHMARK(BM_StatusOrFactoryFailShortMsg);
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
void BM_ArgumentFactoryFailLongMsg(int iters) {
tensorflow::testing::StopTiming();
void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFailLongMsg(&result);
if (status.ok() && result != nullptr) {
@ -678,11 +660,9 @@ BENCHMARK(BM_ArgumentFactoryFailLongMsg);
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
void BM_StatusOrFactoryFailLongMsg(int iters) {
tensorflow::testing::StopTiming();
void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
tensorflow::testing::StartTiming();
for (int i = 0; i != iters; ++i) {
for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
if (result.ok()) {
result.ValueOrDie()->DoWork();