[XLA:PJRT] Add optional platform-specific fingerprint to PjRtExecutable.

If implemented by the client, this fingerprint is used as the executable's launch ID.

PiperOrigin-RevId: 324256013
Change-Id: I2288dad54dc5ba73d3d65cb71d7dd1e54e14b048
This commit is contained in:
Skye Wanderman-Milne 2020-07-31 12:05:20 -07:00 committed by TensorFlower Gardener
parent f3e7bc6a0b
commit 7e7d62b735
7 changed files with 63 additions and 8 deletions

View File

@ -1610,6 +1610,10 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
run_options.set_run_id(run_id);
run_options.set_rng_seed(device_state->GetNewPrngSeed());
run_options.set_gpu_executable_run_options(client_->gpu_run_options());
run_options.set_launch_id(options.launch_id);
if (run_options.launch_id() != 0) {
VLOG(1) << "launch id for " << name() << ": " << run_options.launch_id();
}
// The choice of where we wait is arbitrary; the reason for the wait is
// pacing to avoid problems such as memory fragmentation and running ahead
@ -2138,13 +2142,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
client->client()->Compile(computation, argument_layout_pointers,
build_options));
auto py_executable = absl::make_unique<PjRtExecutable>(
auto executable = absl::make_unique<PjRtExecutable>(
std::move(local_executables), options.parameter_is_tupled_arguments,
std::move(device_assignment), std::move(local_logical_device_ids),
std::move(local_devices), client);
TF_RETURN_IF_ERROR(py_executable->SetUpDonation(
client, options.parameter_is_tupled_arguments));
return py_executable;
TF_RETURN_IF_ERROR(
executable->SetUpDonation(client, options.parameter_is_tupled_arguments));
return executable;
}
} // namespace xla

View File

@ -119,6 +119,8 @@ struct PjRtCrossHostRecvBuffer {
using PjRtCrossHostRecvNotifier =
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
class PjRtExecutable;
// Encapsulates the state of Python session with XLA.
//
// It is the responsibility of the client of this API to keep the PjRtClient
@ -181,6 +183,13 @@ class PjRtClient {
virtual StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const LocalExecutable& executable, bool tuple_inputs) const;
// Generates a unique fingerprint for `executable`. See
// PjRtExecutable::fingerprint_.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
const PjRtExecutable& executable) const {
return absl::optional<std::string>();
}
protected:
friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive(
@ -668,6 +677,11 @@ struct ExecuteOptions {
// If true, the computation must return a tuple, which will be destructured
// into its elements.
bool untuple_result = false;
// If non-zero, identifies this execution as part of a potentially
// multi-device launch. This can be used to detect scheduling errors, e.g. if
// multi-host programs are launched in different orders on different hosts,
// the launch IDs may be used by the runtime to detect the mismatch.
int32 launch_id = 0;
};
// Represents a compiled computation that can be executed given handles to
@ -687,6 +701,8 @@ class PjRtExecutable {
std::vector<std::pair<int, int>> local_logical_device_ids,
std::vector<Device*> local_devices, PjRtClient* client);
virtual ~PjRtExecutable() = default;
PjRtClient* client() const { return client_; }
int num_replicas() const {
@ -744,12 +760,14 @@ class PjRtExecutable {
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
StatusOr<ScopedShapedBuffer> EnqueueExecution(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, int executable_idx, const RunId& run_id,
const ExecuteOptions& options, Device* device,
std::vector<PjRtBuffer::ScopedHold>* device_buffers,
std::shared_ptr<DeviceAssignment> device_assignment) const;
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
absl::Span<PjRtBuffer* const> argument_handles, int replica,
int partition, const RunId& run_id, const ExecuteOptions& options,

View File

@ -202,6 +202,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/core/platform:fingerprint",
"//tensorflow/core/profiler:protos_all_cc",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -124,15 +124,19 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
const XlaComputation& computation, CompileOptions options) {
std::unique_ptr<PjRtExecutable> executable;
absl::optional<std::string> fingerprint;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(executable,
PjRtExecutable::Compile(computation, pjrt_client_.get(),
std::move(options)));
TF_ASSIGN_OR_RETURN(fingerprint,
pjrt_client_->ExecutableFingerprint(*executable));
}
auto traceback = Traceback::Get();
return std::make_unique<PyExecutable>(
shared_from_this(), std::move(executable), std::move(traceback));
shared_from_this(), std::move(executable), std::move(traceback),
std::move(fingerprint));
}
class ProfileBuilder {

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/py_executable.h"
#include "absl/algorithm/container.h"
#include "tensorflow/core/platform/fingerprint.h"
namespace xla {
@ -23,10 +24,12 @@ namespace py = pybind11;
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
std::shared_ptr<Traceback> traceback)
std::shared_ptr<Traceback> traceback,
absl::optional<std::string> fingerprint)
: client_(std::move(client)),
executable_(std::move(executable)),
traceback_(std::move(traceback)) {
traceback_(std::move(traceback)),
fingerprint_(std::move(fingerprint)) {
CHECK(PyGILState_Check());
next_ = client_->executables_;
client_->executables_ = this;
@ -34,6 +37,10 @@ PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
if (next_) {
next_->prev_ = this;
}
if (fingerprint_) {
VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
<< *fingerprint_;
}
}
PyExecutable::~PyExecutable() {
@ -65,6 +72,9 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
if (fingerprint_) {
options.launch_id = tensorflow::Fingerprint32(*fingerprint_);
}
std::vector<PjRtBuffer*> arg_buffers(args.size());
absl::c_transform(args, arg_buffers.begin(),
[](PyBuffer* buf) { return buf->buffer(); });
@ -89,6 +99,9 @@ PyExecutable::ExecuteOnLocalDevices(
py::gil_scoped_release gil_release;
ExecuteOptions options;
options.untuple_result = true;
if (fingerprint_) {
options.launch_id = tensorflow::Fingerprint32(*fingerprint_);
}
std::vector<std::vector<PjRtBuffer*>> arg_buffers(args.size());
for (int computation = 0; computation < args.size(); ++computation) {
arg_buffers[computation].resize(args[computation].size());

View File

@ -37,7 +37,8 @@ class PyExecutable {
public:
PyExecutable(std::shared_ptr<PyClient> client,
std::unique_ptr<PjRtExecutable> executable,
std::shared_ptr<Traceback> traceback);
std::shared_ptr<Traceback> traceback,
absl::optional<std::string> fingerprint);
~PyExecutable();
std::shared_ptr<PyClient> client() const { return client_; }
@ -71,6 +72,11 @@ class PyExecutable {
std::unique_ptr<PjRtExecutable> executable_;
std::shared_ptr<Traceback> traceback_;
// Identical executables (i.e. representing the same program) will have the
// same fingerprint. nullopt on platforms or executables where fingerprints
// aren't implemented.
absl::optional<std::string> fingerprint_;
// Doubly-linked list of all executables known to the client. Protected by the
// GIL.
PyExecutable* next_;

View File

@ -90,6 +90,15 @@ inline uint64 Fingerprint64(const StringPiece s) {
#endif
}
// 32-bit variant of Fingerprint64 above (same properties and caveats apply).
inline uint32 Fingerprint32(const StringPiece s) {
#ifdef USE_OSS_FARMHASH
return ::util::Fingerprint32(s.data(), s.size());
#else
return farmhash::Fingerprint32(s.data(), s.size());
#endif
}
// 128-bit variant of Fingerprint64 above (same properties and caveats apply).
inline Fprint128 Fingerprint128(const StringPiece s) {
#ifdef USE_OSS_FARMHASH