Refactor PJRT client.
- Move GetParametersThatMustBeDonated to a free function according to interface principle. https://www.fluentcpp.com/2017/06/20/interface-principle-cpp/ PiperOrigin-RevId: 338583883 Change-Id: I2e9c868d80d99512c02a3a553112656f254b5474
This commit is contained in:
parent
2cea0006f2
commit
99cc6a2df0
@ -235,62 +235,6 @@ StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
|
||||
num_partitions);
|
||||
}
|
||||
|
||||
StatusOr<absl::flat_hash_set<int>> PjRtClient::GetParametersThatMustBeDonated(
|
||||
const LocalExecutable& executable, bool tuple_inputs) const {
|
||||
HloComputation* computation =
|
||||
executable.executable()->module().entry_computation();
|
||||
int number_of_parameters = [&]() -> int {
|
||||
if (tuple_inputs) {
|
||||
CHECK_EQ(computation->num_parameters(), 1);
|
||||
const Shape& input_tuple_shape =
|
||||
computation->parameter_instruction(0)->shape();
|
||||
CHECK(input_tuple_shape.IsTuple());
|
||||
return input_tuple_shape.tuple_shapes_size();
|
||||
} else {
|
||||
return computation->num_parameters();
|
||||
}
|
||||
}();
|
||||
// If any buffer in a parameter is aliased we will donate the entire input
|
||||
// parameter.
|
||||
absl::flat_hash_set<int> parameters_to_donate;
|
||||
const HloInputOutputAliasConfig& config =
|
||||
executable.executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
||||
[&](const ShapeIndex& output_index,
|
||||
const HloInputOutputAliasConfig::Alias& alias) {
|
||||
if (tuple_inputs) {
|
||||
if (alias.parameter_number != 0) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config with tupled "
|
||||
"inputs",
|
||||
alias.parameter_number);
|
||||
}
|
||||
const ShapeIndex& index = alias.parameter_index;
|
||||
if (!index.empty()) {
|
||||
int this_parameter = index.data()[0];
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter index %s in alias config with tupled "
|
||||
"inputs and %d parameters",
|
||||
index.ToString(), number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
} else {
|
||||
int this_parameter = alias.parameter_number;
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config without tupled "
|
||||
"inputs and %d parameters",
|
||||
this_parameter, number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return parameters_to_donate;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
|
||||
return absl::make_unique<HloCostAnalysis>(
|
||||
client_->backend().compiler()->ShapeSizeBytesFunction());
|
||||
@ -1531,12 +1475,66 @@ PjRtExecutable::PjRtExecutable(
|
||||
}
|
||||
}
|
||||
|
||||
Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) {
|
||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const HloModule& module, bool tuple_inputs) {
|
||||
HloComputation* computation = module.entry_computation();
|
||||
int number_of_parameters = [&]() -> int {
|
||||
if (tuple_inputs) {
|
||||
CHECK_EQ(computation->num_parameters(), 1);
|
||||
const Shape& input_tuple_shape =
|
||||
computation->parameter_instruction(0)->shape();
|
||||
CHECK(input_tuple_shape.IsTuple());
|
||||
return input_tuple_shape.tuple_shapes_size();
|
||||
} else {
|
||||
return computation->num_parameters();
|
||||
}
|
||||
}();
|
||||
// If any buffer in a parameter is aliased we will donate the entire input
|
||||
// parameter.
|
||||
absl::flat_hash_set<int> parameters_to_donate;
|
||||
const HloInputOutputAliasConfig& config = module.input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
||||
[&](const ShapeIndex& output_index,
|
||||
const HloInputOutputAliasConfig::Alias& alias) {
|
||||
if (tuple_inputs) {
|
||||
if (alias.parameter_number != 0) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config with tupled "
|
||||
"inputs",
|
||||
alias.parameter_number);
|
||||
}
|
||||
const ShapeIndex& index = alias.parameter_index;
|
||||
if (!index.empty()) {
|
||||
int this_parameter = index.data()[0];
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter index %s in alias config with tupled "
|
||||
"inputs and %d parameters",
|
||||
index.ToString(), number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
} else {
|
||||
int this_parameter = alias.parameter_number;
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config without tupled "
|
||||
"inputs and %d parameters",
|
||||
this_parameter, number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return parameters_to_donate;
|
||||
}
|
||||
|
||||
Status PjRtExecutable::SetUpDonation(bool tuple_inputs) {
|
||||
parameters_that_must_be_donated_.reserve(executables_.size());
|
||||
for (auto& executable : executables_) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
absl::flat_hash_set<int> parameters_to_donate,
|
||||
client->GetParametersThatMustBeDonated(*executable, tuple_inputs));
|
||||
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> parameters_to_donate,
|
||||
GetParametersThatMustBeDonated(
|
||||
executable->executable()->module(), tuple_inputs));
|
||||
parameters_that_must_be_donated_.emplace_back(
|
||||
std::move(parameters_to_donate));
|
||||
}
|
||||
@ -2237,7 +2235,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
std::move(device_assignment), std::move(local_logical_device_ids),
|
||||
std::move(local_devices), this);
|
||||
TF_RETURN_IF_ERROR(
|
||||
executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
|
||||
executable->SetUpDonation(options.parameter_is_tupled_arguments));
|
||||
return executable;
|
||||
}
|
||||
|
||||
|
@ -251,13 +251,6 @@ class PjRtClient {
|
||||
// function specifies which one the platform expects.
|
||||
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
|
||||
|
||||
// Some platforms allow executables to donate buffers so that they can be
|
||||
// aliased from inputs to outputs. This function returns the list of
|
||||
// parameters that must be donated when executable is run. tuple_inputs
|
||||
// reflects the option that executable was compiled with.
|
||||
virtual StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const LocalExecutable& executable, bool tuple_inputs) const;
|
||||
|
||||
// Generates a unique fingerprint for `executable`. See
|
||||
// PjRtExecutable::fingerprint_.
|
||||
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
|
||||
@ -854,7 +847,7 @@ class PjRtExecutable {
|
||||
friend class PjRtClient;
|
||||
// Initializes information about which arguments to which executables must be
|
||||
// donated due to aliases that were specified by the computation.
|
||||
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
|
||||
Status SetUpDonation(bool tuple_inputs);
|
||||
|
||||
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
|
||||
|
||||
@ -913,6 +906,13 @@ class PjRtExecutable {
|
||||
std::vector<PjRtDevice*> local_devices_;
|
||||
};
|
||||
|
||||
// Executables can donate buffers so that buffers can be aliased from inputs
|
||||
// to outputs. This function returns the list of parameters that must be
|
||||
// donated when executable is run. tuple_inputs reflects the option that
|
||||
// executable was compiled with.
|
||||
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
|
||||
const HloModule& hlo_module, bool tuple_inputs);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_
|
||||
|
Loading…
Reference in New Issue
Block a user