diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 0b6edf74daf..a6a589d4db3 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -48,7 +48,8 @@ namespace test { // TODO(hongm): Convert `g` and `init` to using std::unique_ptr. Benchmark::Benchmark(const string& device, Graph* g, const SessionOptions* options, Graph* init, - Rendezvous* rendez, const char* executor_type) { + Rendezvous* rendez, const char* executor_type, + bool old_benchmark_api) { auto cleanup = gtl::MakeCleanup([g, init]() { delete g; delete init; @@ -59,7 +60,8 @@ Benchmark::Benchmark(const string& device, Graph* g, options = &default_options; } - testing::StopTiming(); + old_benchmark_api_ = old_benchmark_api; + if (old_benchmark_api_) testing::StopTiming(); string t = absl::AsciiStrToUpper(device); // Allow NewDevice to allocate a new threadpool with different number of // threads for each new benchmark. @@ -135,6 +137,10 @@ Benchmark::~Benchmark() { void Benchmark::Run(int iters) { RunWithRendezvousArgs({}, {}, iters); } +void Benchmark::Run(benchmark::State& state) { + RunWithRendezvousArgs({}, {}, state); +} + string GetRendezvousKey(const Node* node) { string send_device; TF_CHECK_OK(GetNodeAttr(node->attrs(), "send_device", &send_device)); @@ -149,9 +155,63 @@ string GetRendezvousKey(const Node* node) { recv_device, tensor_name, FrameAndIter(0, 0)); } +void Benchmark::RunWithRendezvousArgs( + const std::vector>& inputs, + const std::vector& outputs, benchmark::State& state) { + CHECK(!old_benchmark_api_) + << "This method should only be called with new benchmark API"; + if (!device_ || state.max_iterations == 0) { + return; + } + Tensor unused; // In benchmark, we don't care the return value. + bool is_dead; + + // Warm up + Executor::Args args; + args.rendezvous = rendez_; + args.runner = [this](std::function closure) { + pool_->Schedule(closure); + }; + static const int kWarmupRuns = 3; + for (int i = 0; i < kWarmupRuns; ++i) { + for (const auto& p : inputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed)); + TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false)); + } + TF_CHECK_OK(exec_->Run(args)); + for (const string& key : outputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); + TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead)); + } + } + TF_CHECK_OK(device_->Sync()); + VLOG(3) << kWarmupRuns << " warmup runs done."; + + // Benchmark loop. Timer starts automatically at the beginning of the loop + // and ends automatically after the last iteration. + for (auto s : state) { + for (const auto& p : inputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(p.first, &parsed)); + TF_CHECK_OK(rendez_->Send(parsed, Rendezvous::Args(), p.second, false)); + } + TF_CHECK_OK(exec_->Run(args)); + for (const string& key : outputs) { + Rendezvous::ParsedKey parsed; + TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); + TF_CHECK_OK(rendez_->Recv(parsed, Rendezvous::Args(), &unused, &is_dead)); + } + } + TF_CHECK_OK(device_->Sync()); +} + void Benchmark::RunWithRendezvousArgs( const std::vector>& inputs, const std::vector& outputs, int iters) { + CHECK(old_benchmark_api_) << "This method should only be called when running " + "with old benchmark API"; if (!device_ || iters == 0) { return; } diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h index 9c6b1eb088c..fe161b6b939 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h @@ -26,6 +26,12 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +namespace testing { +namespace benchmark { +class State; +} // namespace benchmark +} // namespace testing + namespace tensorflow { class Device; @@ -40,23 +46,42 @@ class Benchmark { public: // "device" must be either "cpu" or "gpu". Takes ownership of "g", // "init", and one reference on "rendez" (if not null). + // + // old_benchmark_api: If true, the benchmark is running with older API + // * In the old API, the timer needs to be stopped/restarted + // by users. + // * In the new API, the timer starts automatically at the first + // iteration of the loop and stops after the last iteration. + // TODO(vyng) Remove this once we have migrated all code to newer API. Benchmark(const string& device, Graph* g, const SessionOptions* options = nullptr, Graph* init = nullptr, - Rendezvous* rendez = nullptr, const char* executor_type = ""); + Rendezvous* rendez = nullptr, const char* executor_type = "", + bool old_benchmark_api = true); ~Benchmark(); // Executes the graph for "iters" times. + // This function is deprecated. Use the overload that takes + // `benchmark::State&` + // instead. void Run(int iters); + void Run(::testing::benchmark::State& state); + // If "g" contains send/recv nodes, before each execution, we send // inputs to the corresponding recv keys in the graph, after each // execution, we recv outputs from the corresponding send keys in // the graph. In the benchmark, we throw away values returned by the // graph. + // This function is deprecated. Use the overload that takes + // `benchmark::State&` instead. void RunWithRendezvousArgs( const std::vector>& inputs, const std::vector& outputs, int iters); + void RunWithRendezvousArgs( + const std::vector>& inputs, + const std::vector& outputs, ::testing::benchmark::State& state); + private: thread::ThreadPool* pool_ = nullptr; // Not owned. Device* device_ = nullptr; // Not owned. @@ -66,6 +91,7 @@ class Benchmark { std::unique_ptr pflr_; FunctionLibraryRuntime* flr_; // Not owned. std::unique_ptr exec_; + bool old_benchmark_api_; TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); }; diff --git a/tensorflow/core/platform/default/BUILD b/tensorflow/core/platform/default/BUILD index 849048e99be..962a32a94f6 100644 --- a/tensorflow/core/platform/default/BUILD +++ b/tensorflow/core/platform/default/BUILD @@ -1,7 +1,7 @@ # Tensorflow default + linux implementations of tensorflow/core/platform libraries. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("//tensorflow:tensorflow.bzl", "filegroup") -load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_copts") +load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test", "tf_copts") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( @@ -429,12 +429,22 @@ cc_library( deps = [ "//tensorflow/core/platform", "//tensorflow/core/platform:env", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:str_util", "//tensorflow/core/platform:types", "//tensorflow/core/util:reporter", ], ) +tf_cc_test( + name = "test_benchmark_test", + srcs = ["test_benchmark_test.cc"], + deps = [ + ":test_benchmark", + ], +) + cc_library( name = "test", testonly = True, diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc index 533c4ac1df1..6b1bb57f0b6 100644 --- a/tensorflow/core/platform/default/test_benchmark.cc +++ b/tensorflow/core/platform/default/test_benchmark.cc @@ -52,18 +52,48 @@ Benchmark::Benchmark(const char* name, void (*fn)(int, int, int)) Register(); } +Benchmark::Benchmark(const char* name, void (*fn)(::testing::benchmark::State&)) + : name_(name), + // -1 because the number of parameters is not part of the benchmark + // routine signature. + num_args_(-1), + fn_state_(fn) { + Register(); +} + +void Benchmark::CheckArgCount(int expected) { + if (num_args_ == expected) return; + + // Number of args is not part of function signature. + // Verify that if benchmark instantiation has previously provided args, they + // match "args". + if (num_args_ < 0) { + if (args_.empty() || instantiated_num_args_ == expected) return; + } + CHECK(false) << "Expected " << expected << " args for benchmark, but got " + << instantiated_num_args_; +} + Benchmark* Benchmark::Arg(int x) { - CHECK_EQ(num_args_, 1); + CheckArgCount(/*expected=*/1); args_.push_back(std::make_pair(x, -1)); + instantiated_num_args_ = 1; return this; } Benchmark* Benchmark::ArgPair(int x, int y) { - CHECK_EQ(num_args_, 2); + CheckArgCount(/*expected=*/2); + instantiated_num_args_ = 2; args_.push_back(std::make_pair(x, y)); return this; } +Benchmark* Benchmark::UseRealTime() { + // Do nothing. + // This only exists for API compatibility with internal benchmarks. + return this; +} + namespace { void AddRange(std::vector* dst, int lo, int hi, int mult) { @@ -210,6 +240,7 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) { static const int64 kMaxIters = 1000000000; static const double kMinTime = 0.5; int64 iters = kMinIters; + while (true) { accum_time = 0; start_time = env->NowMicros(); @@ -220,8 +251,11 @@ void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) { (*fn0_)(iters); } else if (fn1_) { (*fn1_)(iters, arg1); - } else { + } else if (fn2_) { (*fn2_)(iters, arg1, arg2); + } else if (fn_state_) { + ::testing::benchmark::State state(iters, std::vector(arg1, arg2)); + (*fn_state_)(state); } StopTiming(); const double seconds = accum_time * 1e-6; @@ -261,3 +295,38 @@ void UseRealTime() {} } // namespace testing } // namespace tensorflow + +namespace testing { +namespace benchmark { +State::State(size_t max_iterations, const std::vector& args) + : max_iterations(max_iterations), args_(args) { + completed_iterations_ = 0; +} + +void State::PauseTiming() { ::tensorflow::testing::StopTiming(); } + +void State::ResumeTiming() { ::tensorflow::testing::StartTiming(); } + +void State::SetBytesProcessed(int64 bytes) { + ::tensorflow::testing::BytesProcessed(bytes); +} + +void State::SetItemsProcessed(int64 items) { + ::tensorflow::testing::ItemsProcessed(items); +} + +void State::SetLabel(absl::string_view label) { + ::tensorflow::testing::SetLabel(std::string(label)); +} + +int State::range(size_t i) const { + if (i >= args_.size()) { + LOG(FATAL) << "argument for range " << i << " is not set"; + } + return args_[i]; +} + +void RunSpecifiedBenchmarks() { ::tensorflow::testing::Benchmark::Run("all"); } + +} // namespace benchmark +} // namespace testing diff --git a/tensorflow/core/platform/default/test_benchmark.h b/tensorflow/core/platform/default/test_benchmark.h index 55149e5c050..a7cf674637c 100644 --- a/tensorflow/core/platform/default/test_benchmark.h +++ b/tensorflow/core/platform/default/test_benchmark.h @@ -32,6 +32,12 @@ limitations under the License. #define TF_BENCHMARK_CONCAT(a, b, c) TF_BENCHMARK_CONCAT2(a, b, c) #define TF_BENCHMARK_CONCAT2(a, b, c) a##b##c +namespace testing { +namespace benchmark { +class State; +} +} // namespace testing + namespace tensorflow { namespace testing { @@ -77,26 +83,41 @@ void DoNotOptimize(const T& var) { class Benchmark { public: - Benchmark(const char* name, void (*fn)(int)); - Benchmark(const char* name, void (*fn)(int, int)); - Benchmark(const char* name, void (*fn)(int, int, int)); + [[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name, + void (*fn)(int)); + + [[deprecated("use `benchmark::State&` instead.")]] Benchmark(const char* name, + void (*fn)(int, + int)); + + [[deprecated("use `benchmark::State&` instead.")]] Benchmark( + const char* name, void (*fn)(int, int, int)); + + Benchmark(const char* name, void (*fn)(::testing::benchmark::State&)); Benchmark* Arg(int x); Benchmark* ArgPair(int x, int y); Benchmark* Range(int lo, int hi); Benchmark* RangePair(int lo1, int hi1, int lo2, int hi2); + + Benchmark* UseRealTime(); + static void Run(const char* pattern); private: string name_; int num_args_; + int instantiated_num_args_ = -1; std::vector > args_; void (*fn0_)(int) = nullptr; void (*fn1_)(int, int) = nullptr; void (*fn2_)(int, int, int) = nullptr; + void (*fn_state_)(::testing::benchmark::State&) = nullptr; void Register(); void Run(int arg1, int arg2, int* run_count, double* run_seconds); + + void CheckArgCount(int expected); }; void RunBenchmarks(); @@ -110,4 +131,151 @@ void UseRealTime(); } // namespace testing } // namespace tensorflow +// Support `void BM_Func(benchmark::State&)` interface so that the it is +// compatible with the internal version. +namespace testing { +namespace benchmark { + +using namespace tensorflow::testing; // NOLINT: for access to ::int64 + +// State is passed as an argument to a benchmark function. +// Each thread in threaded benchmarks receives own object. +class State { + public: + // Incomplete iterator-like type with dummy value type so that + // benchmark::State can support iteration with a range-based for loop. + // + // The only supported usage: + // + // static void BM_Foo(benchmark::State& state) { + // for (auto s : state) { + // // perform single iteration + // } + // } + // + // This is meant to replace the deprecated API : + // + // static void BM_Foo(int iters) { + // while (iters-- > 0) { + // // perform single iteration + // } + // } + // + // See go/benchmark#old-benchmark-interface for more details. + class Iterator { + public: + struct Value { + // Non-trivial destructor to avoid warning for unused dummy variable in + // the range-based for loop. + ~Value() {} + }; + + explicit Iterator(State* parent); + + Iterator& operator++(); + + bool operator!=(const Iterator& other); + + Value operator*(); + + private: + State* const parent_; + }; + + Iterator begin(); + Iterator end(); + + void PauseTiming(); + void ResumeTiming(); + + // Set the number of bytes processed by the current benchmark + // execution. This routine is typically called once at the end of a + // throughput oriented benchmark. If this routine is called with a + // value > 0, then bytes processed per second is also reported. + void SetBytesProcessed(int64 bytes); + + // If this routine is called with items > 0, then an items/s + // label is printed on the benchmark report line for the currently + // executing benchmark. It is typically called at the end of a processing + // benchmark where a processing items/second output is desired. + void SetItemsProcessed(int64 items); + + // If this method is called, the specified label is printed at the + // end of the benchmark report line for the currently executing + // benchmark. Example: + // static void BM_Compress(benchmark::State& state) { + // ... + // double compression = input_size / output_size; + // state.SetLabel(StringPrintf("compress:%.1f%%", 100.0*compression)); + // } + // Produces output that looks like: + // BM_Compress 50 50 14115038 compress:27.3% + // + // REQUIRES: a benchmark is currently executing + void SetLabel(absl::string_view label); + + // For parameterized benchmarks, range(i) returns the value of the ith + // parameter. Simple benchmarks are not parameterized and do not need to call + // range(). + int range(size_t i) const; + + // Total number of iterations processed so far. + size_t iterations() const; + + const size_t + max_iterations; // NOLINT: for compatibility with OSS benchmark library + + // Disallow copy and assign. + State(const State&) = delete; + State& operator=(const State&) = delete; + + protected: + friend class tensorflow::testing::Benchmark; + State(size_t max_iterations, const std::vector& args); + + private: + size_t completed_iterations_; + std::vector args_; +}; + +inline State::Iterator::Iterator(State* parent) : parent_(parent) {} + +inline size_t State::iterations() const { return completed_iterations_; } + +inline bool State::Iterator::operator!=(const Iterator& other) { + DCHECK_EQ(other.parent_, nullptr); + DCHECK_NE(parent_, nullptr); + + if (parent_->completed_iterations_ < parent_->max_iterations) { + return true; + } + + ++parent_->completed_iterations_; + // If this is the last iteration, stop the timer. + parent_->PauseTiming(); + return false; +} + +inline State::Iterator& State::Iterator::operator++() { + DCHECK_LT(parent_->completed_iterations_, parent_->max_iterations); + ++parent_->completed_iterations_; + return *this; +} + +inline State::Iterator::Value State::Iterator::operator*() { return Value(); } + +inline State::Iterator State::begin() { + // Starts the timer here because if the code uses this API, it expects + // the timer to starts at the beginning of this loop. + ResumeTiming(); + return Iterator(this); +} + +inline State::Iterator State::end() { return Iterator(nullptr); } + +void RunSpecifiedBenchmarks(); + +} // namespace benchmark +} // namespace testing + #endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TEST_BENCHMARK_H_ diff --git a/tensorflow/core/platform/default/test_benchmark_test.cc b/tensorflow/core/platform/default/test_benchmark_test.cc new file mode 100644 index 00000000000..2c692b2af7a --- /dev/null +++ b/tensorflow/core/platform/default/test_benchmark_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/default/test_benchmark.h" + +// Test the new interface: BM_benchmark(benchmark::State& state) +namespace tensorflow { +namespace testing { +namespace { + +void BM_TestIterState(::testing::benchmark::State& state) { + int i = 0; + for (auto s : state) { + ++i; + DoNotOptimize(i); + } +} + +BENCHMARK(BM_TestIterState); + +} // namespace +} // namespace testing +} // namespace tensorflow + +int main() { + ::testing::benchmark::RunSpecifiedBenchmarks(); + return 0; +}