Internal tests cleanup.
PiperOrigin-RevId: 339315357 Change-Id: I0245a2c37991f3a34c13828217bd0f4cce177705
This commit is contained in:
parent
44baa200d2
commit
dde0fe783c
@ -95,10 +95,10 @@ TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
|
|||||||
TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
|
TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
|
||||||
TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
|
TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
|
||||||
|
|
||||||
void BM_TrackUnique(int iters) {
|
static void BM_TrackUnique(::testing::benchmark::State& state) {
|
||||||
RecentRequestIds recent_request_ids(100000);
|
RecentRequestIds recent_request_ids(100000);
|
||||||
RecvTensorRequest request;
|
RecvTensorRequest request;
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (auto s : state) {
|
||||||
TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(),
|
TF_CHECK_OK(recent_request_ids.TrackUnique(GetUniqueRequestId(),
|
||||||
"BM_TrackUnique", request));
|
"BM_TrackUnique", request));
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
|
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||||
@ -134,43 +135,41 @@ TEST(GrpcProto, ParseFromString) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_UnparseGrpc(int iters, int size) {
|
static void BM_UnparseGrpc(::testing::benchmark::State& state) {
|
||||||
testing::StopTiming();
|
const int size = state.range(0);
|
||||||
|
|
||||||
auto proto = MakeProto(size);
|
auto proto = MakeProto(size);
|
||||||
testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i < iters; i++) {
|
|
||||||
grpc::ByteBuffer buf;
|
grpc::ByteBuffer buf;
|
||||||
CHECK(GrpcMaybeUnparseProto(proto, &buf).ok());
|
CHECK(GrpcMaybeUnparseProto(proto, &buf).ok());
|
||||||
}
|
}
|
||||||
testing::StopTiming();
|
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_UnparseGrpc)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
|
BENCHMARK(BM_UnparseGrpc)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
|
||||||
|
|
||||||
static void BM_UnparseString(int iters, int size) {
|
static void BM_UnparseString(::testing::benchmark::State& state) {
|
||||||
testing::StopTiming();
|
const int size = state.range(0);
|
||||||
|
|
||||||
auto proto = MakeProto(size);
|
auto proto = MakeProto(size);
|
||||||
testing::StartTiming();
|
testing::StartTiming();
|
||||||
|
|
||||||
for (int i = 0; i < iters; i++) {
|
for (auto s : state) {
|
||||||
string buf;
|
string buf;
|
||||||
proto.SerializeToString(&buf);
|
proto.SerializeToString(&buf);
|
||||||
}
|
}
|
||||||
|
|
||||||
testing::StopTiming();
|
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_UnparseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
|
BENCHMARK(BM_UnparseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
|
||||||
|
|
||||||
static void BM_ParseGrpc(int iters, int size, int num_slices) {
|
static void BM_ParseGrpc(::testing::benchmark::State& state) {
|
||||||
testing::StopTiming();
|
const int size = state.range(0);
|
||||||
|
const int num_slices = state.range(1);
|
||||||
|
|
||||||
CleanupAllRequest proto = MakeProto(size);
|
CleanupAllRequest proto = MakeProto(size);
|
||||||
auto buf = MakeBuffer(proto.SerializeAsString(), num_slices);
|
auto buf = MakeBuffer(proto.SerializeAsString(), num_slices);
|
||||||
testing::StartTiming();
|
testing::StartTiming();
|
||||||
|
|
||||||
for (int i = 0; i < iters; i++) {
|
for (auto s : state) {
|
||||||
CHECK(GrpcMaybeParseProto(&buf, &proto));
|
CHECK(GrpcMaybeParseProto(&buf, &proto));
|
||||||
}
|
}
|
||||||
|
|
||||||
testing::StopTiming();
|
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_ParseGrpc)
|
BENCHMARK(BM_ParseGrpc)
|
||||||
->ArgPair(1, 1)
|
->ArgPair(1, 1)
|
||||||
@ -179,17 +178,16 @@ BENCHMARK(BM_ParseGrpc)
|
|||||||
->ArgPair(1 << 20, 1)
|
->ArgPair(1 << 20, 1)
|
||||||
->ArgPair(1 << 20, 4);
|
->ArgPair(1 << 20, 4);
|
||||||
|
|
||||||
static void BM_ParseString(int iters, int size) {
|
static void BM_ParseString(::testing::benchmark::State& state) {
|
||||||
testing::StopTiming();
|
const int size = state.range(0);
|
||||||
|
|
||||||
CleanupAllRequest proto = MakeProto(size);
|
CleanupAllRequest proto = MakeProto(size);
|
||||||
string serial = proto.SerializeAsString();
|
string serial = proto.SerializeAsString();
|
||||||
testing::StartTiming();
|
testing::StartTiming();
|
||||||
|
|
||||||
for (int i = 0; i < iters; i++) {
|
for (auto s : state) {
|
||||||
CHECK(proto.ParseFromString(serial));
|
CHECK(proto.ParseFromString(serial));
|
||||||
}
|
}
|
||||||
|
|
||||||
testing::StopTiming();
|
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_ParseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
|
BENCHMARK(BM_ParseString)->Arg(1)->Arg(1 << 10)->Arg(1 << 20);
|
||||||
|
|
||||||
|
@ -166,9 +166,9 @@ string DebugString(const Tensor& x, const Tensor& y, int tensor_size) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Support sharding and depth.
|
// TODO: Support sharding and depth.
|
||||||
static void BM_Helper(int iters, int width, int num_stages, int tensor_size,
|
static void BM_Helper(::testing::benchmark::State& state, int width,
|
||||||
|
int num_stages, int tensor_size,
|
||||||
bool use_multiple_devices) {
|
bool use_multiple_devices) {
|
||||||
testing::StopTiming();
|
|
||||||
const Cluster* cluster = GetCluster();
|
const Cluster* cluster = GetCluster();
|
||||||
|
|
||||||
// Creates a session.
|
// Creates a session.
|
||||||
@ -203,17 +203,18 @@ static void BM_Helper(int iters, int width, int num_stages, int tensor_size,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Iterations.
|
// Iterations.
|
||||||
testing::StartTiming();
|
for (auto s : state) {
|
||||||
for (int i = 0; i < iters; i++) {
|
|
||||||
outputs.clear();
|
outputs.clear();
|
||||||
TF_CHECK_OK(session->Run({{"x", x}}, {"y:0"}, {}, &outputs));
|
TF_CHECK_OK(session->Run({{"x", x}}, {"y:0"}, {}, &outputs));
|
||||||
CHECK_EQ(size_t{1}, outputs.size());
|
CHECK_EQ(size_t{1}, outputs.size());
|
||||||
}
|
}
|
||||||
testing::StopTiming();
|
|
||||||
TF_CHECK_OK(session->Close());
|
TF_CHECK_OK(session->Close());
|
||||||
}
|
}
|
||||||
static void BM_ShardedProgram(int iters, int width, int num_stages) {
|
static void BM_ShardedProgram(::testing::benchmark::State& state) {
|
||||||
BM_Helper(iters, width, num_stages, 2 /*tensor_size*/, true /*multi-device*/);
|
const int width = state.range(0);
|
||||||
|
const int num_stages = state.range(1);
|
||||||
|
|
||||||
|
BM_Helper(state, width, num_stages, 2 /*tensor_size*/, true /*multi-device*/);
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_ShardedProgram)
|
BENCHMARK(BM_ShardedProgram)
|
||||||
->ArgPair(1, 1)
|
->ArgPair(1, 1)
|
||||||
@ -232,13 +233,19 @@ BENCHMARK(BM_ShardedProgram)
|
|||||||
->ArgPair(60, 3)
|
->ArgPair(60, 3)
|
||||||
->ArgPair(60, 5);
|
->ArgPair(60, 5);
|
||||||
|
|
||||||
static void BM_RPC(int iters, int width, int tensor_size) {
|
static void BM_RPC(::testing::benchmark::State& state) {
|
||||||
BM_Helper(iters, width, 2 /*num_stages*/, tensor_size, true /*multi-device*/);
|
const int width = state.range(0);
|
||||||
|
const int tensor_size = state.range(1);
|
||||||
|
|
||||||
|
BM_Helper(state, width, 2 /*num_stages*/, tensor_size, true /*multi-device*/);
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_RPC)->ArgPair(30, 2)->ArgPair(30, 1000)->ArgPair(30, 100000);
|
BENCHMARK(BM_RPC)->ArgPair(30, 2)->ArgPair(30, 1000)->ArgPair(30, 100000);
|
||||||
|
|
||||||
static void BM_SingleDevice(int iters, int width, int num_stages) {
|
static void BM_SingleDevice(::testing::benchmark::State& state) {
|
||||||
BM_Helper(iters, width, num_stages, 2 /*tensor_size*/,
|
const int width = state.range(0);
|
||||||
|
const int num_stages = state.range(1);
|
||||||
|
|
||||||
|
BM_Helper(state, width, num_stages, 2 /*tensor_size*/,
|
||||||
false /*not multi-device*/);
|
false /*not multi-device*/);
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_SingleDevice)
|
BENCHMARK(BM_SingleDevice)
|
||||||
|
@ -173,37 +173,36 @@ string MakeFloatTensorTestCase(int num_elems) {
|
|||||||
return encoded;
|
return encoded;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_TensorResponse(int iters, int arg) {
|
static void BM_TensorResponse(::testing::benchmark::State& state) {
|
||||||
testing::StopTiming();
|
const int arg = state.range(0);
|
||||||
|
|
||||||
string encoded = MakeFloatTensorTestCase(arg);
|
string encoded = MakeFloatTensorTestCase(arg);
|
||||||
DummyDevice cpu_device(Env::Default());
|
DummyDevice cpu_device(Env::Default());
|
||||||
testing::StartTiming();
|
size_t bytes = 0;
|
||||||
while (--iters > 0) {
|
for (auto i : state) {
|
||||||
TensorResponse response;
|
TensorResponse response;
|
||||||
response.InitAlloc(&cpu_device, AllocatorAttributes());
|
response.InitAlloc(&cpu_device, AllocatorAttributes());
|
||||||
StringSource source(&encoded, -1);
|
StringSource source(&encoded, -1);
|
||||||
Status s = response.ParseFrom(&source);
|
Status s = response.ParseFrom(&source);
|
||||||
if (iters == 1) {
|
bytes = response.tensor().TotalBytes();
|
||||||
testing::SetLabel(
|
|
||||||
strings::StrCat("Bytes: ", response.tensor().TotalBytes()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
state.SetLabel(strings::StrCat("Bytes: ", bytes));
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_TensorResponse)->Arg(0)->Arg(1000)->Arg(100000);
|
BENCHMARK(BM_TensorResponse)->Arg(0)->Arg(1000)->Arg(100000);
|
||||||
|
|
||||||
static void BM_TensorViaTensorProto(int iters, int arg) {
|
static void BM_TensorViaTensorProto(::testing::benchmark::State& state) {
|
||||||
testing::StopTiming();
|
const int arg = state.range(0);
|
||||||
string encoded = MakeFloatTensorTestCase(arg);
|
|
||||||
testing::StartTiming();
|
std::string encoded = MakeFloatTensorTestCase(arg);
|
||||||
while (--iters > 0) {
|
size_t bytes = 0;
|
||||||
|
for (auto s : state) {
|
||||||
RecvTensorResponse r;
|
RecvTensorResponse r;
|
||||||
r.ParseFromString(encoded);
|
r.ParseFromString(encoded);
|
||||||
Tensor t;
|
Tensor t;
|
||||||
CHECK(t.FromProto(r.tensor()));
|
CHECK(t.FromProto(r.tensor()));
|
||||||
if (iters == 1) {
|
bytes = t.TotalBytes();
|
||||||
testing::SetLabel(strings::StrCat("Bytes: ", t.TotalBytes()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
state.SetLabel(strings::StrCat("Bytes: ", bytes));
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_TensorViaTensorProto)->Arg(0)->Arg(1000)->Arg(100000);
|
BENCHMARK(BM_TensorViaTensorProto)->Arg(0)->Arg(1000)->Arg(100000);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user