In tensorflow/core/util/, introduce a IncrementalBarrier library.
PiperOrigin-RevId: 313250473 Change-Id: I0cbb2d263d1639b1ea444b05ae7f5ea29fa252ce
This commit is contained in:
parent
13de0f1c98
commit
c068a625c5
@ -505,6 +505,16 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "incremental_barrier",
|
||||||
|
srcs = ["incremental_barrier.cc"],
|
||||||
|
hdrs = ["incremental_barrier.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/functional:bind_front",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Tests.
|
# Tests.
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
@ -632,6 +642,20 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "incremental_barrier_test",
|
||||||
|
srcs = ["incremental_barrier_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":incremental_barrier",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/platform",
|
||||||
|
"@com_google_absl//absl/functional:bind_front",
|
||||||
|
"@com_google_absl//absl/time",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Proto libraries.
|
# Proto libraries.
|
||||||
tf_proto_library(
|
tf_proto_library(
|
||||||
name = "test_log_proto_impl",
|
name = "test_log_proto_impl",
|
||||||
|
64
tensorflow/core/util/incremental_barrier.cc
Normal file
64
tensorflow/core/util/incremental_barrier.cc
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/incremental_barrier.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
#include "absl/functional/bind_front.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class InternalIncrementalBarrier {
|
||||||
|
public:
|
||||||
|
explicit InternalIncrementalBarrier(IncrementalBarrier::DoneCallback callback)
|
||||||
|
: left_(1), done_callback_(std::move(callback)) {}
|
||||||
|
|
||||||
|
void operator()() {
|
||||||
|
DCHECK_GE(left_.load(std::memory_order_relaxed), 0);
|
||||||
|
|
||||||
|
if (left_.fetch_sub(1, std::memory_order_acq_rel) - 1 == 0) {
|
||||||
|
IncrementalBarrier::DoneCallback done_callback =
|
||||||
|
std::move(done_callback_);
|
||||||
|
delete this;
|
||||||
|
done_callback();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
IncrementalBarrier::BarrierCallback Inc() {
|
||||||
|
left_.fetch_add(1, std::memory_order_acq_rel);
|
||||||
|
|
||||||
|
// std::bind_front is only available ever since C++20.
|
||||||
|
return absl::bind_front(&InternalIncrementalBarrier::operator(), this);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::atomic<int> left_;
|
||||||
|
IncrementalBarrier::DoneCallback done_callback_;
|
||||||
|
};
|
||||||
|
|
||||||
|
IncrementalBarrier::IncrementalBarrier(DoneCallback done_callback)
|
||||||
|
: internal_barrier_(
|
||||||
|
new InternalIncrementalBarrier(std::move(done_callback))) {}
|
||||||
|
|
||||||
|
IncrementalBarrier::~IncrementalBarrier() { (*internal_barrier_)(); }
|
||||||
|
|
||||||
|
IncrementalBarrier::BarrierCallback IncrementalBarrier::Inc() {
|
||||||
|
return internal_barrier_->Inc();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
81
tensorflow/core/util/incremental_barrier.h
Normal file
81
tensorflow/core/util/incremental_barrier.h
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class InternalIncrementalBarrier;
|
||||||
|
|
||||||
|
// BarrierClosure (see
|
||||||
|
// https://github.com/chromium/chromium/blob/master/base/barrier_closure.h)
|
||||||
|
// executes a callback after it has been invoked |num_closures| times.
|
||||||
|
// Plus, `BarrierClosure` is a continuation-passing style abstraction and self-
|
||||||
|
// deleting.
|
||||||
|
|
||||||
|
// IncrementalBarrier is a convenience class to be used in place of a barrier
|
||||||
|
// closure, which is particularly helpful (e.g. simplify code) because callers
|
||||||
|
// don't need to calculate the |num_closures| beforehand.
|
||||||
|
//
|
||||||
|
// Example Usage:
|
||||||
|
// void MakeCalls() {
|
||||||
|
// typedef std::function<void()> Callback;
|
||||||
|
// typedef std::function<void(Callback)> OtherCallback;
|
||||||
|
// Callback done_callback = ...
|
||||||
|
// OtherCallback cb1 = ...
|
||||||
|
// OtherCallback cb2 = ...
|
||||||
|
// std::thread threads[2];
|
||||||
|
// {
|
||||||
|
// IncrementalBarrier barrier(done_callback);
|
||||||
|
// threads[0] = std::thread(cb1(barrier.Inc());
|
||||||
|
// threads[1] = std::thread(cb2(barrier.Inc());
|
||||||
|
// ... at this moment, `barrier` is incremented twice, and then
|
||||||
|
// destructed....
|
||||||
|
// }
|
||||||
|
// threads[0].join();
|
||||||
|
// threads[1].join();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// `done_callback` will be called when both conditions are true:
|
||||||
|
// 1) after `barrier` is destructed.
|
||||||
|
// 2) Each `BarrierCallback` returned by `Inc` is called.
|
||||||
|
// This class is thread-safe.
|
||||||
|
class IncrementalBarrier {
|
||||||
|
public:
|
||||||
|
typedef std::function<void()> DoneCallback;
|
||||||
|
typedef std::function<void()> BarrierCallback;
|
||||||
|
explicit IncrementalBarrier(DoneCallback callback);
|
||||||
|
|
||||||
|
~IncrementalBarrier();
|
||||||
|
|
||||||
|
// Returns a BarrierCallback (std::function) that individual task call to
|
||||||
|
// signal its completeness.
|
||||||
|
// The returned BarrierCallback outlives this `IncrementalBarrier` instance.
|
||||||
|
// Furthermore, each task should eventually call the returned function, or
|
||||||
|
// else done_callback wouldn't be called.
|
||||||
|
BarrierCallback Inc();
|
||||||
|
|
||||||
|
private:
|
||||||
|
// self-deleting, thereby not owned by 'IncrementalBarrier'.
|
||||||
|
InternalIncrementalBarrier* internal_barrier_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
|
133
tensorflow/core/util/incremental_barrier_test.cc
Normal file
133
tensorflow/core/util/incremental_barrier_test.cc
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/incremental_barrier.h"
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "absl/functional/bind_front.h"
|
||||||
|
#include "absl/time/time.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/platform.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// A thread-safe counter class.
|
||||||
|
class Counter {
|
||||||
|
public:
|
||||||
|
void Increment() TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
++count_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetCount() TF_LOCKS_EXCLUDED(mu_) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return count_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutex mu_;
|
||||||
|
int count_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) {
|
||||||
|
Counter counter;
|
||||||
|
EXPECT_EQ(counter.GetCount(), 0);
|
||||||
|
{
|
||||||
|
IncrementalBarrier::DoneCallback done_callback =
|
||||||
|
absl::bind_front(&Counter::Increment, &counter);
|
||||||
|
IncrementalBarrier barrier(done_callback);
|
||||||
|
EXPECT_EQ(counter.GetCount(), 0);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(counter.GetCount(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) {
|
||||||
|
Counter counter;
|
||||||
|
|
||||||
|
IncrementalBarrier::BarrierCallback bc1, bc2;
|
||||||
|
{
|
||||||
|
IncrementalBarrier::DoneCallback done_callback =
|
||||||
|
absl::bind_front(&Counter::Increment, &counter);
|
||||||
|
IncrementalBarrier barrier(done_callback);
|
||||||
|
|
||||||
|
CHECK_EQ(counter.GetCount(), 0);
|
||||||
|
|
||||||
|
bc1 = barrier.Inc();
|
||||||
|
bc2 = barrier.Inc();
|
||||||
|
|
||||||
|
IncrementalBarrier::BarrierCallback bc3 = barrier.Inc();
|
||||||
|
bc3();
|
||||||
|
|
||||||
|
CHECK_EQ(counter.GetCount(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_EQ(counter.GetCount(), 0);
|
||||||
|
bc1();
|
||||||
|
CHECK_EQ(counter.GetCount(), 0);
|
||||||
|
bc2();
|
||||||
|
CHECK_EQ(counter.GetCount(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) {
|
||||||
|
const int num_closure = 100, num_thread = 2;
|
||||||
|
std::atomic<int> schedule_count{0};
|
||||||
|
Counter counter;
|
||||||
|
|
||||||
|
{
|
||||||
|
IncrementalBarrier::DoneCallback done_callback =
|
||||||
|
absl::bind_front(&Counter::Increment, &counter);
|
||||||
|
IncrementalBarrier barrier(done_callback);
|
||||||
|
|
||||||
|
CHECK_EQ(counter.GetCount(), 0);
|
||||||
|
|
||||||
|
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
|
||||||
|
"BarrierClosure", num_thread);
|
||||||
|
for (int i = 0; i < num_closure; ++i) {
|
||||||
|
pool.Schedule([&barrier, &schedule_count]() {
|
||||||
|
schedule_count.fetch_add(1);
|
||||||
|
IncrementalBarrier::BarrierCallback bc = barrier.Inc();
|
||||||
|
|
||||||
|
Env::Default()->SleepForMicroseconds(100);
|
||||||
|
bc();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_EQ(counter.GetCount(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100);
|
||||||
|
CHECK_EQ(counter.GetCount(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(PLATFORM_GOOGLE)
|
||||||
|
void BM_FunctionInc(benchmark::State& state) {
|
||||||
|
IncrementalBarrier barrier([] {});
|
||||||
|
for (auto _ : state) {
|
||||||
|
barrier.Inc()();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_FunctionInc);
|
||||||
|
#endif // PLATFORM_GOOGLE
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user