In tensorflow/core/util/, introduce a IncrementalBarrier library.
PiperOrigin-RevId: 313250473 Change-Id: I0cbb2d263d1639b1ea444b05ae7f5ea29fa252ce
This commit is contained in:
parent
13de0f1c98
commit
c068a625c5
@ -505,6 +505,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "incremental_barrier",
|
||||
srcs = ["incremental_barrier.cc"],
|
||||
hdrs = ["incremental_barrier.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/functional:bind_front",
|
||||
],
|
||||
)
|
||||
|
||||
# Tests.
|
||||
|
||||
tf_cc_test(
|
||||
@ -632,6 +642,20 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "incremental_barrier_test",
|
||||
srcs = ["incremental_barrier_test.cc"],
|
||||
deps = [
|
||||
":incremental_barrier",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform",
|
||||
"@com_google_absl//absl/functional:bind_front",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
||||
# Proto libraries.
|
||||
tf_proto_library(
|
||||
name = "test_log_proto_impl",
|
||||
|
64
tensorflow/core/util/incremental_barrier.cc
Normal file
64
tensorflow/core/util/incremental_barrier.cc
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/incremental_barrier.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
|
||||
#include "absl/functional/bind_front.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class InternalIncrementalBarrier {
|
||||
public:
|
||||
explicit InternalIncrementalBarrier(IncrementalBarrier::DoneCallback callback)
|
||||
: left_(1), done_callback_(std::move(callback)) {}
|
||||
|
||||
void operator()() {
|
||||
DCHECK_GE(left_.load(std::memory_order_relaxed), 0);
|
||||
|
||||
if (left_.fetch_sub(1, std::memory_order_acq_rel) - 1 == 0) {
|
||||
IncrementalBarrier::DoneCallback done_callback =
|
||||
std::move(done_callback_);
|
||||
delete this;
|
||||
done_callback();
|
||||
}
|
||||
}
|
||||
|
||||
IncrementalBarrier::BarrierCallback Inc() {
|
||||
left_.fetch_add(1, std::memory_order_acq_rel);
|
||||
|
||||
// std::bind_front is only available ever since C++20.
|
||||
return absl::bind_front(&InternalIncrementalBarrier::operator(), this);
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic<int> left_;
|
||||
IncrementalBarrier::DoneCallback done_callback_;
|
||||
};
|
||||
|
||||
IncrementalBarrier::IncrementalBarrier(DoneCallback done_callback)
|
||||
: internal_barrier_(
|
||||
new InternalIncrementalBarrier(std::move(done_callback))) {}
|
||||
|
||||
IncrementalBarrier::~IncrementalBarrier() { (*internal_barrier_)(); }
|
||||
|
||||
IncrementalBarrier::BarrierCallback IncrementalBarrier::Inc() {
|
||||
return internal_barrier_->Inc();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
81
tensorflow/core/util/incremental_barrier.h
Normal file
81
tensorflow/core/util/incremental_barrier.h
Normal file
@ -0,0 +1,81 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <functional>
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class InternalIncrementalBarrier;
|
||||
|
||||
// BarrierClosure (see
|
||||
// https://github.com/chromium/chromium/blob/master/base/barrier_closure.h)
|
||||
// executes a callback after it has been invoked |num_closures| times.
|
||||
// Plus, `BarrierClosure` is a continuation-passing style abstraction and self-
|
||||
// deleting.
|
||||
|
||||
// IncrementalBarrier is a convenience class to be used in place of a barrier
|
||||
// closure, which is particularly helpful (e.g. simplify code) because callers
|
||||
// don't need to calculate the |num_closures| beforehand.
|
||||
//
|
||||
// Example Usage:
|
||||
// void MakeCalls() {
|
||||
// typedef std::function<void()> Callback;
|
||||
// typedef std::function<void(Callback)> OtherCallback;
|
||||
// Callback done_callback = ...
|
||||
// OtherCallback cb1 = ...
|
||||
// OtherCallback cb2 = ...
|
||||
// std::thread threads[2];
|
||||
// {
|
||||
// IncrementalBarrier barrier(done_callback);
|
||||
// threads[0] = std::thread(cb1(barrier.Inc());
|
||||
// threads[1] = std::thread(cb2(barrier.Inc());
|
||||
// ... at this moment, `barrier` is incremented twice, and then
|
||||
// destructed....
|
||||
// }
|
||||
// threads[0].join();
|
||||
// threads[1].join();
|
||||
// }
|
||||
//
|
||||
// `done_callback` will be called when both conditions are true:
|
||||
// 1) after `barrier` is destructed.
|
||||
// 2) Each `BarrierCallback` returned by `Inc` is called.
|
||||
// This class is thread-safe.
|
||||
class IncrementalBarrier {
|
||||
public:
|
||||
typedef std::function<void()> DoneCallback;
|
||||
typedef std::function<void()> BarrierCallback;
|
||||
explicit IncrementalBarrier(DoneCallback callback);
|
||||
|
||||
~IncrementalBarrier();
|
||||
|
||||
// Returns a BarrierCallback (std::function) that individual task call to
|
||||
// signal its completeness.
|
||||
// The returned BarrierCallback outlives this `IncrementalBarrier` instance.
|
||||
// Furthermore, each task should eventually call the returned function, or
|
||||
// else done_callback wouldn't be called.
|
||||
BarrierCallback Inc();
|
||||
|
||||
private:
|
||||
// self-deleting, thereby not owned by 'IncrementalBarrier'.
|
||||
InternalIncrementalBarrier* internal_barrier_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
|
133
tensorflow/core/util/incremental_barrier_test.cc
Normal file
133
tensorflow/core/util/incremental_barrier_test.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/incremental_barrier.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "absl/functional/bind_front.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// A thread-safe counter class.
|
||||
class Counter {
|
||||
public:
|
||||
void Increment() TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
++count_;
|
||||
}
|
||||
|
||||
int GetCount() TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
return count_;
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
int count_ = 0;
|
||||
};
|
||||
|
||||
TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) {
|
||||
Counter counter;
|
||||
EXPECT_EQ(counter.GetCount(), 0);
|
||||
{
|
||||
IncrementalBarrier::DoneCallback done_callback =
|
||||
absl::bind_front(&Counter::Increment, &counter);
|
||||
IncrementalBarrier barrier(done_callback);
|
||||
EXPECT_EQ(counter.GetCount(), 0);
|
||||
}
|
||||
EXPECT_EQ(counter.GetCount(), 1);
|
||||
}
|
||||
|
||||
TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) {
|
||||
Counter counter;
|
||||
|
||||
IncrementalBarrier::BarrierCallback bc1, bc2;
|
||||
{
|
||||
IncrementalBarrier::DoneCallback done_callback =
|
||||
absl::bind_front(&Counter::Increment, &counter);
|
||||
IncrementalBarrier barrier(done_callback);
|
||||
|
||||
CHECK_EQ(counter.GetCount(), 0);
|
||||
|
||||
bc1 = barrier.Inc();
|
||||
bc2 = barrier.Inc();
|
||||
|
||||
IncrementalBarrier::BarrierCallback bc3 = barrier.Inc();
|
||||
bc3();
|
||||
|
||||
CHECK_EQ(counter.GetCount(), 0);
|
||||
}
|
||||
|
||||
CHECK_EQ(counter.GetCount(), 0);
|
||||
bc1();
|
||||
CHECK_EQ(counter.GetCount(), 0);
|
||||
bc2();
|
||||
CHECK_EQ(counter.GetCount(), 1);
|
||||
}
|
||||
|
||||
TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) {
|
||||
const int num_closure = 100, num_thread = 2;
|
||||
std::atomic<int> schedule_count{0};
|
||||
Counter counter;
|
||||
|
||||
{
|
||||
IncrementalBarrier::DoneCallback done_callback =
|
||||
absl::bind_front(&Counter::Increment, &counter);
|
||||
IncrementalBarrier barrier(done_callback);
|
||||
|
||||
CHECK_EQ(counter.GetCount(), 0);
|
||||
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
|
||||
"BarrierClosure", num_thread);
|
||||
for (int i = 0; i < num_closure; ++i) {
|
||||
pool.Schedule([&barrier, &schedule_count]() {
|
||||
schedule_count.fetch_add(1);
|
||||
IncrementalBarrier::BarrierCallback bc = barrier.Inc();
|
||||
|
||||
Env::Default()->SleepForMicroseconds(100);
|
||||
bc();
|
||||
});
|
||||
}
|
||||
|
||||
CHECK_EQ(counter.GetCount(), 0);
|
||||
}
|
||||
|
||||
CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100);
|
||||
CHECK_EQ(counter.GetCount(), 1);
|
||||
}
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
void BM_FunctionInc(benchmark::State& state) {
|
||||
IncrementalBarrier barrier([] {});
|
||||
for (auto _ : state) {
|
||||
barrier.Inc()();
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_FunctionInc);
|
||||
#endif // PLATFORM_GOOGLE
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user