internal cleanup of benchmark lib
PiperOrigin-RevId: 338324940 Change-Id: Ieb42e28f2ec6325f3ef63c892c5f71d29f72e485
This commit is contained in:
parent
ac2324c037
commit
38b03d8d05
@ -48,7 +48,8 @@ namespace test {
|
|||||||
// TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
|
// TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
|
||||||
Benchmark::Benchmark(const string& device, Graph* g,
|
Benchmark::Benchmark(const string& device, Graph* g,
|
||||||
const SessionOptions* options, Graph* init,
|
const SessionOptions* options, Graph* init,
|
||||||
Rendezvous* rendez, const char* executor_type) {
|
Rendezvous* rendez, const char* executor_type,
|
||||||
|
bool old_benchmark_api) {
|
||||||
auto cleanup = gtl::MakeCleanup([g, init]() {
|
auto cleanup = gtl::MakeCleanup([g, init]() {
|
||||||
delete g;
|
delete g;
|
||||||
delete init;
|
delete init;
|
||||||
@ -59,7 +60,8 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||||||
options = &default_options;
|
options = &default_options;
|
||||||
}
|
}
|
||||||
|
|
||||||
testing::StopTiming();
|
old_benchmark_api_ = old_benchmark_api;
|
||||||
|
if (old_benchmark_api_) testing::StopTiming();
|
||||||
string t = absl::AsciiStrToUpper(device);
|
string t = absl::AsciiStrToUpper(device);
|
||||||
// Allow NewDevice to allocate a new threadpool with different number of
|
// Allow NewDevice to allocate a new threadpool with different number of
|
||||||
// threads for each new benchmark.
|
// threads for each new benchmark.
|
||||||
@ -135,6 +137,10 @@ Benchmark::~Benchmark() {
|
|||||||
|
|
||||||
void Benchmark::Run(int iters) { RunWithRendezvousArgs({}, {}, iters); }
|
void Benchmark::Run(int iters) { RunWithRendezvousArgs({}, {}, iters); }
|
||||||
|
|
||||||
|
void Benchmark::Run(::testing::benchmark::State& state) {
|
||||||
|
RunWithRendezvousArgs({}, {}, state);
|
||||||
|
}
|
||||||
|
|
||||||
string GetRendezvousKey(const Node* node) {
|
string GetRendezvousKey(const Node* node) {
|
||||||
string send_device;
|
string send_device;
|
||||||
TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device));
|
TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device));
|
||||||
@ -149,9 +155,63 @@ string GetRendezvousKey(const Node* node) {
|
|||||||
recv_device, tensor_name, FrameAndIter(0, 0));
|
recv_device, tensor_name, FrameAndIter(0, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Benchmark::RunWithRendezvousArgs(
|
||||||
|
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||||
|
const std::vector<string>& outputs, ::testing::benchmark::State& state) {
|
||||||
|
CHECK(!old_benchmark_api_)
|
||||||
|
<< "This method should only be called with new benchmark API";
|
||||||
|
if (!device_ || state.max_iterations == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Tensor unused; // In benchmark, we don't care the return value.
|
||||||
|
bool is_dead;
|
||||||
|
|
||||||
|
// Warm up
|
||||||
|
Executor::Args args;
|
||||||
|
args.rendezvous = rendez_;
|
||||||
|
args.runner = [this](std::function<void()> closure) {
|
||||||
|
pool_->Schedule(closure);
|
||||||
|
};
|
||||||
|
static const int kWarmupRuns = 3;
|
||||||
|
for (int i = 0; i < kWarmupRuns; ++i) {
|
||||||
|
for (const auto& p : inputs) {
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||||
|
TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(exec_->Run(args));
|
||||||
|
for (const string& key : outputs) {
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||||
|
TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(device_->Sync());
|
||||||
|
VLOG(3) << kWarmupRuns << " warmup runs done.";
|
||||||
|
|
||||||
|
// Benchmark loop. Timer starts automatically at the beginning of the loop
|
||||||
|
// and ends automatically after the last iteration.
|
||||||
|
for (auto s : state) {
|
||||||
|
for (const auto& p : inputs) {
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||||
|
TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(exec_->Run(args));
|
||||||
|
for (const string& key : outputs) {
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||||
|
TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TF_CHECK_OK(device_->Sync());
|
||||||
|
}
|
||||||
|
|
||||||
void Benchmark::RunWithRendezvousArgs(
|
void Benchmark::RunWithRendezvousArgs(
|
||||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||||
const std::vector<string>& outputs, int iters) {
|
const std::vector<string>& outputs, int iters) {
|
||||||
|
CHECK(old_benchmark_api_) << "This method should only be called when running "
|
||||||
|
"with old benchmark API";
|
||||||
if (!device_ || iters == 0) {
|
if (!device_ || iters == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
namespace benchmark {
|
||||||
|
class State;
|
||||||
|
} // namespace benchmark
|
||||||
|
} // namespace testing
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
@ -40,23 +46,42 @@ class Benchmark {
|
|||||||
public:
|
public:
|
||||||
// "device" must be either "cpu" or "gpu". Takes ownership of "g",
|
// "device" must be either "cpu" or "gpu". Takes ownership of "g",
|
||||||
// "init", and one reference on "rendez" (if not null).
|
// "init", and one reference on "rendez" (if not null).
|
||||||
|
//
|
||||||
|
// old_benchmark_api: If true, the benchmark is running with older API
|
||||||
|
// * In the old API, the timer needs to be stopped/restarted
|
||||||
|
// by users.
|
||||||
|
// * In the new API, the timer starts automatically at the first
|
||||||
|
// iteration of the loop and stops after the last iteration.
|
||||||
|
// TODO(vyng) Remove this once we have migrated all code to newer API.
|
||||||
Benchmark(const string& device, Graph* g,
|
Benchmark(const string& device, Graph* g,
|
||||||
const SessionOptions* options = nullptr, Graph* init = nullptr,
|
const SessionOptions* options = nullptr, Graph* init = nullptr,
|
||||||
Rendezvous* rendez = nullptr, const char* executor_type = "");
|
Rendezvous* rendez = nullptr, const char* executor_type = "",
|
||||||
|
bool old_benchmark_api = true);
|
||||||
~Benchmark();
|
~Benchmark();
|
||||||
|
|
||||||
// Executes the graph for "iters" times.
|
// Executes the graph for "iters" times.
|
||||||
|
// This function is deprecated. Use the overload that takes
|
||||||
|
// `benchmark::State&`
|
||||||
|
// instead.
|
||||||
void Run(int iters);
|
void Run(int iters);
|
||||||
|
|
||||||
|
void Run(::testing::benchmark::State& state);
|
||||||
|
|
||||||
// If "g" contains send/recv nodes, before each execution, we send
|
// If "g" contains send/recv nodes, before each execution, we send
|
||||||
// inputs to the corresponding recv keys in the graph, after each
|
// inputs to the corresponding recv keys in the graph, after each
|
||||||
// execution, we recv outputs from the corresponding send keys in
|
// execution, we recv outputs from the corresponding send keys in
|
||||||
// the graph. In the benchmark, we throw away values returned by the
|
// the graph. In the benchmark, we throw away values returned by the
|
||||||
// graph.
|
// graph.
|
||||||
|
// This function is deprecated. Use the overload that takes
|
||||||
|
// `benchmark::State&` instead.
|
||||||
void RunWithRendezvousArgs(
|
void RunWithRendezvousArgs(
|
||||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||||
const std::vector<string>& outputs, int iters);
|
const std::vector<string>& outputs, int iters);
|
||||||
|
|
||||||
|
void RunWithRendezvousArgs(
|
||||||
|
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||||
|
const std::vector<string>& outputs, ::testing::benchmark::State& state);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
thread::ThreadPool* pool_ = nullptr; // Not owned.
|
thread::ThreadPool* pool_ = nullptr; // Not owned.
|
||||||
Device* device_ = nullptr; // Not owned.
|
Device* device_ = nullptr; // Not owned.
|
||||||
@ -66,6 +91,7 @@ class Benchmark {
|
|||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||||
FunctionLibraryRuntime* flr_; // Not owned.
|
FunctionLibraryRuntime* flr_; // Not owned.
|
||||||
std::unique_ptr<Executor> exec_;
|
std::unique_ptr<Executor> exec_;
|
||||||
|
bool old_benchmark_api_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
|
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Tensorflow default + linux implementations of tensorflow/core/platform libraries.
|
# Tensorflow default + linux implementations of tensorflow/core/platform libraries.
|
||||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||||
load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_copts")
|
load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test", "tf_copts")
|
||||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
@ -429,12 +429,28 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core/platform",
|
"//tensorflow/core/platform",
|
||||||
"//tensorflow/core/platform:env",
|
"//tensorflow/core/platform:env",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/core/platform:macros",
|
"//tensorflow/core/platform:macros",
|
||||||
|
"//tensorflow/core/platform:str_util",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
"//tensorflow/core/util:reporter",
|
"//tensorflow/core/util:reporter",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "test_benchmark_test",
|
||||||
|
srcs = ["test_benchmark_test.cc"],
|
||||||
|
tags = [
|
||||||
|
"nobuilder",
|
||||||
|
"notap",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":test_benchmark",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "test",
|
name = "test",
|
||||||
testonly = True,
|
testonly = True,
|
||||||
|
@ -52,18 +52,48 @@ Benchmark::Benchmark(const char* name, void (*fn)(int, int, int))
|
|||||||
Register();
|
Register();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Benchmark::Benchmark(const char* name, void (*fn)(::testing::benchmark::State&))
|
||||||
|
: name_(name),
|
||||||
|
// -1 because the number of parameters is not part of the benchmark
|
||||||
|
// routine signature.
|
||||||
|
num_args_(-1),
|
||||||
|
fn_state_(fn) {
|
||||||
|
Register();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Benchmark::CheckArgCount(int expected) {
|
||||||
|
if (num_args_ == expected) return;
|
||||||
|
|
||||||
|
// Number of args is not part of function signature.
|
||||||
|
// Verify that if benchmark instantiation has previously provided args, they
|
||||||
|
// match "args".
|
||||||
|
if (num_args_ < 0) {
|
||||||
|
if (args_.empty() || instantiated_num_args_ == expected) return;
|
||||||
|
}
|
||||||
|
CHECK(false) << "Expected " << expected << " args for benchmark, but got "
|
||||||
|
<< instantiated_num_args_;
|
||||||
|
}
|
||||||
|
|
||||||
Benchmark* Benchmark::Arg(int x) {
|
Benchmark* Benchmark::Arg(int x) {
|
||||||
CHECK_EQ(num_args_, 1);
|
CheckArgCount(/*expected=*/1);
|
||||||
args_.push_back(std::make_pair(x, -1));
|
args_.push_back(std::make_pair(x, -1));
|
||||||
|
instantiated_num_args_ = 1;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Benchmark* Benchmark::ArgPair(int x, int y) {
|
Benchmark* Benchmark::ArgPair(int x, int y) {
|
||||||
CHECK_EQ(num_args_, 2);
|
CheckArgCount(/*expected=*/2);
|
||||||
|
instantiated_num_args_ = 2;
|
||||||
args_.push_back(std::make_pair(x, y));
|
args_.push_back(std::make_pair(x, y));
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Benchmark* Benchmark::UseRealTime() {
|
||||||
|
// Do nothing.
|
||||||
|
// This only exists for API compatibility with internal benchmarks.
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void AddRange(std::vector<int>* dst, int lo, int hi, int mult) {
|
void AddRange(std::vector<int>* dst, int lo, int hi, int mult) {
|
||||||
@ -210,6 +240,7 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
|
|||||||
static const int64 kMaxIters = 1000000000;
|
static const int64 kMaxIters = 1000000000;
|
||||||
static const double kMinTime = 0.5;
|
static const double kMinTime = 0.5;
|
||||||
int64 iters = kMinIters;
|
int64 iters = kMinIters;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
accum_time = 0;
|
accum_time = 0;
|
||||||
start_time = env->NowMicros();
|
start_time = env->NowMicros();
|
||||||
@ -220,8 +251,11 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
|
|||||||
(*fn0_)(iters);
|
(*fn0_)(iters);
|
||||||
} else if (fn1_) {
|
} else if (fn1_) {
|
||||||
(*fn1_)(iters, arg1);
|
(*fn1_)(iters, arg1);
|
||||||
} else {
|
} else if (fn2_) {
|
||||||
(*fn2_)(iters, arg1, arg2);
|
(*fn2_)(iters, arg1, arg2);
|
||||||
|
} else if (fn_state_) {
|
||||||
|
::testing::benchmark::State state(iters, std::vector<int>(arg1, arg2));
|
||||||
|
(*fn_state_)(state);
|
||||||
}
|
}
|
||||||
StopTiming();
|
StopTiming();
|
||||||
const double seconds = accum_time * 1e-6;
|
const double seconds = accum_time * 1e-6;
|
||||||
@ -261,3 +295,38 @@ void UseRealTime() {}
|
|||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
namespace benchmark {
|
||||||
|
State::State(size_t max_iterations, const std::vector<int>& args)
|
||||||
|
: max_iterations(max_iterations), args_(args) {
|
||||||
|
completed_iterations_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void State::PauseTiming() { ::tensorflow::testing::StopTiming(); }
|
||||||
|
|
||||||
|
void State::ResumeTiming() { ::tensorflow::testing::StartTiming(); }
|
||||||
|
|
||||||
|
void State::SetBytesProcessed(::tensorflow::int64 bytes) {
|
||||||
|
::tensorflow::testing::BytesProcessed(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void State::SetItemsProcessed(::tensorflow::int64 items) {
|
||||||
|
::tensorflow::testing::ItemsProcessed(items);
|
||||||
|
}
|
||||||
|
|
||||||
|
void State::SetLabel(absl::string_view label) {
|
||||||
|
::tensorflow::testing::SetLabel(std::string(label));
|
||||||
|
}
|
||||||
|
|
||||||
|
int State::range(size_t i) const {
|
||||||
|
if (i >= args_.size()) {
|
||||||
|
LOG(FATAL) << "argument for range " << i << " is not set";
|
||||||
|
}
|
||||||
|
return args_[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
void RunSpecifiedBenchmarks() { ::tensorflow::testing::Benchmark::Run("all"); }
|
||||||
|
|
||||||
|
} // namespace benchmark
|
||||||
|
} // namespace testing
|
||||||
|
@ -32,6 +32,12 @@ limitations under the License.
|
|||||||
#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c)
|
#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c)
|
||||||
#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c
|
#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c
|
||||||
|
|
||||||
|
namespace testing {
|
||||||
|
namespace benchmark {
|
||||||
|
class State;
|
||||||
|
}
|
||||||
|
} // namespace testing
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
|
|
||||||
@ -77,26 +83,41 @@ void DoNotOptimize(const T& var) {
|
|||||||
|
|
||||||
class Benchmark {
|
class Benchmark {
|
||||||
public:
|
public:
|
||||||
Benchmark(const char* name, void (*fn)(int));
|
[[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name,
|
||||||
Benchmark(const char* name, void (*fn)(int, int));
|
void (*fn)(int));
|
||||||
Benchmark(const char* name, void (*fn)(int, int, int));
|
|
||||||
|
[[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name,
|
||||||
|
void (*fn)(int,
|
||||||
|
int));
|
||||||
|
|
||||||
|
[[deprecated("use `benchmark::State&` instead.")]] Benchmark(
|
||||||
|
const char* name, void (*fn)(int, int, int));
|
||||||
|
|
||||||
|
Benchmark(const char* name, void (*fn)(::testing::benchmark::State&));
|
||||||
|
|
||||||
Benchmark* Arg(int x);
|
Benchmark* Arg(int x);
|
||||||
Benchmark* ArgPair(int x, int y);
|
Benchmark* ArgPair(int x, int y);
|
||||||
Benchmark* Range(int lo, int hi);
|
Benchmark* Range(int lo, int hi);
|
||||||
Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2);
|
Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2);
|
||||||
|
|
||||||
|
Benchmark* UseRealTime();
|
||||||
|
|
||||||
static void Run(const char* pattern);
|
static void Run(const char* pattern);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string name_;
|
string name_;
|
||||||
int num_args_;
|
int num_args_;
|
||||||
|
int instantiated_num_args_ = -1;
|
||||||
std::vector<std::pair<int, int> > args_;
|
std::vector<std::pair<int, int> > args_;
|
||||||
void (*fn0_)(int) = nullptr;
|
void (*fn0_)(int) = nullptr;
|
||||||
void (*fn1_)(int, int) = nullptr;
|
void (*fn1_)(int, int) = nullptr;
|
||||||
void (*fn2_)(int, int, int) = nullptr;
|
void (*fn2_)(int, int, int) = nullptr;
|
||||||
|
void (*fn_state_)(::testing::benchmark::State&) = nullptr;
|
||||||
|
|
||||||
void Register();
|
void Register();
|
||||||
void Run(int arg1, int arg2, int* run_count, double* run_seconds);
|
void Run(int arg1, int arg2, int* run_count, double* run_seconds);
|
||||||
|
|
||||||
|
void CheckArgCount(int expected);
|
||||||
};
|
};
|
||||||
|
|
||||||
void RunBenchmarks();
|
void RunBenchmarks();
|
||||||
@ -110,4 +131,148 @@ void UseRealTime();
|
|||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
// Support `void BM_Func(benchmark::State&)` interface so that the it is
|
||||||
|
// compatible with the internal version.
|
||||||
|
namespace testing {
|
||||||
|
namespace benchmark {
|
||||||
|
// State is passed as an argument to a benchmark function.
|
||||||
|
// Each thread in threaded benchmarks receives own object.
|
||||||
|
class State {
|
||||||
|
public:
|
||||||
|
// Incomplete iterator-like type with dummy value type so that
|
||||||
|
// benchmark::State can support iteration with a range-based for loop.
|
||||||
|
//
|
||||||
|
// The only supported usage:
|
||||||
|
//
|
||||||
|
// static void BM_Foo(benchmark::State& state) {
|
||||||
|
// for (auto s : state) {
|
||||||
|
// // perform single iteration
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// This is meant to replace the deprecated API :
|
||||||
|
//
|
||||||
|
// static void BM_Foo(int iters) {
|
||||||
|
// while (iters-- > 0) {
|
||||||
|
// // perform single iteration
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// See go/benchmark#old-benchmark-interface for more details.
|
||||||
|
class Iterator {
|
||||||
|
public:
|
||||||
|
struct Value {
|
||||||
|
// Non-trivial destructor to avoid warning for unused dummy variable in
|
||||||
|
// the range-based for loop.
|
||||||
|
~Value() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit Iterator(State* parent);
|
||||||
|
|
||||||
|
Iterator& operator++();
|
||||||
|
|
||||||
|
bool operator!=(const Iterator& other);
|
||||||
|
|
||||||
|
Value operator*();
|
||||||
|
|
||||||
|
private:
|
||||||
|
State* const parent_;
|
||||||
|
};
|
||||||
|
|
||||||
|
Iterator begin();
|
||||||
|
Iterator end();
|
||||||
|
|
||||||
|
void PauseTiming();
|
||||||
|
void ResumeTiming();
|
||||||
|
|
||||||
|
// Set the number of bytes processed by the current benchmark
|
||||||
|
// execution. This routine is typically called once at the end of a
|
||||||
|
// throughput oriented benchmark. If this routine is called with a
|
||||||
|
// value > 0, then bytes processed per second is also reported.
|
||||||
|
void SetBytesProcessed(::tensorflow::int64 bytes);
|
||||||
|
|
||||||
|
// If this routine is called with items > 0, then an items/s
|
||||||
|
// label is printed on the benchmark report line for the currently
|
||||||
|
// executing benchmark. It is typically called at the end of a processing
|
||||||
|
// benchmark where a processing items/second output is desired.
|
||||||
|
void SetItemsProcessed(::tensorflow::int64 items);
|
||||||
|
|
||||||
|
// If this method is called, the specified label is printed at the
|
||||||
|
// end of the benchmark report line for the currently executing
|
||||||
|
// benchmark. Example:
|
||||||
|
// static void BM_Compress(benchmark::State& state) {
|
||||||
|
// ...
|
||||||
|
// double compression = input_size / output_size;
|
||||||
|
// state.SetLabel(StringPrintf("compress:%.1f%%", 100.0*compression));
|
||||||
|
// }
|
||||||
|
// Produces output that looks like:
|
||||||
|
// BM_Compress 50 50 14115038 compress:27.3%
|
||||||
|
//
|
||||||
|
// REQUIRES: a benchmark is currently executing
|
||||||
|
void SetLabel(absl::string_view label);
|
||||||
|
|
||||||
|
// For parameterized benchmarks, range(i) returns the value of the ith
|
||||||
|
// parameter. Simple benchmarks are not parameterized and do not need to call
|
||||||
|
// range().
|
||||||
|
int range(size_t i) const;
|
||||||
|
|
||||||
|
// Total number of iterations processed so far.
|
||||||
|
size_t iterations() const;
|
||||||
|
|
||||||
|
const size_t
|
||||||
|
max_iterations; // NOLINT: for compatibility with OSS benchmark library
|
||||||
|
|
||||||
|
// Disallow copy and assign.
|
||||||
|
State(const State&) = delete;
|
||||||
|
State& operator=(const State&) = delete;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
friend class tensorflow::testing::Benchmark;
|
||||||
|
State(size_t max_iterations, const std::vector<int>& args);
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t completed_iterations_;
|
||||||
|
std::vector<int> args_;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline State::Iterator::Iterator(State* parent) : parent_(parent) {}
|
||||||
|
|
||||||
|
inline size_t State::iterations() const { return completed_iterations_; }
|
||||||
|
|
||||||
|
inline bool State::Iterator::operator!=(const Iterator& other) {
|
||||||
|
DCHECK_EQ(other.parent_, nullptr);
|
||||||
|
DCHECK_NE(parent_, nullptr);
|
||||||
|
|
||||||
|
if (parent_->completed_iterations_ < parent_->max_iterations) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
++parent_->completed_iterations_;
|
||||||
|
// If this is the last iteration, stop the timer.
|
||||||
|
parent_->PauseTiming();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline State::Iterator& State::Iterator::operator++() {
|
||||||
|
DCHECK_LT(parent_->completed_iterations_, parent_->max_iterations);
|
||||||
|
++parent_->completed_iterations_;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline State::Iterator::Value State::Iterator::operator*() { return Value(); }
|
||||||
|
|
||||||
|
inline State::Iterator State::begin() {
|
||||||
|
// Starts the timer here because if the code uses this API, it expects
|
||||||
|
// the timer to starts at the beginning of this loop.
|
||||||
|
ResumeTiming();
|
||||||
|
return Iterator(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline State::Iterator State::end() { return Iterator(nullptr); }
|
||||||
|
|
||||||
|
void RunSpecifiedBenchmarks();
|
||||||
|
|
||||||
|
} // namespace benchmark
|
||||||
|
} // namespace testing
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_
|
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_
|
||||||
|
40
tensorflow/core/platform/default/test_benchmark_test.cc
Normal file
40
tensorflow/core/platform/default/test_benchmark_test.cc
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/default/test_benchmark.h"
|
||||||
|
|
||||||
|
// Test the new interface: BM_benchmark(benchmark::State& state)
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace testing {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void BM_TestIterState(::testing::benchmark::State& state) {
|
||||||
|
int i = 0;
|
||||||
|
for (auto s : state) {
|
||||||
|
++i;
|
||||||
|
DoNotOptimize(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_TestIterState);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace testing
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
::testing::benchmark::RunSpecifiedBenchmarks();
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user