STT-tensorflow/tensorflow/compiler/xla/refcounting_hash_map.h
Peter Hawkins 8a72c4466a [XLA:GPU] Add experimental, lightly tested support for multi-host and multi-process NCCL AllReduce.
This change makes several API changes:
* we allow the client to provide a mapping from the local device ordinals on the machine to global device IDs. If provided, we interpret the device IDs in the DeviceAssignment provided by the client as global IDs, not as local device ordinals. This allows us to describe computations that cross a host boundary.
* we allow the client to provide a callback for manufacturing a ncclUniqueId for a particular subset of global devices. The idea is that the client should use some other distributed system of their own (e.g., MPI) to share ncclUniqueId values needed for a computation. NCCL allows for cross-host/process collectives iff the same ncclUniqueId value is used.

Refactors the common collective logic and the NCCL collective logic in particular to support a local/global distinction.

PiperOrigin-RevId: 296505571
Change-Id: I5ed42d65597b0960df78890745421f77e9789ba3
2020-02-21 14:00:29 -08:00

118 lines
3.9 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
#define TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_
#include <functional>
#include <memory>
#include "absl/base/thread_annotations.h"
#include "absl/container/node_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
namespace xla {
// RefcountingHashMap is an "eager, thread-safe cache".
//
// Given a key k you can retrieve a shared_ptr to a value v. If k is not
// already in the map, we construct a new V; if it is already in the map, we'll
// return the existing v. Once all shared_ptrs are destroyed, the entry is
// removed from the map.
//
// This class is thread-safe.
//
// Word to the wise: You might want an erase() function here that removes a
// value from the map but leaves existing shared_ptrs intact. My experience is,
// this is extremely complicated to implement correctly.
template <typename K, typename V>
class RefcountingHashMap {
public:
// Default-constructs new values.
RefcountingHashMap() = default;
// Not copyable or movable because this contains internal pointers (namely,
// instances of Deleter contain pointers to `this` and into `map_`).
RefcountingHashMap(const RefcountingHashMap&) = delete;
RefcountingHashMap(RefcountingHashMap&&) = delete;
RefcountingHashMap& operator=(const RefcountingHashMap&) = delete;
RefcountingHashMap& operator=(RefcountingHashMap&&) = delete;
// Gets the value for the given key.
//
// If the map doesn't contain a live value for the key, constructs one
// using `value_factory`.
std::shared_ptr<V> GetOrCreateIfAbsent(
const K& key,
const std::function<std::unique_ptr<V>(const K&)>& value_factory) {
absl::MutexLock lock(&mu_);
auto it = map_.find(key);
// We ensure that the entry has not expired in case deleter was running when
// we have entered this block.
if (it != map_.end()) {
if (std::shared_ptr<V> value = it->second.lock()) {
return value;
}
map_.erase(it);
}
// Create entry in the map and then set its value, so the value can
// contain a pointer back into the map.
it = map_.emplace(key, std::weak_ptr<V>()).first;
std::shared_ptr<V> value(value_factory(key).release(),
Deleter{&it->first, this});
it->second = value; // Set the weak ptr to the shared ptr.
return value;
}
// Runs a function over every key/value in the map.
//
// Touching the map from within this function may deadlock; don't do it.
//
// Function signature must be compatible with
// void fn(const K&, std::shared_ptr<V>)
//
template <typename Fn>
void ForEach(Fn&& fn) {
absl::MutexLock lock(&mu_);
for (const auto& kv : map_) {
fn(kv.first, kv.second.lock());
}
}
private:
struct Deleter {
const K* key; // Points into parent->map_.
RefcountingHashMap* parent;
void operator()(V* v) {
delete v;
absl::MutexLock lock(&parent->mu_);
auto it = parent->map_.find(*key);
if (it != parent->map_.end() && it->second.expired()) {
parent->map_.erase(it);
}
}
};
absl::Mutex mu_;
absl::node_hash_map<K, std::weak_ptr<V>> map_ ABSL_GUARDED_BY(mu_);
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_REFCOUNTING_HASH_MAP_H_