Refactor PJRT client.
- Move GetParametersThatMustBeDonated to a free function according to interface principle. https://www.fluentcpp.com/2017/06/20/interface-principle-cpp/ PiperOrigin-RevId: 338583883 Change-Id: I2e9c868d80d99512c02a3a553112656f254b5474
This commit is contained in:
parent
2cea0006f2
commit
99cc6a2df0
@ -235,62 +235,6 @@ StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
|
|||||||
num_partitions);
|
num_partitions);
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<absl::flat_hash_set<int>> PjRtClient::GetParametersThatMustBeDonated(
|
|
||||||
const LocalExecutable& executable, bool tuple_inputs) const {
|
|
||||||
HloComputation* computation =
|
|
||||||
executable.executable()->module().entry_computation();
|
|
||||||
int number_of_parameters = [&]() -> int {
|
|
||||||
if (tuple_inputs) {
|
|
||||||
CHECK_EQ(computation->num_parameters(), 1);
|
|
||||||
const Shape& input_tuple_shape =
|
|
||||||
computation->parameter_instruction(0)->shape();
|
|
||||||
CHECK(input_tuple_shape.IsTuple());
|
|
||||||
return input_tuple_shape.tuple_shapes_size();
|
|
||||||
} else {
|
|
||||||
return computation->num_parameters();
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
// If any buffer in a parameter is aliased we will donate the entire input
|
|
||||||
// parameter.
|
|
||||||
absl::flat_hash_set<int> parameters_to_donate;
|
|
||||||
const HloInputOutputAliasConfig& config =
|
|
||||||
executable.executable()->module().input_output_alias_config();
|
|
||||||
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
|
||||||
[&](const ShapeIndex& output_index,
|
|
||||||
const HloInputOutputAliasConfig::Alias& alias) {
|
|
||||||
if (tuple_inputs) {
|
|
||||||
if (alias.parameter_number != 0) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Unexpected parameter number %d in alias config with tupled "
|
|
||||||
"inputs",
|
|
||||||
alias.parameter_number);
|
|
||||||
}
|
|
||||||
const ShapeIndex& index = alias.parameter_index;
|
|
||||||
if (!index.empty()) {
|
|
||||||
int this_parameter = index.data()[0];
|
|
||||||
if (this_parameter >= number_of_parameters) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Unexpected parameter index %s in alias config with tupled "
|
|
||||||
"inputs and %d parameters",
|
|
||||||
index.ToString(), number_of_parameters);
|
|
||||||
}
|
|
||||||
parameters_to_donate.insert(this_parameter);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
int this_parameter = alias.parameter_number;
|
|
||||||
if (this_parameter >= number_of_parameters) {
|
|
||||||
return InvalidArgument(
|
|
||||||
"Unexpected parameter number %d in alias config without tupled "
|
|
||||||
"inputs and %d parameters",
|
|
||||||
this_parameter, number_of_parameters);
|
|
||||||
}
|
|
||||||
parameters_to_donate.insert(this_parameter);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
return parameters_to_donate;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
||||||
return absl::make_unique<HloCostAnalysis>(
|
return absl::make_unique<HloCostAnalysis>(
|
||||||
client_->backend().compiler()->ShapeSizeBytesFunction());
|
client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||||
@ -1531,12 +1475,66 @@ PjRtExecutable::PjRtExecutable(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) {
|
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||||
|
const HloModule& module, bool tuple_inputs) {
|
||||||
|
HloComputation* computation = module.entry_computation();
|
||||||
|
int number_of_parameters = [&]() -> int {
|
||||||
|
if (tuple_inputs) {
|
||||||
|
CHECK_EQ(computation->num_parameters(), 1);
|
||||||
|
const Shape& input_tuple_shape =
|
||||||
|
computation->parameter_instruction(0)->shape();
|
||||||
|
CHECK(input_tuple_shape.IsTuple());
|
||||||
|
return input_tuple_shape.tuple_shapes_size();
|
||||||
|
} else {
|
||||||
|
return computation->num_parameters();
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
// If any buffer in a parameter is aliased we will donate the entire input
|
||||||
|
// parameter.
|
||||||
|
absl::flat_hash_set<int> parameters_to_donate;
|
||||||
|
const HloInputOutputAliasConfig& config = module.input_output_alias_config();
|
||||||
|
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
||||||
|
[&](const ShapeIndex& output_index,
|
||||||
|
const HloInputOutputAliasConfig::Alias& alias) {
|
||||||
|
if (tuple_inputs) {
|
||||||
|
if (alias.parameter_number != 0) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Unexpected parameter number %d in alias config with tupled "
|
||||||
|
"inputs",
|
||||||
|
alias.parameter_number);
|
||||||
|
}
|
||||||
|
const ShapeIndex& index = alias.parameter_index;
|
||||||
|
if (!index.empty()) {
|
||||||
|
int this_parameter = index.data()[0];
|
||||||
|
if (this_parameter >= number_of_parameters) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Unexpected parameter index %s in alias config with tupled "
|
||||||
|
"inputs and %d parameters",
|
||||||
|
index.ToString(), number_of_parameters);
|
||||||
|
}
|
||||||
|
parameters_to_donate.insert(this_parameter);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int this_parameter = alias.parameter_number;
|
||||||
|
if (this_parameter >= number_of_parameters) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Unexpected parameter number %d in alias config without tupled "
|
||||||
|
"inputs and %d parameters",
|
||||||
|
this_parameter, number_of_parameters);
|
||||||
|
}
|
||||||
|
parameters_to_donate.insert(this_parameter);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
return parameters_to_donate;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PjRtExecutable::SetUpDonation(bool tuple_inputs) {
|
||||||
parameters_that_must_be_donated_.reserve(executables_.size());
|
parameters_that_must_be_donated_.reserve(executables_.size());
|
||||||
for (auto& executable : executables_) {
|
for (auto& executable : executables_) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> parameters_to_donate,
|
||||||
absl::flat_hash_set<int> parameters_to_donate,
|
GetParametersThatMustBeDonated(
|
||||||
client->GetParametersThatMustBeDonated(*executable, tuple_inputs));
|
executable->executable()->module(), tuple_inputs));
|
||||||
parameters_that_must_be_donated_.emplace_back(
|
parameters_that_must_be_donated_.emplace_back(
|
||||||
std::move(parameters_to_donate));
|
std::move(parameters_to_donate));
|
||||||
}
|
}
|
||||||
@ -2237,7 +2235,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
|||||||
std::move(device_assignment), std::move(local_logical_device_ids),
|
std::move(device_assignment), std::move(local_logical_device_ids),
|
||||||
std::move(local_devices), this);
|
std::move(local_devices), this);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
|
executable->SetUpDonation(options.parameter_is_tupled_arguments));
|
||||||
return executable;
|
return executable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,13 +251,6 @@ class PjRtClient {
|
|||||||
// function specifies which one the platform expects.
|
// function specifies which one the platform expects.
|
||||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||||
|
|
||||||
// Some platforms allow executables to donate buffers so that they can be
|
|
||||||
// aliased from inputs to outputs. This function returns the list of
|
|
||||||
// parameters that must be donated when executable is run. tuple_inputs
|
|
||||||
// reflects the option that executable was compiled with.
|
|
||||||
virtual StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
|
||||||
const LocalExecutable& executable, bool tuple_inputs) const;
|
|
||||||
|
|
||||||
// Generates a unique fingerprint for `executable`. See
|
// Generates a unique fingerprint for `executable`. See
|
||||||
// PjRtExecutable::fingerprint_.
|
// PjRtExecutable::fingerprint_.
|
||||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||||
@ -854,7 +847,7 @@ class PjRtExecutable {
|
|||||||
friend class PjRtClient;
|
friend class PjRtClient;
|
||||||
// Initializes information about which arguments to which executables must be
|
// Initializes information about which arguments to which executables must be
|
||||||
// donated due to aliases that were specified by the computation.
|
// donated due to aliases that were specified by the computation.
|
||||||
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
|
Status SetUpDonation(bool tuple_inputs);
|
||||||
|
|
||||||
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
|
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
|
||||||
|
|
||||||
@ -913,6 +906,13 @@ class PjRtExecutable {
|
|||||||
std::vector<PjRtDevice*> local_devices_;
|
std::vector<PjRtDevice*> local_devices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Executables can donate buffers so that buffers can be aliased from inputs
|
||||||
|
// to outputs. This function returns the list of parameters that must be
|
||||||
|
// donated when executable is run. tuple_inputs reflects the option that
|
||||||
|
// executable was compiled with.
|
||||||
|
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||||
|
const HloModule& hlo_module, bool tuple_inputs);
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user