internal cleanup of benchmark lib
PiperOrigin-RevId: 337979979 Change-Id: Ife832c41e2c9cd738d950da5746bc45a8a33624a
This commit is contained in:
parent
09d8e7548c
commit
76bf7d80a0
@ -48,7 +48,8 @@ namespace test {
|
||||
// TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
|
||||
Benchmark::Benchmark(const string& device, Graph* g,
|
||||
const SessionOptions* options, Graph* init,
|
||||
Rendezvous* rendez, const char* executor_type) {
|
||||
Rendezvous* rendez, const char* executor_type,
|
||||
bool old_benchmark_api) {
|
||||
auto cleanup = gtl::MakeCleanup([g, init]() {
|
||||
delete g;
|
||||
delete init;
|
||||
@ -59,7 +60,8 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
||||
options = &default_options;
|
||||
}
|
||||
|
||||
testing::StopTiming();
|
||||
old_benchmark_api_ = old_benchmark_api;
|
||||
if (old_benchmark_api_) testing::StopTiming();
|
||||
string t = absl::AsciiStrToUpper(device);
|
||||
// Allow NewDevice to allocate a new threadpool with different number of
|
||||
// threads for each new benchmark.
|
||||
@ -135,6 +137,10 @@ Benchmark::~Benchmark() {
|
||||
|
||||
void Benchmark::Run(int iters) { RunWithRendezvousArgs({}, {}, iters); }
|
||||
|
||||
void Benchmark::Run(benchmark::State& state) {
|
||||
RunWithRendezvousArgs({}, {}, state);
|
||||
}
|
||||
|
||||
string GetRendezvousKey(const Node* node) {
|
||||
string send_device;
|
||||
TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device));
|
||||
@ -149,9 +155,63 @@ string GetRendezvousKey(const Node* node) {
|
||||
recv_device, tensor_name, FrameAndIter(0, 0));
|
||||
}
|
||||
|
||||
void Benchmark::RunWithRendezvousArgs(
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, benchmark::State& state) {
|
||||
CHECK(!old_benchmark_api_)
|
||||
<< "This method should only be called with new benchmark API";
|
||||
if (!device_ || state.max_iterations == 0) {
|
||||
return;
|
||||
}
|
||||
Tensor unused; // In benchmark, we don't care the return value.
|
||||
bool is_dead;
|
||||
|
||||
// Warm up
|
||||
Executor::Args args;
|
||||
args.rendezvous = rendez_;
|
||||
args.runner = [this](std::function<void()> closure) {
|
||||
pool_->Schedule(closure);
|
||||
};
|
||||
static const int kWarmupRuns = 3;
|
||||
for (int i = 0; i < kWarmupRuns; ++i) {
|
||||
for (const auto& p : inputs) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||
TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
|
||||
}
|
||||
TF_CHECK_OK(exec_->Run(args));
|
||||
for (const string& key : outputs) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||
TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
|
||||
}
|
||||
}
|
||||
TF_CHECK_OK(device_->Sync());
|
||||
VLOG(3) << kWarmupRuns << " warmup runs done.";
|
||||
|
||||
// Benchmark loop. Timer starts automatically at the beginning of the loop
|
||||
// and ends automatically after the last iteration.
|
||||
for (auto s : state) {
|
||||
for (const auto& p : inputs) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed));
|
||||
TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false));
|
||||
}
|
||||
TF_CHECK_OK(exec_->Run(args));
|
||||
for (const string& key : outputs) {
|
||||
Rendezvous::ParsedKey parsed;
|
||||
TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed));
|
||||
TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead));
|
||||
}
|
||||
}
|
||||
TF_CHECK_OK(device_->Sync());
|
||||
}
|
||||
|
||||
void Benchmark::RunWithRendezvousArgs(
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, int iters) {
|
||||
CHECK(old_benchmark_api_) << "This method should only be called when running "
|
||||
"with old benchmark API";
|
||||
if (!device_ || iters == 0) {
|
||||
return;
|
||||
}
|
||||
|
@ -26,6 +26,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace testing {
|
||||
namespace benchmark {
|
||||
class State;
|
||||
} // namespace benchmark
|
||||
} // namespace testing
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
@ -40,23 +46,42 @@ class Benchmark {
|
||||
public:
|
||||
// "device" must be either "cpu" or "gpu". Takes ownership of "g",
|
||||
// "init", and one reference on "rendez" (if not null).
|
||||
//
|
||||
// old_benchmark_api: If true, the benchmark is running with older API
|
||||
// * In the old API, the timer needs to be stopped/restarted
|
||||
// by users.
|
||||
// * In the new API, the timer starts automatically at the first
|
||||
// iteration of the loop and stops after the last iteration.
|
||||
// TODO(vyng) Remove this once we have migrated all code to newer API.
|
||||
Benchmark(const string& device, Graph* g,
|
||||
const SessionOptions* options = nullptr, Graph* init = nullptr,
|
||||
Rendezvous* rendez = nullptr, const char* executor_type = "");
|
||||
Rendezvous* rendez = nullptr, const char* executor_type = "",
|
||||
bool old_benchmark_api = true);
|
||||
~Benchmark();
|
||||
|
||||
// Executes the graph for "iters" times.
|
||||
// This function is deprecated. Use the overload that takes
|
||||
// `benchmark::State&`
|
||||
// instead.
|
||||
void Run(int iters);
|
||||
|
||||
void Run(::testing::benchmark::State& state);
|
||||
|
||||
// If "g" contains send/recv nodes, before each execution, we send
|
||||
// inputs to the corresponding recv keys in the graph, after each
|
||||
// execution, we recv outputs from the corresponding send keys in
|
||||
// the graph. In the benchmark, we throw away values returned by the
|
||||
// graph.
|
||||
// This function is deprecated. Use the overload that takes
|
||||
// `benchmark::State&` instead.
|
||||
void RunWithRendezvousArgs(
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, int iters);
|
||||
|
||||
void RunWithRendezvousArgs(
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, ::testing::benchmark::State& state);
|
||||
|
||||
private:
|
||||
thread::ThreadPool* pool_ = nullptr; // Not owned.
|
||||
Device* device_ = nullptr; // Not owned.
|
||||
@ -66,6 +91,7 @@ class Benchmark {
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||
FunctionLibraryRuntime* flr_; // Not owned.
|
||||
std::unique_ptr<Executor> exec_;
|
||||
bool old_benchmark_api_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
|
||||
};
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Tensorflow default + linux implementations of tensorflow/core/platform libraries.
|
||||
load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
||||
load("//tensorflow:tensorflow.bzl", "filegroup")
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test", "tf_copts")
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
|
||||
package(
|
||||
@ -429,12 +429,22 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core/platform",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:macros",
|
||||
"//tensorflow/core/platform:str_util",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/core/util:reporter",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "test_benchmark_test",
|
||||
srcs = ["test_benchmark_test.cc"],
|
||||
deps = [
|
||||
":test_benchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test",
|
||||
testonly = True,
|
||||
|
@ -52,18 +52,48 @@ Benchmark::Benchmark(const char* name, void (*fn)(int, int, int))
|
||||
Register();
|
||||
}
|
||||
|
||||
Benchmark::Benchmark(const char* name, void (*fn)(::testing::benchmark::State&))
|
||||
: name_(name),
|
||||
// -1 because the number of parameters is not part of the benchmark
|
||||
// routine signature.
|
||||
num_args_(-1),
|
||||
fn_state_(fn) {
|
||||
Register();
|
||||
}
|
||||
|
||||
void Benchmark::CheckArgCount(int expected) {
|
||||
if (num_args_ == expected) return;
|
||||
|
||||
// Number of args is not part of function signature.
|
||||
// Verify that if benchmark instantiation has previously provided args, they
|
||||
// match "args".
|
||||
if (num_args_ < 0) {
|
||||
if (args_.empty() || instantiated_num_args_ == expected) return;
|
||||
}
|
||||
CHECK(false) << "Expected " << expected << " args for benchmark, but got "
|
||||
<< instantiated_num_args_;
|
||||
}
|
||||
|
||||
Benchmark* Benchmark::Arg(int x) {
|
||||
CHECK_EQ(num_args_, 1);
|
||||
CheckArgCount(/*expected=*/1);
|
||||
args_.push_back(std::make_pair(x, -1));
|
||||
instantiated_num_args_ = 1;
|
||||
return this;
|
||||
}
|
||||
|
||||
Benchmark* Benchmark::ArgPair(int x, int y) {
|
||||
CHECK_EQ(num_args_, 2);
|
||||
CheckArgCount(/*expected=*/2);
|
||||
instantiated_num_args_ = 2;
|
||||
args_.push_back(std::make_pair(x, y));
|
||||
return this;
|
||||
}
|
||||
|
||||
Benchmark* Benchmark::UseRealTime() {
|
||||
// Do nothing.
|
||||
// This only exists for API compatibility with internal benchmarks.
|
||||
return this;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void AddRange(std::vector<int>* dst, int lo, int hi, int mult) {
|
||||
@ -210,6 +240,7 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
|
||||
static const int64 kMaxIters = 1000000000;
|
||||
static const double kMinTime = 0.5;
|
||||
int64 iters = kMinIters;
|
||||
|
||||
while (true) {
|
||||
accum_time = 0;
|
||||
start_time = env->NowMicros();
|
||||
@ -220,8 +251,11 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
|
||||
(*fn0_)(iters);
|
||||
} else if (fn1_) {
|
||||
(*fn1_)(iters, arg1);
|
||||
} else {
|
||||
} else if (fn2_) {
|
||||
(*fn2_)(iters, arg1, arg2);
|
||||
} else if (fn_state_) {
|
||||
::testing::benchmark::State state(iters, std::vector<int>(arg1, arg2));
|
||||
(*fn_state_)(state);
|
||||
}
|
||||
StopTiming();
|
||||
const double seconds = accum_time * 1e-6;
|
||||
@ -261,3 +295,38 @@ void UseRealTime() {}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace testing {
|
||||
namespace benchmark {
|
||||
State::State(size_t max_iterations, const std::vector<int>& args)
|
||||
: max_iterations(max_iterations), args_(args) {
|
||||
completed_iterations_ = 0;
|
||||
}
|
||||
|
||||
void State::PauseTiming() { ::tensorflow::testing::StopTiming(); }
|
||||
|
||||
void State::ResumeTiming() { ::tensorflow::testing::StartTiming(); }
|
||||
|
||||
void State::SetBytesProcessed(int64 bytes) {
|
||||
::tensorflow::testing::BytesProcessed(bytes);
|
||||
}
|
||||
|
||||
void State::SetItemsProcessed(int64 items) {
|
||||
::tensorflow::testing::ItemsProcessed(items);
|
||||
}
|
||||
|
||||
void State::SetLabel(absl::string_view label) {
|
||||
::tensorflow::testing::SetLabel(std::string(label));
|
||||
}
|
||||
|
||||
int State::range(size_t i) const {
|
||||
if (i >= args_.size()) {
|
||||
LOG(FATAL) << "argument for range " << i << " is not set";
|
||||
}
|
||||
return args_[i];
|
||||
}
|
||||
|
||||
void RunSpecifiedBenchmarks() { ::tensorflow::testing::Benchmark::Run("all"); }
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace testing
|
||||
|
@ -32,6 +32,12 @@ limitations under the License.
|
||||
#define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c)
|
||||
#define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c
|
||||
|
||||
namespace testing {
|
||||
namespace benchmark {
|
||||
class State;
|
||||
}
|
||||
} // namespace testing
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
|
||||
@ -77,26 +83,41 @@ void DoNotOptimize(const T& var) {
|
||||
|
||||
class Benchmark {
|
||||
public:
|
||||
Benchmark(const char* name, void (*fn)(int));
|
||||
Benchmark(const char* name, void (*fn)(int, int));
|
||||
Benchmark(const char* name, void (*fn)(int, int, int));
|
||||
[[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name,
|
||||
void (*fn)(int));
|
||||
|
||||
[[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name,
|
||||
void (*fn)(int,
|
||||
int));
|
||||
|
||||
[[deprecated("use `benchmark::State&` instead.")]] Benchmark(
|
||||
const char* name, void (*fn)(int, int, int));
|
||||
|
||||
Benchmark(const char* name, void (*fn)(::testing::benchmark::State&));
|
||||
|
||||
Benchmark* Arg(int x);
|
||||
Benchmark* ArgPair(int x, int y);
|
||||
Benchmark* Range(int lo, int hi);
|
||||
Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2);
|
||||
|
||||
Benchmark* UseRealTime();
|
||||
|
||||
static void Run(const char* pattern);
|
||||
|
||||
private:
|
||||
string name_;
|
||||
int num_args_;
|
||||
int instantiated_num_args_ = -1;
|
||||
std::vector<std::pair<int, int> > args_;
|
||||
void (*fn0_)(int) = nullptr;
|
||||
void (*fn1_)(int, int) = nullptr;
|
||||
void (*fn2_)(int, int, int) = nullptr;
|
||||
void (*fn_state_)(::testing::benchmark::State&) = nullptr;
|
||||
|
||||
void Register();
|
||||
void Run(int arg1, int arg2, int* run_count, double* run_seconds);
|
||||
|
||||
void CheckArgCount(int expected);
|
||||
};
|
||||
|
||||
void RunBenchmarks();
|
||||
@ -110,4 +131,151 @@ void UseRealTime();
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
// Support `void BM_Func(benchmark::State&)` interface so that the it is
|
||||
// compatible with the internal version.
|
||||
namespace testing {
|
||||
namespace benchmark {
|
||||
|
||||
using namespace tensorflow::testing; // NOLINT: for access to ::int64
|
||||
|
||||
// State is passed as an argument to a benchmark function.
|
||||
// Each thread in threaded benchmarks receives own object.
|
||||
class State {
|
||||
public:
|
||||
// Incomplete iterator-like type with dummy value type so that
|
||||
// benchmark::State can support iteration with a range-based for loop.
|
||||
//
|
||||
// The only supported usage:
|
||||
//
|
||||
// static void BM_Foo(benchmark::State& state) {
|
||||
// for (auto s : state) {
|
||||
// // perform single iteration
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This is meant to replace the deprecated API :
|
||||
//
|
||||
// static void BM_Foo(int iters) {
|
||||
// while (iters-- > 0) {
|
||||
// // perform single iteration
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// See go/benchmark#old-benchmark-interface for more details.
|
||||
class Iterator {
|
||||
public:
|
||||
struct Value {
|
||||
// Non-trivial destructor to avoid warning for unused dummy variable in
|
||||
// the range-based for loop.
|
||||
~Value() {}
|
||||
};
|
||||
|
||||
explicit Iterator(State* parent);
|
||||
|
||||
Iterator& operator++();
|
||||
|
||||
bool operator!=(const Iterator& other);
|
||||
|
||||
Value operator*();
|
||||
|
||||
private:
|
||||
State* const parent_;
|
||||
};
|
||||
|
||||
Iterator begin();
|
||||
Iterator end();
|
||||
|
||||
void PauseTiming();
|
||||
void ResumeTiming();
|
||||
|
||||
// Set the number of bytes processed by the current benchmark
|
||||
// execution. This routine is typically called once at the end of a
|
||||
// throughput oriented benchmark. If this routine is called with a
|
||||
// value > 0, then bytes processed per second is also reported.
|
||||
void SetBytesProcessed(int64 bytes);
|
||||
|
||||
// If this routine is called with items > 0, then an items/s
|
||||
// label is printed on the benchmark report line for the currently
|
||||
// executing benchmark. It is typically called at the end of a processing
|
||||
// benchmark where a processing items/second output is desired.
|
||||
void SetItemsProcessed(int64 items);
|
||||
|
||||
// If this method is called, the specified label is printed at the
|
||||
// end of the benchmark report line for the currently executing
|
||||
// benchmark. Example:
|
||||
// static void BM_Compress(benchmark::State& state) {
|
||||
// ...
|
||||
// double compression = input_size / output_size;
|
||||
// state.SetLabel(StringPrintf("compress:%.1f%%", 100.0*compression));
|
||||
// }
|
||||
// Produces output that looks like:
|
||||
// BM_Compress 50 50 14115038 compress:27.3%
|
||||
//
|
||||
// REQUIRES: a benchmark is currently executing
|
||||
void SetLabel(absl::string_view label);
|
||||
|
||||
// For parameterized benchmarks, range(i) returns the value of the ith
|
||||
// parameter. Simple benchmarks are not parameterized and do not need to call
|
||||
// range().
|
||||
int range(size_t i) const;
|
||||
|
||||
// Total number of iterations processed so far.
|
||||
size_t iterations() const;
|
||||
|
||||
const size_t
|
||||
max_iterations; // NOLINT: for compatibility with OSS benchmark library
|
||||
|
||||
// Disallow copy and assign.
|
||||
State(const State&) = delete;
|
||||
State& operator=(const State&) = delete;
|
||||
|
||||
protected:
|
||||
friend class tensorflow::testing::Benchmark;
|
||||
State(size_t max_iterations, const std::vector<int>& args);
|
||||
|
||||
private:
|
||||
size_t completed_iterations_;
|
||||
std::vector<int> args_;
|
||||
};
|
||||
|
||||
inline State::Iterator::Iterator(State* parent) : parent_(parent) {}
|
||||
|
||||
inline size_t State::iterations() const { return completed_iterations_; }
|
||||
|
||||
inline bool State::Iterator::operator!=(const Iterator& other) {
|
||||
DCHECK_EQ(other.parent_, nullptr);
|
||||
DCHECK_NE(parent_, nullptr);
|
||||
|
||||
if (parent_->completed_iterations_ < parent_->max_iterations) {
|
||||
return true;
|
||||
}
|
||||
|
||||
++parent_->completed_iterations_;
|
||||
// If this is the last iteration, stop the timer.
|
||||
parent_->PauseTiming();
|
||||
return false;
|
||||
}
|
||||
|
||||
inline State::Iterator& State::Iterator::operator++() {
|
||||
DCHECK_LT(parent_->completed_iterations_, parent_->max_iterations);
|
||||
++parent_->completed_iterations_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline State::Iterator::Value State::Iterator::operator*() { return Value(); }
|
||||
|
||||
inline State::Iterator State::begin() {
|
||||
// Starts the timer here because if the code uses this API, it expects
|
||||
// the timer to starts at the beginning of this loop.
|
||||
ResumeTiming();
|
||||
return Iterator(this);
|
||||
}
|
||||
|
||||
inline State::Iterator State::end() { return Iterator(nullptr); }
|
||||
|
||||
void RunSpecifiedBenchmarks();
|
||||
|
||||
} // namespace benchmark
|
||||
} // namespace testing
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_
|
||||
|
40
tensorflow/core/platform/default/test_benchmark_test.cc
Normal file
40
tensorflow/core/platform/default/test_benchmark_test.cc
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/default/test_benchmark.h"
|
||||
|
||||
// Test the new interface: BM_benchmark(benchmark::State& state)
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
void BM_TestIterState(::testing::benchmark::State& state) {
|
||||
int i = 0;
|
||||
for (auto s : state) {
|
||||
++i;
|
||||
DoNotOptimize(i);
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_TestIterState);
|
||||
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
int main() {
|
||||
::testing::benchmark::RunSpecifiedBenchmarks();
|
||||
return 0;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user