Internal tests cleanup.

PiperOrigin-RevId: 339315357
Change-Id: I0245a2c37991f3a34c13828217bd0f4cce177705
This commit is contained in:
A. Unique TensorFlower 2020-10-27 13:19:28 -07:00 committed by TensorFlower Gardener
parent 44baa200d2
commit dde0fe783c
4 changed files with 53 additions and 49 deletions

View File

@ -95,10 +95,10 @@ TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
TEST(RecentRequestIds, Ordered4) { TestOrdered(4); } TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
TEST(RecentRequestIds, Ordered5) { TestOrdered(5); } TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
void BM_TrackUnique(int iters) { static void BM_TrackUnique(::testing::benchmark::State& state) {
RecentRequestIds recent_request_ids(100000); RecentRequestIds recent_request_ids(100000);
RecvTensorRequest request; RecvTensorRequest request;
for (int i = 0; i < iters; ++i) { for (auto s : state) {
TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(), TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(),
"BM_TrackUnique", request)); "BM_TrackUnique", request));
} }

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/protobuf/worker.pb.h"
@ -134,43 +135,41 @@ TEST(GrpcProto, ParseFromString) {
} }
} }
static void BM_UnparseGrpc(int iters, int size) { static void BM_UnparseGrpc(::testing::benchmark::State& state) {
testing::StopTiming(); const int size = state.range(0);
auto proto = MakeProto(size); auto proto = MakeProto(size);
testing::StartTiming(); for (auto s : state) {
for (int i = 0; i < iters; i++) {
grpc::ByteBuffer buf; grpc::ByteBuffer buf;
CHECK(GrpcMaybeUnparseProto(proto, &buf).ok()); CHECK(GrpcMaybeUnparseProto(proto, &buf).ok());
} }
testing::StopTiming();
} }
BENCHMARK(BM_UnparseGrpc)->Arg(1)->Arg(1 << 10)->Arg(1 << 20); BENCHMARK(BM_UnparseGrpc)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
static void BM_UnparseString(int iters, int size) { static void BM_UnparseString(::testing::benchmark::State& state) {
testing::StopTiming(); const int size = state.range(0);
auto proto = MakeProto(size); auto proto = MakeProto(size);
testing::StartTiming(); testing::StartTiming();
for (int i = 0; i < iters; i++) { for (auto s : state) {
string buf; string buf;
proto.SerializeToString(&buf); proto.SerializeToString(&buf);
} }
testing::StopTiming();
} }
BENCHMARK(BM_UnparseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20); BENCHMARK(BM_UnparseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
static void BM_ParseGrpc(int iters, int size, int num_slices) { static void BM_ParseGrpc(::testing::benchmark::State& state) {
testing::StopTiming(); const int size = state.range(0);
const int num_slices = state.range(1);
CleanupAllRequest proto = MakeProto(size); CleanupAllRequest proto = MakeProto(size);
auto buf = MakeBuffer(proto.SerializeAsString(), num_slices); auto buf = MakeBuffer(proto.SerializeAsString(), num_slices);
testing::StartTiming(); testing::StartTiming();
for (int i = 0; i < iters; i++) { for (auto s : state) {
CHECK(GrpcMaybeParseProto(&buf, &proto)); CHECK(GrpcMaybeParseProto(&buf, &proto));
} }
testing::StopTiming();
} }
BENCHMARK(BM_ParseGrpc) BENCHMARK(BM_ParseGrpc)
->ArgPair(1, 1) ->ArgPair(1, 1)
@ -179,17 +178,16 @@ BENCHMARK(BM_ParseGrpc)
->ArgPair(1 << 20, 1) ->ArgPair(1 << 20, 1)
->ArgPair(1 << 20, 4); ->ArgPair(1 << 20, 4);
static void BM_ParseString(int iters, int size) { static void BM_ParseString(::testing::benchmark::State& state) {
testing::StopTiming(); const int size = state.range(0);
CleanupAllRequest proto = MakeProto(size); CleanupAllRequest proto = MakeProto(size);
string serial = proto.SerializeAsString(); string serial = proto.SerializeAsString();
testing::StartTiming(); testing::StartTiming();
for (int i = 0; i < iters; i++) { for (auto s : state) {
CHECK(proto.ParseFromString(serial)); CHECK(proto.ParseFromString(serial));
} }
testing::StopTiming();
} }
BENCHMARK(BM_ParseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20); BENCHMARK(BM_ParseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);

View File

@ -166,9 +166,9 @@ string DebugString(const Tensor& x, const Tensor& y, int tensor_size) {
} }
// TODO: Support sharding and depth. // TODO: Support sharding and depth.
static void BM_Helper(int iters, int width, int num_stages, int tensor_size, static void BM_Helper(::testing::benchmark::State& state, int width,
int num_stages, int tensor_size,
bool use_multiple_devices) { bool use_multiple_devices) {
testing::StopTiming();
const Cluster* cluster = GetCluster(); const Cluster* cluster = GetCluster();
// Creates a session. // Creates a session.
@ -203,17 +203,18 @@ static void BM_Helper(int iters, int width, int num_stages, int tensor_size,
} }
// Iterations. // Iterations.
testing::StartTiming(); for (auto s : state) {
for (int i = 0; i < iters; i++) {
outputs.clear(); outputs.clear();
TF_CHECK_OK(session->Run({{"x", x}}, {"y:0"}, {}, &outputs)); TF_CHECK_OK(session->Run({{"x", x}}, {"y:0"}, {}, &outputs));
CHECK_EQ(size_t{1}, outputs.size()); CHECK_EQ(size_t{1}, outputs.size());
} }
testing::StopTiming();
TF_CHECK_OK(session->Close()); TF_CHECK_OK(session->Close());
} }
static void BM_ShardedProgram(int iters, int width, int num_stages) { static void BM_ShardedProgram(::testing::benchmark::State& state) {
BM_Helper(iters, width, num_stages, 2 /*tensor_size*/, true /*multi-device*/); const int width = state.range(0);
const int num_stages = state.range(1);
BM_Helper(state, width, num_stages, 2 /*tensor_size*/, true /*multi-device*/);
} }
BENCHMARK(BM_ShardedProgram) BENCHMARK(BM_ShardedProgram)
->ArgPair(1, 1) ->ArgPair(1, 1)
@ -232,13 +233,19 @@ BENCHMARK(BM_ShardedProgram)
->ArgPair(60, 3) ->ArgPair(60, 3)
->ArgPair(60, 5); ->ArgPair(60, 5);
static void BM_RPC(int iters, int width, int tensor_size) { static void BM_RPC(::testing::benchmark::State& state) {
BM_Helper(iters, width, 2 /*num_stages*/, tensor_size, true /*multi-device*/); const int width = state.range(0);
const int tensor_size = state.range(1);
BM_Helper(state, width, 2 /*num_stages*/, tensor_size, true /*multi-device*/);
} }
BENCHMARK(BM_RPC)->ArgPair(30, 2)->ArgPair(30, 1000)->ArgPair(30, 100000); BENCHMARK(BM_RPC)->ArgPair(30, 2)->ArgPair(30, 1000)->ArgPair(30, 100000);
static void BM_SingleDevice(int iters, int width, int num_stages) { static void BM_SingleDevice(::testing::benchmark::State& state) {
BM_Helper(iters, width, num_stages, 2 /*tensor_size*/, const int width = state.range(0);
const int num_stages = state.range(1);
BM_Helper(state, width, num_stages, 2 /*tensor_size*/,
false /*not multi-device*/); false /*not multi-device*/);
} }
BENCHMARK(BM_SingleDevice) BENCHMARK(BM_SingleDevice)

View File

@ -173,37 +173,36 @@ string MakeFloatTensorTestCase(int num_elems) {
return encoded; return encoded;
} }
static void BM_TensorResponse(int iters, int arg) { static void BM_TensorResponse(::testing::benchmark::State& state) {
testing::StopTiming(); const int arg = state.range(0);
string encoded = MakeFloatTensorTestCase(arg); string encoded = MakeFloatTensorTestCase(arg);
DummyDevice cpu_device(Env::Default()); DummyDevice cpu_device(Env::Default());
testing::StartTiming(); size_t bytes = 0;
while (--iters > 0) { for (auto i : state) {
TensorResponse response; TensorResponse response;
response.InitAlloc(&cpu_device, AllocatorAttributes()); response.InitAlloc(&cpu_device, AllocatorAttributes());
StringSource source(&encoded, -1); StringSource source(&encoded, -1);
Status s = response.ParseFrom(&source); Status s = response.ParseFrom(&source);
if (iters == 1) { bytes = response.tensor().TotalBytes();
testing::SetLabel(
strings::StrCat("Bytes: ", response.tensor().TotalBytes()));
}
} }
state.SetLabel(strings::StrCat("Bytes: ", bytes));
} }
BENCHMARK(BM_TensorResponse)->Arg(0)->Arg(1000)->Arg(100000); BENCHMARK(BM_TensorResponse)->Arg(0)->Arg(1000)->Arg(100000);
static void BM_TensorViaTensorProto(int iters, int arg) { static void BM_TensorViaTensorProto(::testing::benchmark::State& state) {
testing::StopTiming(); const int arg = state.range(0);
string encoded = MakeFloatTensorTestCase(arg);
testing::StartTiming(); std::string encoded = MakeFloatTensorTestCase(arg);
while (--iters > 0) { size_t bytes = 0;
for (auto s : state) {
RecvTensorResponse r; RecvTensorResponse r;
r.ParseFromString(encoded); r.ParseFromString(encoded);
Tensor t; Tensor t;
CHECK(t.FromProto(r.tensor())); CHECK(t.FromProto(r.tensor()));
if (iters == 1) { bytes = t.TotalBytes();
testing::SetLabel(strings::StrCat("Bytes: ", t.TotalBytes()));
}
} }
state.SetLabel(strings::StrCat("Bytes: ", bytes));
} }
BENCHMARK(BM_TensorViaTensorProto)->Arg(0)->Arg(1000)->Arg(100000); BENCHMARK(BM_TensorViaTensorProto)->Arg(0)->Arg(1000)->Arg(100000);