Internal tests cleanup.

PiperOrigin-RevId: 339315357
Change-Id: I0245a2c37991f3a34c13828217bd0f4cce177705
This commit is contained in:
A. Unique TensorFlower 2020-10-27 13:19:28 -07:00 committed by TensorFlower Gardener
parent 44baa200d2
commit dde0fe783c
4 changed files with 53 additions and 49 deletions

View File

@ -95,10 +95,10 @@ TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
void BM_TrackUnique(int iters) {
static void BM_TrackUnique(::testing::benchmark::State& state) {
RecentRequestIds recent_request_ids(100000);
RecvTensorRequest request;
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(),
"BM_TrackUnique", request));
}

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@ -134,43 +135,41 @@ TEST(GrpcProto, ParseFromString) {
}
}
static void BM_UnparseGrpc(int iters, int size) {
testing::StopTiming();
static void BM_UnparseGrpc(::testing::benchmark::State& state) {
const int size = state.range(0);
auto proto = MakeProto(size);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto s : state) {
grpc::ByteBuffer buf;
CHECK(GrpcMaybeUnparseProto(proto, &buf).ok());
}
testing::StopTiming();
}
BENCHMARK(BM_UnparseGrpc)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
static void BM_UnparseString(int iters, int size) {
testing::StopTiming();
static void BM_UnparseString(::testing::benchmark::State& state) {
const int size = state.range(0);
auto proto = MakeProto(size);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto s : state) {
string buf;
proto.SerializeToString(&buf);
}
testing::StopTiming();
}
BENCHMARK(BM_UnparseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
static void BM_ParseGrpc(int iters, int size, int num_slices) {
testing::StopTiming();
static void BM_ParseGrpc(::testing::benchmark::State& state) {
const int size = state.range(0);
const int num_slices = state.range(1);
CleanupAllRequest proto = MakeProto(size);
auto buf = MakeBuffer(proto.SerializeAsString(), num_slices);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto s : state) {
CHECK(GrpcMaybeParseProto(&buf, &proto));
}
testing::StopTiming();
}
BENCHMARK(BM_ParseGrpc)
->ArgPair(1, 1)
@ -179,17 +178,16 @@ BENCHMARK(BM_ParseGrpc)
->ArgPair(1 << 20, 1)
->ArgPair(1 << 20, 4);
static void BM_ParseString(int iters, int size) {
testing::StopTiming();
static void BM_ParseString(::testing::benchmark::State& state) {
const int size = state.range(0);
CleanupAllRequest proto = MakeProto(size);
string serial = proto.SerializeAsString();
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto s : state) {
CHECK(proto.ParseFromString(serial));
}
testing::StopTiming();
}
BENCHMARK(BM_ParseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);

View File

@ -166,9 +166,9 @@ string DebugString(const Tensor& x, const Tensor& y, int tensor_size) {
}
// TODO: Support sharding and depth.
static void BM_Helper(int iters, int width, int num_stages, int tensor_size,
static void BM_Helper(::testing::benchmark::State& state, int width,
int num_stages, int tensor_size,
bool use_multiple_devices) {
testing::StopTiming();
const Cluster* cluster = GetCluster();
// Creates a session.
@ -203,17 +203,18 @@ static void BM_Helper(int iters, int width, int num_stages, int tensor_size,
}
// Iterations.
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto s : state) {
outputs.clear();
TF_CHECK_OK(session->Run({{"x", x}}, {"y:0"}, {}, &outputs));
CHECK_EQ(size_t{1}, outputs.size());
}
testing::StopTiming();
TF_CHECK_OK(session->Close());
}
static void BM_ShardedProgram(int iters, int width, int num_stages) {
BM_Helper(iters, width, num_stages, 2 /*tensor_size*/, true /*multi-device*/);
static void BM_ShardedProgram(::testing::benchmark::State& state) {
const int width = state.range(0);
const int num_stages = state.range(1);
BM_Helper(state, width, num_stages, 2 /*tensor_size*/, true /*multi-device*/);
}
BENCHMARK(BM_ShardedProgram)
->ArgPair(1, 1)
@ -232,13 +233,19 @@ BENCHMARK(BM_ShardedProgram)
->ArgPair(60, 3)
->ArgPair(60, 5);
static void BM_RPC(int iters, int width, int tensor_size) {
BM_Helper(iters, width, 2 /*num_stages*/, tensor_size, true /*multi-device*/);
static void BM_RPC(::testing::benchmark::State& state) {
const int width = state.range(0);
const int tensor_size = state.range(1);
BM_Helper(state, width, 2 /*num_stages*/, tensor_size, true /*multi-device*/);
}
BENCHMARK(BM_RPC)->ArgPair(30, 2)->ArgPair(30, 1000)->ArgPair(30, 100000);
static void BM_SingleDevice(int iters, int width, int num_stages) {
BM_Helper(iters, width, num_stages, 2 /*tensor_size*/,
static void BM_SingleDevice(::testing::benchmark::State& state) {
const int width = state.range(0);
const int num_stages = state.range(1);
BM_Helper(state, width, num_stages, 2 /*tensor_size*/,
false /*not multi-device*/);
}
BENCHMARK(BM_SingleDevice)

View File

@ -173,37 +173,36 @@ string MakeFloatTensorTestCase(int num_elems) {
return encoded;
}
static void BM_TensorResponse(int iters, int arg) {
testing::StopTiming();
static void BM_TensorResponse(::testing::benchmark::State& state) {
const int arg = state.range(0);
string encoded = MakeFloatTensorTestCase(arg);
DummyDevice cpu_device(Env::Default());
testing::StartTiming();
while (--iters > 0) {
size_t bytes = 0;
for (auto i : state) {
TensorResponse response;
response.InitAlloc(&cpu_device, AllocatorAttributes());
StringSource source(&encoded, -1);
Status s = response.ParseFrom(&source);
if (iters == 1) {
testing::SetLabel(
strings::StrCat("Bytes: ", response.tensor().TotalBytes()));
}
bytes = response.tensor().TotalBytes();
}
state.SetLabel(strings::StrCat("Bytes: ", bytes));
}
BENCHMARK(BM_TensorResponse)->Arg(0)->Arg(1000)->Arg(100000);
static void BM_TensorViaTensorProto(int iters, int arg) {
testing::StopTiming();
string encoded = MakeFloatTensorTestCase(arg);
testing::StartTiming();
while (--iters > 0) {
static void BM_TensorViaTensorProto(::testing::benchmark::State& state) {
const int arg = state.range(0);
std::string encoded = MakeFloatTensorTestCase(arg);
size_t bytes = 0;
for (auto s : state) {
RecvTensorResponse r;
r.ParseFromString(encoded);
Tensor t;
CHECK(t.FromProto(r.tensor()));
if (iters == 1) {
testing::SetLabel(strings::StrCat("Bytes: ", t.TotalBytes()));
}
bytes = t.TotalBytes();
}
state.SetLabel(strings::StrCat("Bytes: ", bytes));
}
BENCHMARK(BM_TensorViaTensorProto)->Arg(0)->Arg(1000)->Arg(100000);