Expose fusion configuration as part of HLO module's config and AOT compilation options.
PiperOrigin-RevId: 271656694
This commit is contained in:
parent
5c271fd23d
commit
610d05b1cf
@ -21,6 +21,15 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<HloModuleConfig>>
|
||||||
|
CompileOnlyClient::CreateModuleConfig(
|
||||||
|
const ProgramShape& program_shape,
|
||||||
|
absl::Span<const Shape* const> argument_shapes,
|
||||||
|
const ExecutionOptions* execution_options) {
|
||||||
|
return compiler_service_->CreateModuleConfig(program_shape, argument_shapes,
|
||||||
|
execution_options);
|
||||||
|
}
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileOnlyClient::CompileAheadOfTime(
|
CompileOnlyClient::CompileAheadOfTime(
|
||||||
const absl::Span<const AotXlaComputationInstance> computations,
|
const absl::Span<const AotXlaComputationInstance> computations,
|
||||||
|
@ -56,6 +56,13 @@ class CompileOnlyClient : public Client {
|
|||||||
const AotCompilationOptions& options,
|
const AotCompilationOptions& options,
|
||||||
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
|
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
|
||||||
|
|
||||||
|
// Create a Hlo module config for the given program shape and arguments.
|
||||||
|
// execution_options is optional; if not given a default is used.
|
||||||
|
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||||
|
const ProgramShape& program_shape,
|
||||||
|
absl::Span<const Shape* const> argument_shapes,
|
||||||
|
const ExecutionOptions* execution_options);
|
||||||
|
|
||||||
// Returns the size of a pointer in bytes for a given triple.
|
// Returns the size of a pointer in bytes for a given triple.
|
||||||
static int64 PointerSizeForTriple(absl::string_view triple);
|
static int64 PointerSizeForTriple(absl::string_view triple);
|
||||||
|
|
||||||
|
@ -83,7 +83,8 @@ CompileOnlyService::CompileAheadOfTime(
|
|||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
CreateModuleConfig(
|
CreateModuleConfig(
|
||||||
ProgramShape(instance.computation.host_program_shape()),
|
ProgramShape(instance.computation.host_program_shape()),
|
||||||
instance.argument_layouts, &execution_options));
|
instance.argument_layouts, &execution_options,
|
||||||
|
options.fusion_config_collection(), options.fusion_config()));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<HloModule> hlo_module,
|
std::unique_ptr<HloModule> hlo_module,
|
||||||
|
@ -96,6 +96,21 @@ class AotCompilationOptions {
|
|||||||
static_device_assignment_ = device_assignment;
|
static_device_assignment_ = device_assignment;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FusionConfigCollection fusion_config_collection() const {
|
||||||
|
return fusion_config_collection_;
|
||||||
|
}
|
||||||
|
void set_fusion_config_collection(
|
||||||
|
FusionConfigCollection fusion_config_collection) {
|
||||||
|
fusion_config_collection_ = fusion_config_collection;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::vector<bool>>& fusion_config() const {
|
||||||
|
return fusion_config_;
|
||||||
|
}
|
||||||
|
void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) {
|
||||||
|
fusion_config_ = fusion_config;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
AotCompilationOptions();
|
AotCompilationOptions();
|
||||||
|
|
||||||
@ -103,6 +118,8 @@ class AotCompilationOptions {
|
|||||||
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
|
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||||
DebugOptions debug_options_;
|
DebugOptions debug_options_;
|
||||||
absl::optional<DeviceAssignment> static_device_assignment_;
|
absl::optional<DeviceAssignment> static_device_assignment_;
|
||||||
|
std::vector<std::vector<bool>> fusion_config_;
|
||||||
|
FusionConfigCollection fusion_config_collection_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Abstract superclass describing metadata produced during ahead-of-time
|
// Abstract superclass describing metadata produced during ahead-of-time
|
||||||
|
@ -204,7 +204,7 @@ class HloModule {
|
|||||||
std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;
|
std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;
|
||||||
|
|
||||||
const HloModuleConfig& config() const { return config_; }
|
const HloModuleConfig& config() const { return config_; }
|
||||||
void set_config(HloModuleConfig& config) { config_ = config; }
|
void set_config(const HloModuleConfig& config) { config_ = config; }
|
||||||
|
|
||||||
// Return a string representation of the module.
|
// Return a string representation of the module.
|
||||||
//
|
//
|
||||||
@ -294,10 +294,6 @@ class HloModule {
|
|||||||
|
|
||||||
Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
|
Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
|
||||||
|
|
||||||
std::vector<std::vector<bool>>* mutable_fusion_config() {
|
|
||||||
return &fusion_config_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if this config has a list of entry parameters' HLO shardings for
|
// Checks if this config has a list of entry parameters' HLO shardings for
|
||||||
// SPMD.
|
// SPMD.
|
||||||
bool has_spmd_parameters_shardings() const {
|
bool has_spmd_parameters_shardings() const {
|
||||||
@ -370,9 +366,6 @@ class HloModule {
|
|||||||
// Bindings for dynamic parameter mapping.
|
// Bindings for dynamic parameter mapping.
|
||||||
DynamicParameterBinding dynamic_parameter_binding_;
|
DynamicParameterBinding dynamic_parameter_binding_;
|
||||||
|
|
||||||
// Fusion configuration.
|
|
||||||
std::vector<std::vector<bool>> fusion_config_;
|
|
||||||
|
|
||||||
// The HLO shardings of the entry computation's parameters for
|
// The HLO shardings of the entry computation's parameters for
|
||||||
// SPMD-partitioned programs.
|
// SPMD-partitioned programs.
|
||||||
absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
|
absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;
|
||||||
|
@ -27,6 +27,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
enum class FusionConfigCollection {
|
||||||
|
kOff, // Do not collect configuration.
|
||||||
|
kPerEdge, // Collect per-edge configuration.
|
||||||
|
kPerNode, // Collect per-node configuration.
|
||||||
|
};
|
||||||
|
|
||||||
// This class gathers all settings and values which affect the compiled
|
// This class gathers all settings and values which affect the compiled
|
||||||
// executable outside of the HLO code itself. This include layouts of inputs and
|
// executable outside of the HLO code itself. This include layouts of inputs and
|
||||||
// outputs to the module and settings such as HLO profiling. Together the
|
// outputs to the module and settings such as HLO profiling. Together the
|
||||||
@ -157,6 +163,21 @@ class HloModuleConfig {
|
|||||||
alias_passthrough_params_ = alias_passthrough_params;
|
alias_passthrough_params_ = alias_passthrough_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FusionConfigCollection fusion_config_collection() const {
|
||||||
|
return fusion_config_collection_;
|
||||||
|
}
|
||||||
|
void set_fusion_config_collection(
|
||||||
|
FusionConfigCollection fusion_config_collection) {
|
||||||
|
fusion_config_collection_ = fusion_config_collection;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::vector<bool>>& fusion_config() const {
|
||||||
|
return fusion_config_;
|
||||||
|
}
|
||||||
|
std::vector<std::vector<bool>>* mutable_fusion_config() {
|
||||||
|
return &fusion_config_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// If you add new members, be sure to update compilation_cache_key.
|
// If you add new members, be sure to update compilation_cache_key.
|
||||||
|
|
||||||
@ -180,6 +201,11 @@ class HloModuleConfig {
|
|||||||
std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_;
|
std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_;
|
||||||
|
|
||||||
bool alias_passthrough_params_ = false;
|
bool alias_passthrough_params_ = false;
|
||||||
|
|
||||||
|
FusionConfigCollection fusion_config_collection_ =
|
||||||
|
FusionConfigCollection::kOff;
|
||||||
|
|
||||||
|
std::vector<std::vector<bool>> fusion_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -469,8 +469,10 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
|||||||
module_ = module;
|
module_ = module;
|
||||||
int64 fuse_count = 0;
|
int64 fuse_count = 0;
|
||||||
std::vector<std::vector<bool>>* fusion_config = nullptr;
|
std::vector<std::vector<bool>>* fusion_config = nullptr;
|
||||||
|
HloModuleConfig module_config;
|
||||||
if (config_collection_mode_ != FusionConfigCollection::kOff) {
|
if (config_collection_mode_ != FusionConfigCollection::kOff) {
|
||||||
fusion_config = module->mutable_fusion_config();
|
module_config = module->config();
|
||||||
|
fusion_config = module_config.mutable_fusion_config();
|
||||||
fusion_config->clear();
|
fusion_config->clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -569,7 +571,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
|||||||
if (config_collection_mode_ != FusionConfigCollection::kOff) {
|
if (config_collection_mode_ != FusionConfigCollection::kOff) {
|
||||||
const std::vector<bool>* comp_fusion_config =
|
const std::vector<bool>* comp_fusion_config =
|
||||||
fusion_queue->FusionConfiguration();
|
fusion_queue->FusionConfiguration();
|
||||||
if (comp_fusion_config && comp_fusion_config->size() > 0) {
|
if (comp_fusion_config && !comp_fusion_config->empty()) {
|
||||||
fusion_config->push_back(*comp_fusion_config);
|
fusion_config->push_back(*comp_fusion_config);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -587,6 +589,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
|||||||
VLOG(1) << "There are " << fused_count << " fused bits that cause "
|
VLOG(1) << "There are " << fused_count << " fused bits that cause "
|
||||||
<< fuse_count << " fusion actions.";
|
<< fuse_count << " fusion actions.";
|
||||||
VLOG(1) << FusionConfigToString(*fusion_config);
|
VLOG(1) << FusionConfigToString(*fusion_config);
|
||||||
|
module->set_config(module_config);
|
||||||
}
|
}
|
||||||
VLOG(1) << "Fusion count: " << fuse_count;
|
VLOG(1) << "Fusion count: " << fuse_count;
|
||||||
|
|
||||||
|
@ -27,12 +27,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
enum class FusionConfigCollection {
|
|
||||||
kOff, // Do not collect configuration.
|
|
||||||
kPerEdge, // Collect per-edge configuration.
|
|
||||||
kPerNode, // Collect per-node configuration.
|
|
||||||
};
|
|
||||||
|
|
||||||
// HLO pass which performs instruction fusion. Instructions are fused
|
// HLO pass which performs instruction fusion. Instructions are fused
|
||||||
// "vertically", meaning producing instructions are fused into their consumers
|
// "vertically", meaning producing instructions are fused into their consumers
|
||||||
// with the intent that the loops which compute their values will be fused in
|
// with the intent that the loops which compute their values will be fused in
|
||||||
|
@ -265,7 +265,9 @@ Service::ResolveAndValidateArguments(
|
|||||||
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||||
const ProgramShape& program_shape,
|
const ProgramShape& program_shape,
|
||||||
absl::Span<const Shape* const> argument_shapes,
|
absl::Span<const Shape* const> argument_shapes,
|
||||||
const ExecutionOptions* execution_options) {
|
const ExecutionOptions* execution_options,
|
||||||
|
FusionConfigCollection fusion_config_collection,
|
||||||
|
const std::vector<std::vector<bool>>& fusion_config) {
|
||||||
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
auto config = absl::make_unique<HloModuleConfig>(program_shape);
|
||||||
ComputationLayout* computation_layout =
|
ComputationLayout* computation_layout =
|
||||||
config->mutable_entry_computation_layout();
|
config->mutable_entry_computation_layout();
|
||||||
@ -332,6 +334,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
|||||||
config->set_alias_passthrough_params(
|
config->set_alias_passthrough_params(
|
||||||
execution_options->alias_passthrough_params());
|
execution_options->alias_passthrough_params());
|
||||||
|
|
||||||
|
if (fusion_config_collection != FusionConfigCollection::kOff) {
|
||||||
|
config->set_fusion_config_collection(fusion_config_collection);
|
||||||
|
*config->mutable_fusion_config() = fusion_config;
|
||||||
|
}
|
||||||
|
|
||||||
return std::move(config);
|
return std::move(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -183,6 +183,16 @@ class Service : public ServiceInterface {
|
|||||||
const Backend& backend() const { return *execute_backend_; }
|
const Backend& backend() const { return *execute_backend_; }
|
||||||
Backend* mutable_backend() { return execute_backend_.get(); }
|
Backend* mutable_backend() { return execute_backend_.get(); }
|
||||||
|
|
||||||
|
// Create a Hlo module config for the given program shape and arguments.
|
||||||
|
// execution_options is optional; if not given a default is used.
|
||||||
|
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
||||||
|
const ProgramShape& program_shape,
|
||||||
|
absl::Span<const Shape* const> argument_shapes,
|
||||||
|
const ExecutionOptions* execution_options,
|
||||||
|
FusionConfigCollection fusion_config_collection =
|
||||||
|
FusionConfigCollection::kOff,
|
||||||
|
const std::vector<std::vector<bool>>& fusion_config = {});
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// A private overload for Service itself, used by other methods within this
|
// A private overload for Service itself, used by other methods within this
|
||||||
// class.
|
// class.
|
||||||
@ -218,13 +228,6 @@ class Service : public ServiceInterface {
|
|||||||
absl::Span<const GlobalDataHandle* const> arguments,
|
absl::Span<const GlobalDataHandle* const> arguments,
|
||||||
absl::Span<se::StreamExecutor* const> stream_executors) const;
|
absl::Span<se::StreamExecutor* const> stream_executors) const;
|
||||||
|
|
||||||
// Create a Hlo module config for the given program shape and arguments.
|
|
||||||
// execution_options is optional; if not given a default is used.
|
|
||||||
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
|
|
||||||
const ProgramShape& program_shape,
|
|
||||||
absl::Span<const Shape* const> argument_shapes,
|
|
||||||
const ExecutionOptions* execution_options);
|
|
||||||
|
|
||||||
// Builds an Executable for the given parameters.
|
// Builds an Executable for the given parameters.
|
||||||
//
|
//
|
||||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||||
|
Loading…
x
Reference in New Issue
Block a user