diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index de2dce9c0c2..8e878c2464d 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -505,6 +505,16 @@ cc_library( ], ) +cc_library( + name = "incremental_barrier", + srcs = ["incremental_barrier.cc"], + hdrs = ["incremental_barrier.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/functional:bind_front", + ], +) + # Tests. tf_cc_test( @@ -632,6 +642,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "incremental_barrier_test", + srcs = ["incremental_barrier_test.cc"], + deps = [ + ":incremental_barrier", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/platform", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/time", + ], +) + # Proto libraries. tf_proto_library( name = "test_log_proto_impl", diff --git a/tensorflow/core/util/incremental_barrier.cc b/tensorflow/core/util/incremental_barrier.cc new file mode 100644 index 00000000000..cbea7f25cc5 --- /dev/null +++ b/tensorflow/core/util/incremental_barrier.cc @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/incremental_barrier.h" + +#include +#include + +#include "absl/functional/bind_front.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class InternalIncrementalBarrier { + public: + explicit InternalIncrementalBarrier(IncrementalBarrier::DoneCallback callback) + : left_(1), done_callback_(std::move(callback)) {} + + void operator()() { + DCHECK_GE(left_.load(std::memory_order_relaxed), 0); + + if (left_.fetch_sub(1, std::memory_order_acq_rel) - 1 == 0) { + IncrementalBarrier::DoneCallback done_callback = + std::move(done_callback_); + delete this; + done_callback(); + } + } + + IncrementalBarrier::BarrierCallback Inc() { + left_.fetch_add(1, std::memory_order_acq_rel); + + // std::bind_front is only available ever since C++20. + return absl::bind_front(&InternalIncrementalBarrier::operator(), this); + } + + private: + std::atomic left_; + IncrementalBarrier::DoneCallback done_callback_; +}; + +IncrementalBarrier::IncrementalBarrier(DoneCallback done_callback) + : internal_barrier_( + new InternalIncrementalBarrier(std::move(done_callback))) {} + +IncrementalBarrier::~IncrementalBarrier() { (*internal_barrier_)(); } + +IncrementalBarrier::BarrierCallback IncrementalBarrier::Inc() { + return internal_barrier_->Inc(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/incremental_barrier.h b/tensorflow/core/util/incremental_barrier.h new file mode 100644 index 00000000000..be45e9d4d8b --- /dev/null +++ b/tensorflow/core/util/incremental_barrier.h @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ +#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ + +#include +#include + +namespace tensorflow { + +class InternalIncrementalBarrier; + +// BarrierClosure (see +// https://github.com/chromium/chromium/blob/master/base/barrier_closure.h) +// executes a callback after it has been invoked |num_closures| times. +// Plus, `BarrierClosure` is a continuation-passing style abstraction and self- +// deleting. + +// IncrementalBarrier is a convenience class to be used in place of a barrier +// closure, which is particularly helpful (e.g. simplify code) because callers +// don't need to calculate the |num_closures| beforehand. +// +// Example Usage: +// void MakeCalls() { +// typedef std::function Callback; +// typedef std::function OtherCallback; +// Callback done_callback = ... +// OtherCallback cb1 = ... +// OtherCallback cb2 = ... +// std::thread threads[2]; +// { +// IncrementalBarrier barrier(done_callback); +// threads[0] = std::thread(cb1(barrier.Inc()); +// threads[1] = std::thread(cb2(barrier.Inc()); +// ... at this moment, `barrier` is incremented twice, and then +// destructed.... +// } +// threads[0].join(); +// threads[1].join(); +// } +// +// `done_callback` will be called when both conditions are true: +// 1) after `barrier` is destructed. +// 2) Each `BarrierCallback` returned by `Inc` is called. +// This class is thread-safe. +class IncrementalBarrier { + public: + typedef std::function DoneCallback; + typedef std::function BarrierCallback; + explicit IncrementalBarrier(DoneCallback callback); + + ~IncrementalBarrier(); + + // Returns a BarrierCallback (std::function) that individual task call to + // signal its completeness. + // The returned BarrierCallback outlives this `IncrementalBarrier` instance. + // Furthermore, each task should eventually call the returned function, or + // else done_callback wouldn't be called. + BarrierCallback Inc(); + + private: + // self-deleting, thereby not owned by 'IncrementalBarrier'. + InternalIncrementalBarrier* internal_barrier_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_ diff --git a/tensorflow/core/util/incremental_barrier_test.cc b/tensorflow/core/util/incremental_barrier_test.cc new file mode 100644 index 00000000000..020cb9ece32 --- /dev/null +++ b/tensorflow/core/util/incremental_barrier_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/incremental_barrier.h" + +#include + +#include "absl/functional/bind_front.h" +#include "absl/time/time.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/platform.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/threadpool.h" + +namespace tensorflow { +namespace { + +// A thread-safe counter class. +class Counter { + public: + void Increment() TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + ++count_; + } + + int GetCount() TF_LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return count_; + } + + private: + mutex mu_; + int count_ = 0; +}; + +TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) { + Counter counter; + EXPECT_EQ(counter.GetCount(), 0); + { + IncrementalBarrier::DoneCallback done_callback = + absl::bind_front(&Counter::Increment, &counter); + IncrementalBarrier barrier(done_callback); + EXPECT_EQ(counter.GetCount(), 0); + } + EXPECT_EQ(counter.GetCount(), 1); +} + +TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) { + Counter counter; + + IncrementalBarrier::BarrierCallback bc1, bc2; + { + IncrementalBarrier::DoneCallback done_callback = + absl::bind_front(&Counter::Increment, &counter); + IncrementalBarrier barrier(done_callback); + + CHECK_EQ(counter.GetCount(), 0); + + bc1 = barrier.Inc(); + bc2 = barrier.Inc(); + + IncrementalBarrier::BarrierCallback bc3 = barrier.Inc(); + bc3(); + + CHECK_EQ(counter.GetCount(), 0); + } + + CHECK_EQ(counter.GetCount(), 0); + bc1(); + CHECK_EQ(counter.GetCount(), 0); + bc2(); + CHECK_EQ(counter.GetCount(), 1); +} + +TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) { + const int num_closure = 100, num_thread = 2; + std::atomic schedule_count{0}; + Counter counter; + + { + IncrementalBarrier::DoneCallback done_callback = + absl::bind_front(&Counter::Increment, &counter); + IncrementalBarrier barrier(done_callback); + + CHECK_EQ(counter.GetCount(), 0); + + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), + "BarrierClosure", num_thread); + for (int i = 0; i < num_closure; ++i) { + pool.Schedule([&barrier, &schedule_count]() { + schedule_count.fetch_add(1); + IncrementalBarrier::BarrierCallback bc = barrier.Inc(); + + Env::Default()->SleepForMicroseconds(100); + bc(); + }); + } + + CHECK_EQ(counter.GetCount(), 0); + } + + CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100); + CHECK_EQ(counter.GetCount(), 1); +} + +#if defined(PLATFORM_GOOGLE) +void BM_FunctionInc(benchmark::State& state) { + IncrementalBarrier barrier([] {}); + for (auto _ : state) { + barrier.Inc()(); + } +} + +BENCHMARK(BM_FunctionInc); +#endif // PLATFORM_GOOGLE + +} // namespace +} // namespace tensorflow