diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 18de057395b..83ed61cfe63 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -235,62 +235,6 @@ StatusOr PjRtClient::GetDefaultDeviceAssignment( num_partitions); } -StatusOr> PjRtClient::GetParametersThatMustBeDonated( - const LocalExecutable& executable, bool tuple_inputs) const { - HloComputation* computation = - executable.executable()->module().entry_computation(); - int number_of_parameters = [&]() -> int { - if (tuple_inputs) { - CHECK_EQ(computation->num_parameters(), 1); - const Shape& input_tuple_shape = - computation->parameter_instruction(0)->shape(); - CHECK(input_tuple_shape.IsTuple()); - return input_tuple_shape.tuple_shapes_size(); - } else { - return computation->num_parameters(); - } - }(); - // If any buffer in a parameter is aliased we will donate the entire input - // parameter. - absl::flat_hash_set parameters_to_donate; - const HloInputOutputAliasConfig& config = - executable.executable()->module().input_output_alias_config(); - TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus( - [&](const ShapeIndex& output_index, - const HloInputOutputAliasConfig::Alias& alias) { - if (tuple_inputs) { - if (alias.parameter_number != 0) { - return InvalidArgument( - "Unexpected parameter number %d in alias config with tupled " - "inputs", - alias.parameter_number); - } - const ShapeIndex& index = alias.parameter_index; - if (!index.empty()) { - int this_parameter = index.data()[0]; - if (this_parameter >= number_of_parameters) { - return InvalidArgument( - "Unexpected parameter index %s in alias config with tupled " - "inputs and %d parameters", - index.ToString(), number_of_parameters); - } - parameters_to_donate.insert(this_parameter); - } - } else { - int this_parameter = alias.parameter_number; - if (this_parameter >= number_of_parameters) { - return InvalidArgument( - "Unexpected parameter number %d in alias config without tupled " - "inputs and %d parameters", - this_parameter, number_of_parameters); - } - parameters_to_donate.insert(this_parameter); - } - return Status::OK(); - })); - return parameters_to_donate; -} - std::unique_ptr PjRtClient::GetHloCostAnalysis() { return absl::make_unique( client_->backend().compiler()->ShapeSizeBytesFunction()); @@ -1531,12 +1475,66 @@ PjRtExecutable::PjRtExecutable( } } -Status PjRtExecutable::SetUpDonation(PjRtClient* client, bool tuple_inputs) { +StatusOr> GetParametersThatMustBeDonated( + const HloModule& module, bool tuple_inputs) { + HloComputation* computation = module.entry_computation(); + int number_of_parameters = [&]() -> int { + if (tuple_inputs) { + CHECK_EQ(computation->num_parameters(), 1); + const Shape& input_tuple_shape = + computation->parameter_instruction(0)->shape(); + CHECK(input_tuple_shape.IsTuple()); + return input_tuple_shape.tuple_shapes_size(); + } else { + return computation->num_parameters(); + } + }(); + // If any buffer in a parameter is aliased we will donate the entire input + // parameter. + absl::flat_hash_set parameters_to_donate; + const HloInputOutputAliasConfig& config = module.input_output_alias_config(); + TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus( + [&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + if (tuple_inputs) { + if (alias.parameter_number != 0) { + return InvalidArgument( + "Unexpected parameter number %d in alias config with tupled " + "inputs", + alias.parameter_number); + } + const ShapeIndex& index = alias.parameter_index; + if (!index.empty()) { + int this_parameter = index.data()[0]; + if (this_parameter >= number_of_parameters) { + return InvalidArgument( + "Unexpected parameter index %s in alias config with tupled " + "inputs and %d parameters", + index.ToString(), number_of_parameters); + } + parameters_to_donate.insert(this_parameter); + } + } else { + int this_parameter = alias.parameter_number; + if (this_parameter >= number_of_parameters) { + return InvalidArgument( + "Unexpected parameter number %d in alias config without tupled " + "inputs and %d parameters", + this_parameter, number_of_parameters); + } + parameters_to_donate.insert(this_parameter); + } + return Status::OK(); + })); + return parameters_to_donate; +} + +Status PjRtExecutable::SetUpDonation(bool tuple_inputs) { parameters_that_must_be_donated_.reserve(executables_.size()); for (auto& executable : executables_) { - TF_ASSIGN_OR_RETURN( - absl::flat_hash_set parameters_to_donate, - client->GetParametersThatMustBeDonated(*executable, tuple_inputs)); + TF_ASSIGN_OR_RETURN(absl::flat_hash_set parameters_to_donate, + GetParametersThatMustBeDonated( + executable->executable()->module(), tuple_inputs)); parameters_that_must_be_donated_.emplace_back( std::move(parameters_to_donate)); } @@ -2237,7 +2235,7 @@ StatusOr> PjRtClient::Compile( std::move(device_assignment), std::move(local_logical_device_ids), std::move(local_devices), this); TF_RETURN_IF_ERROR( - executable->SetUpDonation(this, options.parameter_is_tupled_arguments)); + executable->SetUpDonation(options.parameter_is_tupled_arguments)); return executable; } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 4ec129eb49d..86805182525 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -251,13 +251,6 @@ class PjRtClient { // function specifies which one the platform expects. virtual bool EnqueueD2DTransfersOnSrcStream() const { return true; } - // Some platforms allow executables to donate buffers so that they can be - // aliased from inputs to outputs. This function returns the list of - // parameters that must be donated when executable is run. tuple_inputs - // reflects the option that executable was compiled with. - virtual StatusOr> GetParametersThatMustBeDonated( - const LocalExecutable& executable, bool tuple_inputs) const; - // Generates a unique fingerprint for `executable`. See // PjRtExecutable::fingerprint_. virtual StatusOr> ExecutableFingerprint( @@ -854,7 +847,7 @@ class PjRtExecutable { friend class PjRtClient; // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. - Status SetUpDonation(PjRtClient* client, bool tuple_inputs); + Status SetUpDonation(bool tuple_inputs); virtual bool MustDonateParameter(int executable_idx, int parameter) const; @@ -913,6 +906,13 @@ class PjRtExecutable { std::vector local_devices_; }; +// Executables can donate buffers so that buffers can be aliased from inputs +// to outputs. This function returns the list of parameters that must be +// donated when executable is run. tuple_inputs reflects the option that +// executable was compiled with. +StatusOr> GetParametersThatMustBeDonated( + const HloModule& hlo_module, bool tuple_inputs); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_CLIENT_H_