Refactor PJRT client.

- Move GetParametersThatMustBeDonated to a free function according to interface
  principle.
https://www.fluentcpp.com/2017/06/20/interface-principle-cpp/

PiperOrigin-RevId: 338583883
Change-Id: I2e9c868d80d99512c02a3a553112656f254b5474
This commit is contained in:
Qiao Zhang 2020-10-22 17:43:04 -07:00 committed by TensorFlower Gardener
parent 2cea0006f2
commit 99cc6a2df0
2 changed files with 67 additions and 69 deletions

View File

@ -235,62 +235,6 @@ StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
num_partitions);
}
StatusOr<absl::flat_hash_set<int>> PjRtClient::GetParametersThatMustBeDonated(
const LocalExecutable& executable, bool tuple_inputs) const {
HloComputation* computation =
executable.executable()->module().entry_computation();
int number_of_parameters = [&]() -> int {
if (tuple_inputs) {
CHECK_EQ(computation->num_parameters(), 1);
const Shape& input_tuple_shape =
computation->parameter_instruction(0)->shape();
CHECK(input_tuple_shape.IsTuple());
return input_tuple_shape.tuple_shapes_size();
} else {
return computation->num_parameters();
}
}();
// If any buffer in a parameter is aliased we will donate the entire input
// parameter.
absl::flat_hash_set<int> parameters_to_donate;
const HloInputOutputAliasConfig& config =
executable.executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
[&](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
if (tuple_inputs) {
if (alias.parameter_number != 0) {
return InvalidArgument(
"Unexpected parameter number %d in alias config with tupled "
"inputs",
alias.parameter_number);
}
const ShapeIndex& index = alias.parameter_index;
if (!index.empty()) {
int this_parameter = index.data()[0];
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter index %s in alias config with tupled "
"inputs and %d parameters",
index.ToString(), number_of_parameters);
}
parameters_to_donate.insert(this_parameter);
}
} else {
int this_parameter = alias.parameter_number;
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter number %d in alias config without tupled "
"inputs and %d parameters",
this_parameter, number_of_parameters);
}
parameters_to_donate.insert(this_parameter);
}
return Status::OK();
}));
return parameters_to_donate;
}
std::unique_ptr<HloCostAnalysis> PjRtClient::GetHloCostAnalysis() {
return absl::make_unique<HloCostAnalysis>(
client_->backend().compiler()->ShapeSizeBytesFunction());
@ -1531,12 +1475,66 @@ PjRtExecutable::PjRtExecutable(
}
}
Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) {
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const HloModule& module, bool tuple_inputs) {
HloComputation* computation = module.entry_computation();
int number_of_parameters = [&]() -> int {
if (tuple_inputs) {
CHECK_EQ(computation->num_parameters(), 1);
const Shape& input_tuple_shape =
computation->parameter_instruction(0)->shape();
CHECK(input_tuple_shape.IsTuple());
return input_tuple_shape.tuple_shapes_size();
} else {
return computation->num_parameters();
}
}();
// If any buffer in a parameter is aliased we will donate the entire input
// parameter.
absl::flat_hash_set<int> parameters_to_donate;
const HloInputOutputAliasConfig& config = module.input_output_alias_config();
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
[&](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
if (tuple_inputs) {
if (alias.parameter_number != 0) {
return InvalidArgument(
"Unexpected parameter number %d in alias config with tupled "
"inputs",
alias.parameter_number);
}
const ShapeIndex& index = alias.parameter_index;
if (!index.empty()) {
int this_parameter = index.data()[0];
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter index %s in alias config with tupled "
"inputs and %d parameters",
index.ToString(), number_of_parameters);
}
parameters_to_donate.insert(this_parameter);
}
} else {
int this_parameter = alias.parameter_number;
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter number %d in alias config without tupled "
"inputs and %d parameters",
this_parameter, number_of_parameters);
}
parameters_to_donate.insert(this_parameter);
}
return Status::OK();
}));
return parameters_to_donate;
}
Status PjRtExecutable::SetUpDonation(bool tuple_inputs) {
parameters_that_must_be_donated_.reserve(executables_.size());
for (auto& executable : executables_) {
TF_ASSIGN_OR_RETURN(
absl::flat_hash_set<int> parameters_to_donate,
client->GetParametersThatMustBeDonated(*executable, tuple_inputs));
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<int> parameters_to_donate,
GetParametersThatMustBeDonated(
executable->executable()->module(), tuple_inputs));
parameters_that_must_be_donated_.emplace_back(
std::move(parameters_to_donate));
}
@ -2237,7 +2235,7 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
std::move(device_assignment), std::move(local_logical_device_ids),
std::move(local_devices), this);
TF_RETURN_IF_ERROR(
executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
executable->SetUpDonation(options.parameter_is_tupled_arguments));
return executable;
}

View File

@ -251,13 +251,6 @@ class PjRtClient {
// function specifies which one the platform expects.
virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; }
// Some platforms allow executables to donate buffers so that they can be
// aliased from inputs to outputs. This function returns the list of
// parameters that must be donated when executable is run. tuple_inputs
// reflects the option that executable was compiled with.
virtual StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const LocalExecutable& executable, bool tuple_inputs) const;
// Generates a unique fingerprint for `executable`. See
// PjRtExecutable::fingerprint_.
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
@ -854,7 +847,7 @@ class PjRtExecutable {
friend class PjRtClient;
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
Status SetUpDonation(bool tuple_inputs);
virtual bool MustDonateParameter(int executable_idx, int parameter) const;
@ -913,6 +906,13 @@ class PjRtExecutable {
std::vector<PjRtDevice*> local_devices_;
};
// Executables can donate buffers so that buffers can be aliased from inputs
// to outputs. This function returns the list of parameters that must be
// donated when executable is run. tuple_inputs reflects the option that
// executable was compiled with.
StatusOr<absl::flat_hash_set<int>> GetParametersThatMustBeDonated(
const HloModule& hlo_module, bool tuple_inputs);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_