In tensorflow/core/util/, introduce a IncrementalBarrier library.
PiperOrigin-RevId: 313250473 Change-Id: I0cbb2d263d1639b1ea444b05ae7f5ea29fa252ce
This commit is contained in:
		
							parent
							
								
									13de0f1c98
								
							
						
					
					
						commit
						c068a625c5
					
				@ -505,6 +505,16 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "incremental_barrier",
 | 
			
		||||
    srcs = ["incremental_barrier.cc"],
 | 
			
		||||
    hdrs = ["incremental_barrier.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "@com_google_absl//absl/functional:bind_front",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Tests.
 | 
			
		||||
 | 
			
		||||
tf_cc_test(
 | 
			
		||||
@ -632,6 +642,20 @@ tf_cc_test(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cc_test(
 | 
			
		||||
    name = "incremental_barrier_test",
 | 
			
		||||
    srcs = ["incremental_barrier_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":incremental_barrier",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "//tensorflow/core:test_main",
 | 
			
		||||
        "//tensorflow/core/platform",
 | 
			
		||||
        "@com_google_absl//absl/functional:bind_front",
 | 
			
		||||
        "@com_google_absl//absl/time",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Proto libraries.
 | 
			
		||||
tf_proto_library(
 | 
			
		||||
    name = "test_log_proto_impl",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										64
									
								
								tensorflow/core/util/incremental_barrier.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								tensorflow/core/util/incremental_barrier.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/util/incremental_barrier.h"
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <functional>
 | 
			
		||||
 | 
			
		||||
#include "absl/functional/bind_front.h"
 | 
			
		||||
#include "tensorflow/core/platform/logging.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class InternalIncrementalBarrier {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit InternalIncrementalBarrier(IncrementalBarrier::DoneCallback callback)
 | 
			
		||||
      : left_(1), done_callback_(std::move(callback)) {}
 | 
			
		||||
 | 
			
		||||
  void operator()() {
 | 
			
		||||
    DCHECK_GE(left_.load(std::memory_order_relaxed), 0);
 | 
			
		||||
 | 
			
		||||
    if (left_.fetch_sub(1, std::memory_order_acq_rel) - 1 == 0) {
 | 
			
		||||
      IncrementalBarrier::DoneCallback done_callback =
 | 
			
		||||
          std::move(done_callback_);
 | 
			
		||||
      delete this;
 | 
			
		||||
      done_callback();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  IncrementalBarrier::BarrierCallback Inc() {
 | 
			
		||||
    left_.fetch_add(1, std::memory_order_acq_rel);
 | 
			
		||||
 | 
			
		||||
    // std::bind_front is only available ever since C++20.
 | 
			
		||||
    return absl::bind_front(&InternalIncrementalBarrier::operator(), this);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::atomic<int> left_;
 | 
			
		||||
  IncrementalBarrier::DoneCallback done_callback_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
IncrementalBarrier::IncrementalBarrier(DoneCallback done_callback)
 | 
			
		||||
    : internal_barrier_(
 | 
			
		||||
          new InternalIncrementalBarrier(std::move(done_callback))) {}
 | 
			
		||||
 | 
			
		||||
IncrementalBarrier::~IncrementalBarrier() { (*internal_barrier_)(); }
 | 
			
		||||
 | 
			
		||||
IncrementalBarrier::BarrierCallback IncrementalBarrier::Inc() {
 | 
			
		||||
  return internal_barrier_->Inc();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
							
								
								
									
										81
									
								
								tensorflow/core/util/incremental_barrier.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								tensorflow/core/util/incremental_barrier.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,81 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
 | 
			
		||||
#define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <functional>
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
 | 
			
		||||
class InternalIncrementalBarrier;
 | 
			
		||||
 | 
			
		||||
// BarrierClosure (see
 | 
			
		||||
// https://github.com/chromium/chromium/blob/master/base/barrier_closure.h)
 | 
			
		||||
// executes a callback after it has been invoked |num_closures| times.
 | 
			
		||||
// Plus, `BarrierClosure` is a continuation-passing style abstraction and self-
 | 
			
		||||
// deleting.
 | 
			
		||||
 | 
			
		||||
// IncrementalBarrier is a convenience class to be used in place of a barrier
 | 
			
		||||
// closure, which is particularly helpful (e.g. simplify code) because callers
 | 
			
		||||
// don't need to calculate the |num_closures| beforehand.
 | 
			
		||||
//
 | 
			
		||||
// Example Usage:
 | 
			
		||||
//   void MakeCalls() {
 | 
			
		||||
//     typedef std::function<void()> Callback;
 | 
			
		||||
//     typedef std::function<void(Callback)> OtherCallback;
 | 
			
		||||
//     Callback done_callback = ...
 | 
			
		||||
//     OtherCallback cb1 = ...
 | 
			
		||||
//     OtherCallback cb2 = ...
 | 
			
		||||
//     std::thread threads[2];
 | 
			
		||||
//     {
 | 
			
		||||
//         IncrementalBarrier barrier(done_callback);
 | 
			
		||||
//         threads[0] = std::thread(cb1(barrier.Inc());
 | 
			
		||||
//         threads[1] = std::thread(cb2(barrier.Inc());
 | 
			
		||||
//         ... at this moment, `barrier` is incremented twice, and then
 | 
			
		||||
//         destructed....
 | 
			
		||||
//     }
 | 
			
		||||
//     threads[0].join();
 | 
			
		||||
//     threads[1].join();
 | 
			
		||||
//   }
 | 
			
		||||
//
 | 
			
		||||
//  `done_callback` will be called when both conditions are true:
 | 
			
		||||
//  1) after `barrier` is destructed.
 | 
			
		||||
//  2) Each `BarrierCallback` returned by `Inc` is called.
 | 
			
		||||
// This class is thread-safe.
 | 
			
		||||
class IncrementalBarrier {
 | 
			
		||||
 public:
 | 
			
		||||
  typedef std::function<void()> DoneCallback;
 | 
			
		||||
  typedef std::function<void()> BarrierCallback;
 | 
			
		||||
  explicit IncrementalBarrier(DoneCallback callback);
 | 
			
		||||
 | 
			
		||||
  ~IncrementalBarrier();
 | 
			
		||||
 | 
			
		||||
  // Returns a BarrierCallback (std::function) that individual task call to
 | 
			
		||||
  // signal its completeness.
 | 
			
		||||
  // The returned BarrierCallback outlives this `IncrementalBarrier` instance.
 | 
			
		||||
  // Furthermore, each task should eventually call the returned function, or
 | 
			
		||||
  // else done_callback wouldn't be called.
 | 
			
		||||
  BarrierCallback Inc();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // self-deleting, thereby not owned by 'IncrementalBarrier'.
 | 
			
		||||
  InternalIncrementalBarrier* internal_barrier_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_INCREMENTAL_BARRIER_H_
 | 
			
		||||
							
								
								
									
										133
									
								
								tensorflow/core/util/incremental_barrier_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								tensorflow/core/util/incremental_barrier_test.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,133 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/util/incremental_barrier.h"
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
 | 
			
		||||
#include "absl/functional/bind_front.h"
 | 
			
		||||
#include "absl/time/time.h"
 | 
			
		||||
#include "tensorflow/core/platform/env.h"
 | 
			
		||||
#include "tensorflow/core/platform/mutex.h"
 | 
			
		||||
#include "tensorflow/core/platform/platform.h"
 | 
			
		||||
#include "tensorflow/core/platform/test.h"
 | 
			
		||||
#include "tensorflow/core/platform/test_benchmark.h"
 | 
			
		||||
#include "tensorflow/core/platform/thread_annotations.h"
 | 
			
		||||
#include "tensorflow/core/platform/threadpool.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// A thread-safe counter class.
 | 
			
		||||
class Counter {
 | 
			
		||||
 public:
 | 
			
		||||
  void Increment() TF_LOCKS_EXCLUDED(mu_) {
 | 
			
		||||
    mutex_lock l(mu_);
 | 
			
		||||
    ++count_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int GetCount() TF_LOCKS_EXCLUDED(mu_) {
 | 
			
		||||
    mutex_lock l(mu_);
 | 
			
		||||
    return count_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  mutex mu_;
 | 
			
		||||
  int count_ = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST(IncrementalBarrierTest, RunInstantlyWhenZeroClosure) {
 | 
			
		||||
  Counter counter;
 | 
			
		||||
  EXPECT_EQ(counter.GetCount(), 0);
 | 
			
		||||
  {
 | 
			
		||||
    IncrementalBarrier::DoneCallback done_callback =
 | 
			
		||||
        absl::bind_front(&Counter::Increment, &counter);
 | 
			
		||||
    IncrementalBarrier barrier(done_callback);
 | 
			
		||||
    EXPECT_EQ(counter.GetCount(), 0);
 | 
			
		||||
  }
 | 
			
		||||
  EXPECT_EQ(counter.GetCount(), 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(IncrementalBarrierTest, RunAfterNumClosuresOneNowTwoLater) {
 | 
			
		||||
  Counter counter;
 | 
			
		||||
 | 
			
		||||
  IncrementalBarrier::BarrierCallback bc1, bc2;
 | 
			
		||||
  {
 | 
			
		||||
    IncrementalBarrier::DoneCallback done_callback =
 | 
			
		||||
        absl::bind_front(&Counter::Increment, &counter);
 | 
			
		||||
    IncrementalBarrier barrier(done_callback);
 | 
			
		||||
 | 
			
		||||
    CHECK_EQ(counter.GetCount(), 0);
 | 
			
		||||
 | 
			
		||||
    bc1 = barrier.Inc();
 | 
			
		||||
    bc2 = barrier.Inc();
 | 
			
		||||
 | 
			
		||||
    IncrementalBarrier::BarrierCallback bc3 = barrier.Inc();
 | 
			
		||||
    bc3();
 | 
			
		||||
 | 
			
		||||
    CHECK_EQ(counter.GetCount(), 0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  CHECK_EQ(counter.GetCount(), 0);
 | 
			
		||||
  bc1();
 | 
			
		||||
  CHECK_EQ(counter.GetCount(), 0);
 | 
			
		||||
  bc2();
 | 
			
		||||
  CHECK_EQ(counter.GetCount(), 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(IncrementalBarrierTest, RunAfterNumClosuresConcurrency) {
 | 
			
		||||
  const int num_closure = 100, num_thread = 2;
 | 
			
		||||
  std::atomic<int> schedule_count{0};
 | 
			
		||||
  Counter counter;
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    IncrementalBarrier::DoneCallback done_callback =
 | 
			
		||||
        absl::bind_front(&Counter::Increment, &counter);
 | 
			
		||||
    IncrementalBarrier barrier(done_callback);
 | 
			
		||||
 | 
			
		||||
    CHECK_EQ(counter.GetCount(), 0);
 | 
			
		||||
 | 
			
		||||
    tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(),
 | 
			
		||||
                                        "BarrierClosure", num_thread);
 | 
			
		||||
    for (int i = 0; i < num_closure; ++i) {
 | 
			
		||||
      pool.Schedule([&barrier, &schedule_count]() {
 | 
			
		||||
        schedule_count.fetch_add(1);
 | 
			
		||||
        IncrementalBarrier::BarrierCallback bc = barrier.Inc();
 | 
			
		||||
 | 
			
		||||
        Env::Default()->SleepForMicroseconds(100);
 | 
			
		||||
        bc();
 | 
			
		||||
      });
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    CHECK_EQ(counter.GetCount(), 0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  CHECK_EQ(schedule_count.load(std::memory_order_relaxed), 100);
 | 
			
		||||
  CHECK_EQ(counter.GetCount(), 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if defined(PLATFORM_GOOGLE)
 | 
			
		||||
void BM_FunctionInc(benchmark::State& state) {
 | 
			
		||||
  IncrementalBarrier barrier([] {});
 | 
			
		||||
  for (auto _ : state) {
 | 
			
		||||
    barrier.Inc()();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
BENCHMARK(BM_FunctionInc);
 | 
			
		||||
#endif  // PLATFORM_GOOGLE
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user