Expose fusion configuration as part of HLO module's config and AOT compilation options.
PiperOrigin-RevId: 271656694
This commit is contained in:
parent
5c271fd23d
commit
610d05b1cf
@ -21,6 +21,15 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>>
|
||||
CompileOnlyClient::CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options) {
|
||||
return compiler_service_->CreateModuleConfig(program_shape, argument_shapes,
|
||||
execution_options);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyClient::CompileAheadOfTime(
|
||||
const absl::Span<const AotXlaComputationInstance> computations,
|
||||
|
@ -56,6 +56,13 @@ class CompileOnlyClient : public Client {
|
||||
const AotCompilationOptions& options,
|
||||
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
|
||||
|
||||
// Create a Hlo module config for the given program shape and arguments.
|
||||
// execution_options is optional; if not given a default is used.
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options);
|
||||
|
||||
// Returns the size of a pointer in bytes for a given triple.
|
||||
static int64 PointerSizeForTriple(absl::string_view triple);
|
||||
|
||||
|
@ -83,7 +83,8 @@ CompileOnlyService::CompileAheadOfTime(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(
|
||||
ProgramShape(instance.computation.host_program_shape()),
|
||||
instance.argument_layouts, &execution_options));
|
||||
instance.argument_layouts, &execution_options,
|
||||
options.fusion_config_collection(), options.fusion_config()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
|
@ -96,6 +96,21 @@ class AotCompilationOptions {
|
||||
static_device_assignment_ = device_assignment;
|
||||
}
|
||||
|
||||
FusionConfigCollection fusion_config_collection() const {
|
||||
return fusion_config_collection_;
|
||||
}
|
||||
void set_fusion_config_collection(
|
||||
FusionConfigCollection fusion_config_collection) {
|
||||
fusion_config_collection_ = fusion_config_collection;
|
||||
}
|
||||
|
||||
const std::vector<std::vector<bool>>& fusion_config() const {
|
||||
return fusion_config_;
|
||||
}
|
||||
void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) {
|
||||
fusion_config_ = fusion_config;
|
||||
}
|
||||
|
||||
protected:
|
||||
AotCompilationOptions();
|
||||
|
||||
@ -103,6 +118,8 @@ class AotCompilationOptions {
|
||||
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
DebugOptions debug_options_;
|
||||
absl::optional<DeviceAssignment> static_device_assignment_;
|
||||
std::vector<std::vector<bool>> fusion_config_;
|
||||
FusionConfigCollection fusion_config_collection_;
|
||||
};
|
||||
|
||||
// Abstract superclass describing metadata produced during ahead-of-time
|
||||
|
@ -204,7 +204,7 @@ class HloModule {
|
||||
std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;
|
||||
|
||||
const HloModuleConfig& config() const { return config_; }
|
||||
void set_config(HloModuleConfig& config) { config_ = config; }
|
||||
void set_config(const HloModuleConfig& config) { config_ = config; }
|
||||
|
||||
// Return a string representation of the module.
|
||||
//
|
||||
@ -294,10 +294,6 @@ class HloModule {
|
||||
|
||||
Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
|
||||
|
||||
std::vector<std::vector<bool>>* mutable_fusion_config() {
|
||||
return &fusion_config_;
|
||||
}
|
||||
|
||||
// Checks if this config has a list of entry parameters' HLO shardings for
|
||||
// SPMD.
|
||||
bool has_spmd_parameters_shardings() const {
|
||||
@ -370,9 +366,6 @@ class HloModule {
|
||||
// Bindings for dynamic parameter mapping.
|
||||
DynamicParameterBinding dynamic_parameter_binding_;
|
||||
|
||||
// Fusion configuration.
|
||||
std::vector<std::vector<bool>> fusion_config_;
|
||||
|
||||
// The HLO shardings of the entry computation's parameters for
|
||||
// SPMD-partitioned programs.
|
||||
absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
|
||||
|
@ -27,6 +27,12 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
enum class FusionConfigCollection {
|
||||
kOff, // Do not collect configuration.
|
||||
kPerEdge, // Collect per-edge configuration.
|
||||
kPerNode, // Collect per-node configuration.
|
||||
};
|
||||
|
||||
// This class gathers all settings and values which affect the compiled
|
||||
// executable outside of the HLO code itself. This include layouts of inputs and
|
||||
// outputs to the module and settings such as HLO profiling. Together the
|
||||
@ -157,6 +163,21 @@ class HloModuleConfig {
|
||||
alias_passthrough_params_ = alias_passthrough_params;
|
||||
}
|
||||
|
||||
FusionConfigCollection fusion_config_collection() const {
|
||||
return fusion_config_collection_;
|
||||
}
|
||||
void set_fusion_config_collection(
|
||||
FusionConfigCollection fusion_config_collection) {
|
||||
fusion_config_collection_ = fusion_config_collection;
|
||||
}
|
||||
|
||||
const std::vector<std::vector<bool>>& fusion_config() const {
|
||||
return fusion_config_;
|
||||
}
|
||||
std::vector<std::vector<bool>>* mutable_fusion_config() {
|
||||
return &fusion_config_;
|
||||
}
|
||||
|
||||
private:
|
||||
// If you add new members, be sure to update compilation_cache_key.
|
||||
|
||||
@ -180,6 +201,11 @@ class HloModuleConfig {
|
||||
std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_;
|
||||
|
||||
bool alias_passthrough_params_ = false;
|
||||
|
||||
FusionConfigCollection fusion_config_collection_ =
|
||||
FusionConfigCollection::kOff;
|
||||
|
||||
std::vector<std::vector<bool>> fusion_config_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -469,8 +469,10 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
module_ = module;
|
||||
int64 fuse_count = 0;
|
||||
std::vector<std::vector<bool>>* fusion_config = nullptr;
|
||||
HloModuleConfig module_config;
|
||||
if (config_collection_mode_ != FusionConfigCollection::kOff) {
|
||||
fusion_config = module->mutable_fusion_config();
|
||||
module_config = module->config();
|
||||
fusion_config = module_config.mutable_fusion_config();
|
||||
fusion_config->clear();
|
||||
}
|
||||
|
||||
@ -569,7 +571,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
if (config_collection_mode_ != FusionConfigCollection::kOff) {
|
||||
const std::vector<bool>* comp_fusion_config =
|
||||
fusion_queue->FusionConfiguration();
|
||||
if (comp_fusion_config && comp_fusion_config->size() > 0) {
|
||||
if (comp_fusion_config && !comp_fusion_config->empty()) {
|
||||
fusion_config->push_back(*comp_fusion_config);
|
||||
}
|
||||
}
|
||||
@ -587,6 +589,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
VLOG(1) << "There are " << fused_count << " fused bits that cause "
|
||||
<< fuse_count << " fusion actions.";
|
||||
VLOG(1) << FusionConfigToString(*fusion_config);
|
||||
module->set_config(module_config);
|
||||
}
|
||||
VLOG(1) << "Fusion count: " << fuse_count;
|
||||
|
||||
|
@ -27,12 +27,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
enum class FusionConfigCollection {
|
||||
kOff, // Do not collect configuration.
|
||||
kPerEdge, // Collect per-edge configuration.
|
||||
kPerNode, // Collect per-node configuration.
|
||||
};
|
||||
|
||||
// HLO pass which performs instruction fusion. Instructions are fused
|
||||
// "vertically", meaning producing instructions are fused into their consumers
|
||||
// with the intent that the loops which compute their values will be fused in
|
||||
|
@ -265,7 +265,9 @@ Service::ResolveAndValidateArguments(
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options) {
|
||||
const ExecutionOptions* execution_options,
|
||||
FusionConfigCollection fusion_config_collection,
|
||||
const std::vector<std::vector<bool>>& fusion_config) {
|
||||
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
||||
ComputationLayout* computation_layout =
|
||||
config->mutable_entry_computation_layout();
|
||||
@ -332,6 +334,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
config->set_alias_passthrough_params(
|
||||
execution_options->alias_passthrough_params());
|
||||
|
||||
if (fusion_config_collection != FusionConfigCollection::kOff) {
|
||||
config->set_fusion_config_collection(fusion_config_collection);
|
||||
*config->mutable_fusion_config() = fusion_config;
|
||||
}
|
||||
|
||||
return std::move(config);
|
||||
}
|
||||
|
||||
|
@ -183,6 +183,16 @@ class Service : public ServiceInterface {
|
||||
const Backend& backend() const { return *execute_backend_; }
|
||||
Backend* mutable_backend() { return execute_backend_.get(); }
|
||||
|
||||
// Create a Hlo module config for the given program shape and arguments.
|
||||
// execution_options is optional; if not given a default is used.
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options,
|
||||
FusionConfigCollection fusion_config_collection =
|
||||
FusionConfigCollection::kOff,
|
||||
const std::vector<std::vector<bool>>& fusion_config = {});
|
||||
|
||||
private:
|
||||
// A private overload for Service itself, used by other methods within this
|
||||
// class.
|
||||
@ -218,13 +228,6 @@ class Service : public ServiceInterface {
|
||||
absl::Span<const GlobalDataHandle* const> arguments,
|
||||
absl::Span<se::StreamExecutor* const> stream_executors) const;
|
||||
|
||||
// Create a Hlo module config for the given program shape and arguments.
|
||||
// execution_options is optional; if not given a default is used.
|
||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||
const ProgramShape& program_shape,
|
||||
absl::Span<const Shape* const> argument_shapes,
|
||||
const ExecutionOptions* execution_options);
|
||||
|
||||
// Builds an Executable for the given parameters.
|
||||
//
|
||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||
|
Loading…
Reference in New Issue
Block a user