internal tests cleanup
PiperOrigin-RevId: 340568129 Change-Id: I521739fa73d00d4ef75a61556e54dc0c0805c2a9
This commit is contained in:
parent
2f960b4bc7
commit
cce04e5308
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
static void BM_DisabledVlog(int iters) {
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
static void BM_DisabledVlog(::testing::benchmark::State& state) {
|
||||
for (auto s : state) {
|
||||
VLOG(1) << "Testing VLOG(1)!";
|
||||
}
|
||||
}
|
||||
|
@ -673,15 +673,17 @@ TEST(BCastTest, BatchIndices) {
|
||||
BCastBatchIndices({3, 1}, {2, 1, 2}));
|
||||
}
|
||||
|
||||
static void BM_BCastSetup(int iters, int same_shape) {
|
||||
void BM_BCastSetup(::testing::benchmark::State& state) {
|
||||
const int same_shape = state.range(0);
|
||||
|
||||
if (same_shape) {
|
||||
testing::SetLabel("same_shapes");
|
||||
while (--iters > 0) {
|
||||
state.SetLabel("same_shapes");
|
||||
for (auto s : state) {
|
||||
class BCast b({1000, 100}, {1000, 100});
|
||||
}
|
||||
} else {
|
||||
testing::SetLabel("different_shapes");
|
||||
while (--iters > 0) {
|
||||
state.SetLabel("different_shapes");
|
||||
for (auto s : state) {
|
||||
class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
|
||||
}
|
||||
}
|
||||
|
@ -572,9 +572,9 @@ TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) {
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_ParseFullName(int iters) {
|
||||
static void BM_ParseFullName(::testing::benchmark::State& state) {
|
||||
DeviceNameUtils::ParsedName p;
|
||||
while (iters--) {
|
||||
for (auto s : state) {
|
||||
DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
|
||||
}
|
||||
}
|
||||
|
@ -164,13 +164,13 @@ static void CalculateKeys(uint64 num, std::vector<uint64> *dst) {
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_CuckooFill(int iters, int arg) {
|
||||
void BM_CuckooFill(::testing::benchmark::State &state) {
|
||||
const int arg = state.range(0);
|
||||
|
||||
uint64 table_size = arg;
|
||||
testing::StopTiming();
|
||||
std::vector<uint64> calculated_keys;
|
||||
CalculateKeys(table_size, &calculated_keys);
|
||||
testing::StartTiming();
|
||||
for (int iter = 0; iter < iters; iter++) {
|
||||
for (auto s : state) {
|
||||
PresizedCuckooMap<int> pscm(table_size);
|
||||
for (uint64 i = 0; i < table_size; i++) {
|
||||
pscm.InsertUnique(calculated_keys[i], i);
|
||||
@ -180,25 +180,27 @@ static void BM_CuckooFill(int iters, int arg) {
|
||||
|
||||
BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000);
|
||||
|
||||
static void BM_CuckooRead(int iters, int arg) {
|
||||
void BM_CuckooRead(::testing::benchmark::State &state) {
|
||||
const int arg = state.range(0);
|
||||
|
||||
uint64 table_size = arg;
|
||||
testing::StopTiming();
|
||||
std::vector<uint64> calculated_keys;
|
||||
CalculateKeys(table_size, &calculated_keys);
|
||||
PresizedCuckooMap<int> pscm(table_size);
|
||||
for (uint64 i = 0; i < table_size; i++) {
|
||||
pscm.InsertUnique(calculated_keys[i], i);
|
||||
}
|
||||
testing::StartTiming();
|
||||
uint64_t defeat_optimization = 0;
|
||||
for (int i = 0; i < iters; i++) {
|
||||
uint64 key_index = i % table_size; // May slow down bench!
|
||||
|
||||
int i = 0;
|
||||
for (auto s : state) {
|
||||
// Avoid using '%', which is expensive.
|
||||
uint64 key_index = i;
|
||||
++i;
|
||||
if (i == table_size) i = 0;
|
||||
|
||||
int out = 0;
|
||||
pscm.Find(calculated_keys[key_index], &out);
|
||||
defeat_optimization += out;
|
||||
}
|
||||
if (defeat_optimization == 0) {
|
||||
printf("Preventing the compiler from eliding the inner loop\n");
|
||||
tensorflow::testing::DoNotOptimize(out);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1109,9 +1109,8 @@ TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_BundleAlignmentByteOff(int iters, int alignment,
|
||||
int tensor_size) {
|
||||
testing::StopTiming();
|
||||
static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
|
||||
int alignment, int tensor_size) {
|
||||
{
|
||||
BundleWriter::Options opts;
|
||||
opts.data_alignment = alignment;
|
||||
@ -1122,18 +1121,17 @@ static void BM_BundleAlignmentByteOff(int iters, int alignment,
|
||||
}
|
||||
BundleReader reader(Env::Default(), Prefix("foo"));
|
||||
TF_CHECK_OK(reader.status());
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
for (auto s : state) {
|
||||
Tensor t;
|
||||
TF_CHECK_OK(reader.Lookup("big", &t));
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
#define BM_BundleAlignment(ALIGN, SIZE) \
|
||||
static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
|
||||
BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \
|
||||
} \
|
||||
#define BM_BundleAlignment(ALIGN, SIZE) \
|
||||
static void BM_BundleAlignment_##ALIGN##_##SIZE( \
|
||||
::testing::benchmark::State& state) { \
|
||||
BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
|
||||
} \
|
||||
BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
|
||||
|
||||
BM_BundleAlignment(1, 512);
|
||||
|
@ -89,12 +89,14 @@ TEST(Shard, OverflowTest) {
|
||||
}
|
||||
}
|
||||
|
||||
void BM_Sharding(int iters, int arg) {
|
||||
void BM_Sharding(::testing::benchmark::State& state) {
|
||||
const int arg = state.range(0);
|
||||
|
||||
thread::ThreadPool threads(Env::Default(), "test", 16);
|
||||
const int64 total = 1LL << 30;
|
||||
auto lambda = [](int64 start, int64 limit) {};
|
||||
auto work = std::cref(lambda);
|
||||
for (; iters > 0; iters -= arg) {
|
||||
for (auto s : state) {
|
||||
Shard(arg - 1, &threads, total, 1, work);
|
||||
}
|
||||
}
|
||||
|
@ -535,12 +535,10 @@ class BenchmarkType {
|
||||
|
||||
// Calibrate the amount of time spent just calling DoWork, since each of our
|
||||
// tests will do this, we can subtract this out of benchmark results.
|
||||
void BM_CalibrateWorkLoop(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_CalibrateWorkLoop(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
BenchmarkType* result = factory.TrivialFactory();
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
if (result != nullptr) {
|
||||
result->DoWork();
|
||||
}
|
||||
@ -550,11 +548,9 @@ BENCHMARK(BM_CalibrateWorkLoop);
|
||||
|
||||
// Measure the time taken to call into the factory, return the value,
|
||||
// determine that it is OK, and invoke a trivial function.
|
||||
void BM_TrivialFactory(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_TrivialFactory(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
BenchmarkType* result = factory.TrivialFactory();
|
||||
if (result != nullptr) {
|
||||
result->DoWork();
|
||||
@ -566,11 +562,9 @@ BENCHMARK(BM_TrivialFactory);
|
||||
// Measure the time taken to call into the factory, providing an
|
||||
// out-param for the result, evaluating the status result and the
|
||||
// result pointer, and invoking the trivial function.
|
||||
void BM_ArgumentFactory(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_ArgumentFactory(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
BenchmarkType* result = nullptr;
|
||||
Status status = factory.ArgumentFactory(&result);
|
||||
if (status.ok() && result != nullptr) {
|
||||
@ -582,11 +576,9 @@ BENCHMARK(BM_ArgumentFactory);
|
||||
|
||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||
// and invoke the trivial function.
|
||||
void BM_StatusOrFactory(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_StatusOrFactory(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
|
||||
if (result.ok()) {
|
||||
result.ValueOrDie()->DoWork();
|
||||
@ -598,11 +590,9 @@ BENCHMARK(BM_StatusOrFactory);
|
||||
// Measure the time taken to call into the factory, providing an
|
||||
// out-param for the result, evaluating the status result and the
|
||||
// result pointer, and invoking the trivial function.
|
||||
void BM_ArgumentFactoryFail(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_ArgumentFactoryFail(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
BenchmarkType* result = nullptr;
|
||||
Status status = factory.ArgumentFactoryFail(&result);
|
||||
if (status.ok() && result != nullptr) {
|
||||
@ -614,11 +604,9 @@ BENCHMARK(BM_ArgumentFactoryFail);
|
||||
|
||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||
// and invoke the trivial function.
|
||||
void BM_StatusOrFactoryFail(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_StatusOrFactoryFail(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
|
||||
if (result.ok()) {
|
||||
result.ValueOrDie()->DoWork();
|
||||
@ -630,11 +618,9 @@ BENCHMARK(BM_StatusOrFactoryFail);
|
||||
// Measure the time taken to call into the factory, providing an
|
||||
// out-param for the result, evaluating the status result and the
|
||||
// result pointer, and invoking the trivial function.
|
||||
void BM_ArgumentFactoryFailShortMsg(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
BenchmarkType* result = nullptr;
|
||||
Status status = factory.ArgumentFactoryFailShortMsg(&result);
|
||||
if (status.ok() && result != nullptr) {
|
||||
@ -646,11 +632,9 @@ BENCHMARK(BM_ArgumentFactoryFailShortMsg);
|
||||
|
||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||
// and invoke the trivial function.
|
||||
void BM_StatusOrFactoryFailShortMsg(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
|
||||
if (result.ok()) {
|
||||
result.ValueOrDie()->DoWork();
|
||||
@ -662,11 +646,9 @@ BENCHMARK(BM_StatusOrFactoryFailShortMsg);
|
||||
// Measure the time taken to call into the factory, providing an
|
||||
// out-param for the result, evaluating the status result and the
|
||||
// result pointer, and invoking the trivial function.
|
||||
void BM_ArgumentFactoryFailLongMsg(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
BenchmarkType* result = nullptr;
|
||||
Status status = factory.ArgumentFactoryFailLongMsg(&result);
|
||||
if (status.ok() && result != nullptr) {
|
||||
@ -678,11 +660,9 @@ BENCHMARK(BM_ArgumentFactoryFailLongMsg);
|
||||
|
||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||
// and invoke the trivial function.
|
||||
void BM_StatusOrFactoryFailLongMsg(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) {
|
||||
BenchmarkFactory<BenchmarkType> factory;
|
||||
tensorflow::testing::StartTiming();
|
||||
for (int i = 0; i != iters; ++i) {
|
||||
for (auto s : state) {
|
||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
|
||||
if (result.ok()) {
|
||||
result.ValueOrDie()->DoWork();
|
||||
|
Loading…
Reference in New Issue
Block a user