In tensorflow/core/util/, introduce a IncrementalBarrier library.

PiperOrigin-RevId: 313250473
Change-Id: I0cbb2d263d1639b1ea444b05ae7f5ea29fa252ce
This commit is contained in:
Mingming Liu 2020-05-26 12:51:59 -07:00 committed by TensorFlower Gardener
parent 13de0f1c98
commit c068a625c5
4 changed files with 302 additions and 0 deletions

View File

@ -505,6 +505,16 @@ cc_library(
], ],
) )
cc_library(
name = "incremental_barrier",
srcs = ["incremental_barrier.cc"],
hdrs = ["incremental_barrier.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/functional:bind_front",
],
)
# Tests. # Tests.
tf_cc_test( tf_cc_test(
@ -632,6 +642,20 @@ tf_cc_test(
], ],
) )
tf_cc_test(
name = "incremental_barrier_test",
srcs = ["incremental_barrier_test.cc"],
deps = [
":incremental_barrier",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/time",
],
)
# Proto libraries. # Proto libraries.
tf_proto_library( tf_proto_library(
name = "test_log_proto_impl", name = "test_log_proto_impl",

View File

@ -0,0 +1,64 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/incremental_barrier.h"
#include <atomic>
#include <functional>
#include "absl/functional/bind_front.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
class InternalIncrementalBarrier {
public:
explicit InternalIncrementalBarrier(IncrementalBarrier::DoneCallback callback)
: left_(1), done_callback_(std::move(callback)) {}
void operator()() {
DCHECK_GE(left_.load(std::memory_order_relaxed), 0);
if (left_.fetch_sub(1, std::memory_order_acq_rel) - 1 == 0) {
IncrementalBarrier::DoneCallback done_callback =
std::move(done_callback_);
delete this;
done_callback();
}
}
IncrementalBarrier::BarrierCallback Inc() {
left_.fetch_add(1, std::memory_order_acq_rel);
// std::bind_front is only available ever since C++20.
return absl::bind_front(&InternalIncrementalBarrier::operator(), this);
}
private:
std::atomic<int> left_;
IncrementalBarrier::DoneCallback done_callback_;
};
IncrementalBarrier::IncrementalBarrier(DoneCallback done_callback)
: internal_barrier_(
new InternalIncrementalBarrier(std::move(done_callback))) {}
IncrementalBarrier::~IncrementalBarrier() { (*internal_barrier_)(); }
IncrementalBarrier::BarrierCallback IncrementalBarrier::Inc() {
return internal_barrier_->Inc();
}
} // namespace tensorflow

View File

@ -0,0 +1,81 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
#include <atomic>
#include <functional>
namespace tensorflow {
class InternalIncrementalBarrier;
// BarrierClosure (see
// https://github.com/chromium/chromium/blob/master/base/barrier_closure.h)
// executes a callback after it has been invoked |num_closures| times.
// Plus, `BarrierClosure` is a continuation-passing style abstraction and self-
// deleting.
// IncrementalBarrier is a convenience class to be used in place of a barrier
// closure, which is particularly helpful (e.g. simplify code) because callers
// don't need to calculate the |num_closures| beforehand.
//
// Example Usage:
// void MakeCalls() {
// typedef std::function<void()> Callback;
// typedef std::function<void(Callback)> OtherCallback;
// Callback done_callback = ...
// OtherCallback cb1 = ...
// OtherCallback cb2 = ...
// std::thread threads[2];
// {
// IncrementalBarrier barrier(done_callback);
// threads[0] = std::thread(cb1(barrier.Inc());
// threads[1] = std::thread(cb2(barrier.Inc());
// ... at this moment, `barrier` is incremented twice, and then
// destructed....
// }
// threads[0].join();
// threads[1].join();
// }
//
// `done_callback` will be called when both conditions are true:
// 1) after `barrier` is destructed.
// 2) Each `BarrierCallback` returned by `Inc` is called.
// This class is thread-safe.
class IncrementalBarrier {
public:
typedef std::function<void()> DoneCallback;
typedef std::function<void()> BarrierCallback;
explicit IncrementalBarrier(DoneCallback callback);
~IncrementalBarrier();
// Returns a BarrierCallback (std::function) that individual task call to
// signal its completeness.
// The returned BarrierCallback outlives this `IncrementalBarrier` instance.
// Furthermore, each task should eventually call the returned function, or
// else done_callback wouldn't be called.
BarrierCallback Inc();
private:
// self-deleting, thereby not owned by 'IncrementalBarrier'.
InternalIncrementalBarrier* internal_barrier_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/util/incremental_barrier.h"
#include <atomic>
#include "absl/functional/bind_front.h"
#include "absl/time/time.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/threadpool.h"
namespace tensorflow {
namespace {
// A thread-safe counter class.
class Counter {
public:
void Increment() TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
++count_;
}
int GetCount() TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
return count_;
}
private:
mutex mu_;
int count_ = 0;
};
TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) {
Counter counter;
EXPECT_EQ(counter.GetCount(), 0);
{
IncrementalBarrier::DoneCallback done_callback =
absl::bind_front(&Counter::Increment, &counter);
IncrementalBarrier barrier(done_callback);
EXPECT_EQ(counter.GetCount(), 0);
}
EXPECT_EQ(counter.GetCount(), 1);
}
TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) {
Counter counter;
IncrementalBarrier::BarrierCallback bc1, bc2;
{
IncrementalBarrier::DoneCallback done_callback =
absl::bind_front(&Counter::Increment, &counter);
IncrementalBarrier barrier(done_callback);
CHECK_EQ(counter.GetCount(), 0);
bc1 = barrier.Inc();
bc2 = barrier.Inc();
IncrementalBarrier::BarrierCallback bc3 = barrier.Inc();
bc3();
CHECK_EQ(counter.GetCount(), 0);
}
CHECK_EQ(counter.GetCount(), 0);
bc1();
CHECK_EQ(counter.GetCount(), 0);
bc2();
CHECK_EQ(counter.GetCount(), 1);
}
TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) {
const int num_closure = 100, num_thread = 2;
std::atomic<int> schedule_count{0};
Counter counter;
{
IncrementalBarrier::DoneCallback done_callback =
absl::bind_front(&Counter::Increment, &counter);
IncrementalBarrier barrier(done_callback);
CHECK_EQ(counter.GetCount(), 0);
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
"BarrierClosure", num_thread);
for (int i = 0; i < num_closure; ++i) {
pool.Schedule([&barrier, &schedule_count]() {
schedule_count.fetch_add(1);
IncrementalBarrier::BarrierCallback bc = barrier.Inc();
Env::Default()->SleepForMicroseconds(100);
bc();
});
}
CHECK_EQ(counter.GetCount(), 0);
}
CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100);
CHECK_EQ(counter.GetCount(), 1);
}
#if defined(PLATFORM_GOOGLE)
void BM_FunctionInc(benchmark::State& state) {
IncrementalBarrier barrier([] {});
for (auto _ : state) {
barrier.Inc()();
}
}
BENCHMARK(BM_FunctionInc);
#endif // PLATFORM_GOOGLE
} // namespace
} // namespace tensorflow