internal tests cleanup
PiperOrigin-RevId: 340568129 Change-Id: I521739fa73d00d4ef75a61556e54dc0c0805c2a9
This commit is contained in:
parent
2f960b4bc7
commit
cce04e5308
@ -18,8 +18,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
static void BM_DisabledVlog(int iters) {
|
static void BM_DisabledVlog(::testing::benchmark::State& state) {
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (auto s : state) {
|
||||||
VLOG(1) << "Testing VLOG(1)!";
|
VLOG(1) << "Testing VLOG(1)!";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -673,15 +673,17 @@ TEST(BCastTest, BatchIndices) {
|
|||||||
BCastBatchIndices({3, 1}, {2, 1, 2}));
|
BCastBatchIndices({3, 1}, {2, 1, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_BCastSetup(int iters, int same_shape) {
|
void BM_BCastSetup(::testing::benchmark::State& state) {
|
||||||
|
const int same_shape = state.range(0);
|
||||||
|
|
||||||
if (same_shape) {
|
if (same_shape) {
|
||||||
testing::SetLabel("same_shapes");
|
state.SetLabel("same_shapes");
|
||||||
while (--iters > 0) {
|
for (auto s : state) {
|
||||||
class BCast b({1000, 100}, {1000, 100});
|
class BCast b({1000, 100}, {1000, 100});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
testing::SetLabel("different_shapes");
|
state.SetLabel("different_shapes");
|
||||||
while (--iters > 0) {
|
for (auto s : state) {
|
||||||
class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
|
class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -572,9 +572,9 @@ TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_ParseFullName(int iters) {
|
static void BM_ParseFullName(::testing::benchmark::State& state) {
|
||||||
DeviceNameUtils::ParsedName p;
|
DeviceNameUtils::ParsedName p;
|
||||||
while (iters--) {
|
for (auto s : state) {
|
||||||
DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
|
DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -164,13 +164,13 @@ static void CalculateKeys(uint64 num, std::vector<uint64> *dst) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_CuckooFill(int iters, int arg) {
|
void BM_CuckooFill(::testing::benchmark::State &state) {
|
||||||
|
const int arg = state.range(0);
|
||||||
|
|
||||||
uint64 table_size = arg;
|
uint64 table_size = arg;
|
||||||
testing::StopTiming();
|
|
||||||
std::vector<uint64> calculated_keys;
|
std::vector<uint64> calculated_keys;
|
||||||
CalculateKeys(table_size, &calculated_keys);
|
CalculateKeys(table_size, &calculated_keys);
|
||||||
testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int iter = 0; iter < iters; iter++) {
|
|
||||||
PresizedCuckooMap<int> pscm(table_size);
|
PresizedCuckooMap<int> pscm(table_size);
|
||||||
for (uint64 i = 0; i < table_size; i++) {
|
for (uint64 i = 0; i < table_size; i++) {
|
||||||
pscm.InsertUnique(calculated_keys[i], i);
|
pscm.InsertUnique(calculated_keys[i], i);
|
||||||
@ -180,25 +180,27 @@ static void BM_CuckooFill(int iters, int arg) {
|
|||||||
|
|
||||||
BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000);
|
BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000);
|
||||||
|
|
||||||
static void BM_CuckooRead(int iters, int arg) {
|
void BM_CuckooRead(::testing::benchmark::State &state) {
|
||||||
|
const int arg = state.range(0);
|
||||||
|
|
||||||
uint64 table_size = arg;
|
uint64 table_size = arg;
|
||||||
testing::StopTiming();
|
|
||||||
std::vector<uint64> calculated_keys;
|
std::vector<uint64> calculated_keys;
|
||||||
CalculateKeys(table_size, &calculated_keys);
|
CalculateKeys(table_size, &calculated_keys);
|
||||||
PresizedCuckooMap<int> pscm(table_size);
|
PresizedCuckooMap<int> pscm(table_size);
|
||||||
for (uint64 i = 0; i < table_size; i++) {
|
for (uint64 i = 0; i < table_size; i++) {
|
||||||
pscm.InsertUnique(calculated_keys[i], i);
|
pscm.InsertUnique(calculated_keys[i], i);
|
||||||
}
|
}
|
||||||
testing::StartTiming();
|
|
||||||
uint64_t defeat_optimization = 0;
|
int i = 0;
|
||||||
for (int i = 0; i < iters; i++) {
|
for (auto s : state) {
|
||||||
uint64 key_index = i % table_size; // May slow down bench!
|
// Avoid using '%', which is expensive.
|
||||||
|
uint64 key_index = i;
|
||||||
|
++i;
|
||||||
|
if (i == table_size) i = 0;
|
||||||
|
|
||||||
int out = 0;
|
int out = 0;
|
||||||
pscm.Find(calculated_keys[key_index], &out);
|
pscm.Find(calculated_keys[key_index], &out);
|
||||||
defeat_optimization += out;
|
tensorflow::testing::DoNotOptimize(out);
|
||||||
}
|
|
||||||
if (defeat_optimization == 0) {
|
|
||||||
printf("Preventing the compiler from eliding the inner loop\n");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1109,9 +1109,8 @@ TEST_F(TensorBundleAlignmentTest, AlignmentTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_BundleAlignmentByteOff(int iters, int alignment,
|
static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
|
||||||
int tensor_size) {
|
int alignment, int tensor_size) {
|
||||||
testing::StopTiming();
|
|
||||||
{
|
{
|
||||||
BundleWriter::Options opts;
|
BundleWriter::Options opts;
|
||||||
opts.data_alignment = alignment;
|
opts.data_alignment = alignment;
|
||||||
@ -1122,18 +1121,17 @@ static void BM_BundleAlignmentByteOff(int iters, int alignment,
|
|||||||
}
|
}
|
||||||
BundleReader reader(Env::Default(), Prefix("foo"));
|
BundleReader reader(Env::Default(), Prefix("foo"));
|
||||||
TF_CHECK_OK(reader.status());
|
TF_CHECK_OK(reader.status());
|
||||||
testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i < iters; ++i) {
|
|
||||||
Tensor t;
|
Tensor t;
|
||||||
TF_CHECK_OK(reader.Lookup("big", &t));
|
TF_CHECK_OK(reader.Lookup("big", &t));
|
||||||
}
|
}
|
||||||
testing::StopTiming();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define BM_BundleAlignment(ALIGN, SIZE) \
|
#define BM_BundleAlignment(ALIGN, SIZE) \
|
||||||
static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
|
static void BM_BundleAlignment_##ALIGN##_##SIZE( \
|
||||||
BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \
|
::testing::benchmark::State& state) { \
|
||||||
} \
|
BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
|
||||||
|
} \
|
||||||
BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
|
BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
|
||||||
|
|
||||||
BM_BundleAlignment(1, 512);
|
BM_BundleAlignment(1, 512);
|
||||||
|
@ -89,12 +89,14 @@ TEST(Shard, OverflowTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void BM_Sharding(int iters, int arg) {
|
void BM_Sharding(::testing::benchmark::State& state) {
|
||||||
|
const int arg = state.range(0);
|
||||||
|
|
||||||
thread::ThreadPool threads(Env::Default(), "test", 16);
|
thread::ThreadPool threads(Env::Default(), "test", 16);
|
||||||
const int64 total = 1LL << 30;
|
const int64 total = 1LL << 30;
|
||||||
auto lambda = [](int64 start, int64 limit) {};
|
auto lambda = [](int64 start, int64 limit) {};
|
||||||
auto work = std::cref(lambda);
|
auto work = std::cref(lambda);
|
||||||
for (; iters > 0; iters -= arg) {
|
for (auto s : state) {
|
||||||
Shard(arg - 1, &threads, total, 1, work);
|
Shard(arg - 1, &threads, total, 1, work);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -535,12 +535,10 @@ class BenchmarkType {
|
|||||||
|
|
||||||
// Calibrate the amount of time spent just calling DoWork, since each of our
|
// Calibrate the amount of time spent just calling DoWork, since each of our
|
||||||
// tests will do this, we can subtract this out of benchmark results.
|
// tests will do this, we can subtract this out of benchmark results.
|
||||||
void BM_CalibrateWorkLoop(int iters) {
|
void BM_CalibrateWorkLoop(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
BenchmarkType* result = factory.TrivialFactory();
|
BenchmarkType* result = factory.TrivialFactory();
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
if (result != nullptr) {
|
if (result != nullptr) {
|
||||||
result->DoWork();
|
result->DoWork();
|
||||||
}
|
}
|
||||||
@ -550,11 +548,9 @@ BENCHMARK(BM_CalibrateWorkLoop);
|
|||||||
|
|
||||||
// Measure the time taken to call into the factory, return the value,
|
// Measure the time taken to call into the factory, return the value,
|
||||||
// determine that it is OK, and invoke a trivial function.
|
// determine that it is OK, and invoke a trivial function.
|
||||||
void BM_TrivialFactory(int iters) {
|
void BM_TrivialFactory(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
BenchmarkType* result = factory.TrivialFactory();
|
BenchmarkType* result = factory.TrivialFactory();
|
||||||
if (result != nullptr) {
|
if (result != nullptr) {
|
||||||
result->DoWork();
|
result->DoWork();
|
||||||
@ -566,11 +562,9 @@ BENCHMARK(BM_TrivialFactory);
|
|||||||
// Measure the time taken to call into the factory, providing an
|
// Measure the time taken to call into the factory, providing an
|
||||||
// out-param for the result, evaluating the status result and the
|
// out-param for the result, evaluating the status result and the
|
||||||
// result pointer, and invoking the trivial function.
|
// result pointer, and invoking the trivial function.
|
||||||
void BM_ArgumentFactory(int iters) {
|
void BM_ArgumentFactory(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
BenchmarkType* result = nullptr;
|
BenchmarkType* result = nullptr;
|
||||||
Status status = factory.ArgumentFactory(&result);
|
Status status = factory.ArgumentFactory(&result);
|
||||||
if (status.ok() && result != nullptr) {
|
if (status.ok() && result != nullptr) {
|
||||||
@ -582,11 +576,9 @@ BENCHMARK(BM_ArgumentFactory);
|
|||||||
|
|
||||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||||
// and invoke the trivial function.
|
// and invoke the trivial function.
|
||||||
void BM_StatusOrFactory(int iters) {
|
void BM_StatusOrFactory(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
|
StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
|
||||||
if (result.ok()) {
|
if (result.ok()) {
|
||||||
result.ValueOrDie()->DoWork();
|
result.ValueOrDie()->DoWork();
|
||||||
@ -598,11 +590,9 @@ BENCHMARK(BM_StatusOrFactory);
|
|||||||
// Measure the time taken to call into the factory, providing an
|
// Measure the time taken to call into the factory, providing an
|
||||||
// out-param for the result, evaluating the status result and the
|
// out-param for the result, evaluating the status result and the
|
||||||
// result pointer, and invoking the trivial function.
|
// result pointer, and invoking the trivial function.
|
||||||
void BM_ArgumentFactoryFail(int iters) {
|
void BM_ArgumentFactoryFail(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
BenchmarkType* result = nullptr;
|
BenchmarkType* result = nullptr;
|
||||||
Status status = factory.ArgumentFactoryFail(&result);
|
Status status = factory.ArgumentFactoryFail(&result);
|
||||||
if (status.ok() && result != nullptr) {
|
if (status.ok() && result != nullptr) {
|
||||||
@ -614,11 +604,9 @@ BENCHMARK(BM_ArgumentFactoryFail);
|
|||||||
|
|
||||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||||
// and invoke the trivial function.
|
// and invoke the trivial function.
|
||||||
void BM_StatusOrFactoryFail(int iters) {
|
void BM_StatusOrFactoryFail(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
|
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
|
||||||
if (result.ok()) {
|
if (result.ok()) {
|
||||||
result.ValueOrDie()->DoWork();
|
result.ValueOrDie()->DoWork();
|
||||||
@ -630,11 +618,9 @@ BENCHMARK(BM_StatusOrFactoryFail);
|
|||||||
// Measure the time taken to call into the factory, providing an
|
// Measure the time taken to call into the factory, providing an
|
||||||
// out-param for the result, evaluating the status result and the
|
// out-param for the result, evaluating the status result and the
|
||||||
// result pointer, and invoking the trivial function.
|
// result pointer, and invoking the trivial function.
|
||||||
void BM_ArgumentFactoryFailShortMsg(int iters) {
|
void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
BenchmarkType* result = nullptr;
|
BenchmarkType* result = nullptr;
|
||||||
Status status = factory.ArgumentFactoryFailShortMsg(&result);
|
Status status = factory.ArgumentFactoryFailShortMsg(&result);
|
||||||
if (status.ok() && result != nullptr) {
|
if (status.ok() && result != nullptr) {
|
||||||
@ -646,11 +632,9 @@ BENCHMARK(BM_ArgumentFactoryFailShortMsg);
|
|||||||
|
|
||||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||||
// and invoke the trivial function.
|
// and invoke the trivial function.
|
||||||
void BM_StatusOrFactoryFailShortMsg(int iters) {
|
void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
|
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
|
||||||
if (result.ok()) {
|
if (result.ok()) {
|
||||||
result.ValueOrDie()->DoWork();
|
result.ValueOrDie()->DoWork();
|
||||||
@ -662,11 +646,9 @@ BENCHMARK(BM_StatusOrFactoryFailShortMsg);
|
|||||||
// Measure the time taken to call into the factory, providing an
|
// Measure the time taken to call into the factory, providing an
|
||||||
// out-param for the result, evaluating the status result and the
|
// out-param for the result, evaluating the status result and the
|
||||||
// result pointer, and invoking the trivial function.
|
// result pointer, and invoking the trivial function.
|
||||||
void BM_ArgumentFactoryFailLongMsg(int iters) {
|
void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
BenchmarkType* result = nullptr;
|
BenchmarkType* result = nullptr;
|
||||||
Status status = factory.ArgumentFactoryFailLongMsg(&result);
|
Status status = factory.ArgumentFactoryFailLongMsg(&result);
|
||||||
if (status.ok() && result != nullptr) {
|
if (status.ok() && result != nullptr) {
|
||||||
@ -678,11 +660,9 @@ BENCHMARK(BM_ArgumentFactoryFailLongMsg);
|
|||||||
|
|
||||||
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
|
||||||
// and invoke the trivial function.
|
// and invoke the trivial function.
|
||||||
void BM_StatusOrFactoryFailLongMsg(int iters) {
|
void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) {
|
||||||
tensorflow::testing::StopTiming();
|
|
||||||
BenchmarkFactory<BenchmarkType> factory;
|
BenchmarkFactory<BenchmarkType> factory;
|
||||||
tensorflow::testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i != iters; ++i) {
|
|
||||||
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
|
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
|
||||||
if (result.ok()) {
|
if (result.ok()) {
|
||||||
result.ValueOrDie()->DoWork();
|
result.ValueOrDie()->DoWork();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user