[XLA:PJRT] Add optional platform-specific fingerprint to PjRtExecutable.
If implemented by the client, this fingerprint is used as the executable's launch ID. PiperOrigin-RevId: 324256013 Change-Id: I2288dad54dc5ba73d3d65cb71d7dd1e54e14b048
This commit is contained in:
parent
f3e7bc6a0b
commit
7e7d62b735
@ -1610,6 +1610,10 @@ StatusOr<ScopedShapedBuffer> PjRtExecutable::EnqueueExecution(
|
||||
run_options.set_run_id(run_id);
|
||||
run_options.set_rng_seed(device_state->GetNewPrngSeed());
|
||||
run_options.set_gpu_executable_run_options(client_->gpu_run_options());
|
||||
run_options.set_launch_id(options.launch_id);
|
||||
if (run_options.launch_id() != 0) {
|
||||
VLOG(1) << "launch id for " << name() << ": " << run_options.launch_id();
|
||||
}
|
||||
|
||||
// The choice of where we wait is arbitrary; the reason for the wait is
|
||||
// pacing to avoid problems such as memory fragmentation and running ahead
|
||||
@ -2138,13 +2142,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
client->client()->Compile(computation, argument_layout_pointers,
|
||||
build_options));
|
||||
|
||||
auto py_executable = absl::make_unique<PjRtExecutable>(
|
||||
auto executable = absl::make_unique<PjRtExecutable>(
|
||||
std::move(local_executables), options.parameter_is_tupled_arguments,
|
||||
std::move(device_assignment), std::move(local_logical_device_ids),
|
||||
std::move(local_devices), client);
|
||||
TF_RETURN_IF_ERROR(py_executable->SetUpDonation(
|
||||
client, options.parameter_is_tupled_arguments));
|
||||
return py_executable;
|
||||
TF_RETURN_IF_ERROR(
|
||||
executable->SetUpDonation(client, options.parameter_is_tupled_arguments));
|
||||
return executable;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -119,6 +119,8 @@ struct PjRtCrossHostRecvBuffer {
|
||||
using PjRtCrossHostRecvNotifier =
|
||||
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
|
||||
|
||||
class PjRtExecutable;
|
||||
|
||||
// Encapsulates the state of Python session with XLA.
|
||||
//
|
||||
// It is the responsibility of the client of this API to keep the PjRtClient
|
||||
@ -181,6 +183,13 @@ class PjRtClient {
|
||||
virtual StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const LocalExecutable& executable, bool tuple_inputs) const;
|
||||
|
||||
// Generates a unique fingerprint for `executable`. See
|
||||
// PjRtExecutable::fingerprint_.
|
||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
const PjRtExecutable& executable) const {
|
||||
return absl::optional<std::string>();
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class PjRtBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
@ -668,6 +677,11 @@ struct ExecuteOptions {
|
||||
// If true, the computation must return a tuple, which will be destructured
|
||||
// into its elements.
|
||||
bool untuple_result = false;
|
||||
// If non-zero, identifies this execution as part of a potentially
|
||||
// multi-device launch. This can be used to detect scheduling errors, e.g. if
|
||||
// multi-host programs are launched in different orders on different hosts,
|
||||
// the launch IDs may be used by the runtime to detect the mismatch.
|
||||
int32 launch_id = 0;
|
||||
};
|
||||
|
||||
// Represents a compiled computation that can be executed given handles to
|
||||
@ -687,6 +701,8 @@ class PjRtExecutable {
|
||||
std::vector<std::pair<int, int>> local_logical_device_ids,
|
||||
std::vector<Device*> local_devices, PjRtClient* client);
|
||||
|
||||
virtual ~PjRtExecutable() = default;
|
||||
|
||||
PjRtClient* client() const { return client_; }
|
||||
|
||||
int num_replicas() const {
|
||||
@ -744,12 +760,14 @@ class PjRtExecutable {
|
||||
// Initializes information about which arguments to which executables must be
|
||||
// donated due to aliases that were specified by the computation.
|
||||
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
|
||||
|
||||
StatusOr<ScopedShapedBuffer> EnqueueExecution(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, int executable_idx, const RunId& run_id,
|
||||
const ExecuteOptions& options, Device* device,
|
||||
std::vector<PjRtBuffer::ScopedHold>* device_buffers,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment) const;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<PjRtBuffer>>> ExecuteHelper(
|
||||
absl::Span<PjRtBuffer* const> argument_handles, int replica,
|
||||
int partition, const RunId& run_id, const ExecuteOptions& options,
|
||||
|
||||
@ -202,6 +202,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/core/platform:fingerprint",
|
||||
"//tensorflow/core/profiler:protos_all_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
||||
@ -124,15 +124,19 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||
StatusOr<std::unique_ptr<PyExecutable>> PyClient::Compile(
|
||||
const XlaComputation& computation, CompileOptions options) {
|
||||
std::unique_ptr<PjRtExecutable> executable;
|
||||
absl::optional<std::string> fingerprint;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(executable,
|
||||
PjRtExecutable::Compile(computation, pjrt_client_.get(),
|
||||
std::move(options)));
|
||||
TF_ASSIGN_OR_RETURN(fingerprint,
|
||||
pjrt_client_->ExecutableFingerprint(*executable));
|
||||
}
|
||||
auto traceback = Traceback::Get();
|
||||
return std::make_unique<PyExecutable>(
|
||||
shared_from_this(), std::move(executable), std::move(traceback));
|
||||
shared_from_this(), std::move(executable), std::move(traceback),
|
||||
std::move(fingerprint));
|
||||
}
|
||||
|
||||
class ProfileBuilder {
|
||||
|
||||
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -23,10 +24,12 @@ namespace py = pybind11;
|
||||
|
||||
PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
std::shared_ptr<Traceback> traceback)
|
||||
std::shared_ptr<Traceback> traceback,
|
||||
absl::optional<std::string> fingerprint)
|
||||
: client_(std::move(client)),
|
||||
executable_(std::move(executable)),
|
||||
traceback_(std::move(traceback)) {
|
||||
traceback_(std::move(traceback)),
|
||||
fingerprint_(std::move(fingerprint)) {
|
||||
CHECK(PyGILState_Check());
|
||||
next_ = client_->executables_;
|
||||
client_->executables_ = this;
|
||||
@ -34,6 +37,10 @@ PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
|
||||
if (next_) {
|
||||
next_->prev_ = this;
|
||||
}
|
||||
if (fingerprint_) {
|
||||
VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
|
||||
<< *fingerprint_;
|
||||
}
|
||||
}
|
||||
|
||||
PyExecutable::~PyExecutable() {
|
||||
@ -65,6 +72,9 @@ StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
|
||||
py::gil_scoped_release gil_release;
|
||||
ExecuteOptions options;
|
||||
options.untuple_result = true;
|
||||
if (fingerprint_) {
|
||||
options.launch_id = tensorflow::Fingerprint32(*fingerprint_);
|
||||
}
|
||||
std::vector<PjRtBuffer*> arg_buffers(args.size());
|
||||
absl::c_transform(args, arg_buffers.begin(),
|
||||
[](PyBuffer* buf) { return buf->buffer(); });
|
||||
@ -89,6 +99,9 @@ PyExecutable::ExecuteOnLocalDevices(
|
||||
py::gil_scoped_release gil_release;
|
||||
ExecuteOptions options;
|
||||
options.untuple_result = true;
|
||||
if (fingerprint_) {
|
||||
options.launch_id = tensorflow::Fingerprint32(*fingerprint_);
|
||||
}
|
||||
std::vector<std::vector<PjRtBuffer*>> arg_buffers(args.size());
|
||||
for (int computation = 0; computation < args.size(); ++computation) {
|
||||
arg_buffers[computation].resize(args[computation].size());
|
||||
|
||||
@ -37,7 +37,8 @@ class PyExecutable {
|
||||
public:
|
||||
PyExecutable(std::shared_ptr<PyClient> client,
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
std::shared_ptr<Traceback> traceback);
|
||||
std::shared_ptr<Traceback> traceback,
|
||||
absl::optional<std::string> fingerprint);
|
||||
~PyExecutable();
|
||||
|
||||
std::shared_ptr<PyClient> client() const { return client_; }
|
||||
@ -71,6 +72,11 @@ class PyExecutable {
|
||||
std::unique_ptr<PjRtExecutable> executable_;
|
||||
std::shared_ptr<Traceback> traceback_;
|
||||
|
||||
// Identical executables (i.e. representing the same program) will have the
|
||||
// same fingerprint. nullopt on platforms or executables where fingerprints
|
||||
// aren't implemented.
|
||||
absl::optional<std::string> fingerprint_;
|
||||
|
||||
// Doubly-linked list of all executables known to the client. Protected by the
|
||||
// GIL.
|
||||
PyExecutable* next_;
|
||||
|
||||
@ -90,6 +90,15 @@ inline uint64 Fingerprint64(const StringPiece s) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 32-bit variant of Fingerprint64 above (same properties and caveats apply).
|
||||
inline uint32 Fingerprint32(const StringPiece s) {
|
||||
#ifdef USE_OSS_FARMHASH
|
||||
return ::util::Fingerprint32(s.data(), s.size());
|
||||
#else
|
||||
return farmhash::Fingerprint32(s.data(), s.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
// 128-bit variant of Fingerprint64 above (same properties and caveats apply).
|
||||
inline Fprint128 Fingerprint128(const StringPiece s) {
|
||||
#ifdef USE_OSS_FARMHASH
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user