Expose fusion configuration as part of HLO module's config and AOT compilation options.

PiperOrigin-RevId: 271656694
This commit is contained in:
A. Unique TensorFlower 2019-09-27 15:08:59 -07:00 committed by TensorFlower Gardener
parent 5c271fd23d
commit 610d05b1cf
10 changed files with 85 additions and 25 deletions

View File

@ -21,6 +21,15 @@ limitations under the License.
namespace xla {
StatusOr<std::unique_ptr<HloModuleConfig>>
CompileOnlyClient::CreateModuleConfig(
const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options) {
return compiler_service_->CreateModuleConfig(program_shape, argument_shapes,
execution_options);
}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyClient::CompileAheadOfTime(
const absl::Span<const AotXlaComputationInstance> computations,

View File

@ -56,6 +56,13 @@ class CompileOnlyClient : public Client {
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options);
// Returns the size of a pointer in bytes for a given triple.
static int64 PointerSizeForTriple(absl::string_view triple);

View File

@ -83,7 +83,8 @@ CompileOnlyService::CompileAheadOfTime(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(
ProgramShape(instance.computation.host_program_shape()),
instance.argument_layouts, &execution_options));
instance.argument_layouts, &execution_options,
options.fusion_config_collection(), options.fusion_config()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,

View File

@ -96,6 +96,21 @@ class AotCompilationOptions {
static_device_assignment_ = device_assignment;
}
FusionConfigCollection fusion_config_collection() const {
return fusion_config_collection_;
}
void set_fusion_config_collection(
FusionConfigCollection fusion_config_collection) {
fusion_config_collection_ = fusion_config_collection;
}
const std::vector<std::vector<bool>>& fusion_config() const {
return fusion_config_;
}
void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) {
fusion_config_ = fusion_config;
}
protected:
AotCompilationOptions();
@ -103,6 +118,8 @@ class AotCompilationOptions {
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
DebugOptions debug_options_;
absl::optional<DeviceAssignment> static_device_assignment_;
std::vector<std::vector<bool>> fusion_config_;
FusionConfigCollection fusion_config_collection_;
};
// Abstract superclass describing metadata produced during ahead-of-time

View File

@ -204,7 +204,7 @@ class HloModule {
std::vector<HloComputation*> MakeNonfusionComputationsSorted() const;
const HloModuleConfig& config() const { return config_; }
void set_config(HloModuleConfig& config) { config_ = config; }
void set_config(const HloModuleConfig& config) { config_ = config; }
// Return a string representation of the module.
//
@ -294,10 +294,6 @@ class HloModule {
Status CheckUniqueNamesAndIdsForComputationsAndInstructions() const;
std::vector<std::vector<bool>>* mutable_fusion_config() {
return &fusion_config_;
}
// Checks if this config has a list of entry parameters' HLO shardings for
// SPMD.
bool has_spmd_parameters_shardings() const {
@ -370,9 +366,6 @@ class HloModule {
// Bindings for dynamic parameter mapping.
DynamicParameterBinding dynamic_parameter_binding_;
// Fusion configuration.
std::vector<std::vector<bool>> fusion_config_;
// The HLO shardings of the entry computation's parameters for
// SPMD-partitioned programs.
absl::optional<std::vector<HloSharding>> spmd_parameters_shardings_;

View File

@ -27,6 +27,12 @@ limitations under the License.
namespace xla {
enum class FusionConfigCollection {
kOff, // Do not collect configuration.
kPerEdge, // Collect per-edge configuration.
kPerNode, // Collect per-node configuration.
};
// This class gathers all settings and values which affect the compiled
// executable outside of the HLO code itself. This include layouts of inputs and
// outputs to the module and settings such as HLO profiling. Together the
@ -157,6 +163,21 @@ class HloModuleConfig {
alias_passthrough_params_ = alias_passthrough_params;
}
FusionConfigCollection fusion_config_collection() const {
return fusion_config_collection_;
}
void set_fusion_config_collection(
FusionConfigCollection fusion_config_collection) {
fusion_config_collection_ = fusion_config_collection;
}
const std::vector<std::vector<bool>>& fusion_config() const {
return fusion_config_;
}
std::vector<std::vector<bool>>* mutable_fusion_config() {
return &fusion_config_;
}
private:
// If you add new members, be sure to update compilation_cache_key.
@ -180,6 +201,11 @@ class HloModuleConfig {
std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_;
bool alias_passthrough_params_ = false;
FusionConfigCollection fusion_config_collection_ =
FusionConfigCollection::kOff;
std::vector<std::vector<bool>> fusion_config_;
};
} // namespace xla

View File

@ -469,8 +469,10 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
module_ = module;
int64 fuse_count = 0;
std::vector<std::vector<bool>>* fusion_config = nullptr;
HloModuleConfig module_config;
if (config_collection_mode_ != FusionConfigCollection::kOff) {
fusion_config = module->mutable_fusion_config();
module_config = module->config();
fusion_config = module_config.mutable_fusion_config();
fusion_config->clear();
}
@ -569,7 +571,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
if (config_collection_mode_ != FusionConfigCollection::kOff) {
const std::vector<bool>* comp_fusion_config =
fusion_queue->FusionConfiguration();
if (comp_fusion_config && comp_fusion_config->size() > 0) {
if (comp_fusion_config && !comp_fusion_config->empty()) {
fusion_config->push_back(*comp_fusion_config);
}
}
@ -587,6 +589,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
VLOG(1) << "There are " << fused_count << " fused bits that cause "
<< fuse_count << " fusion actions.";
VLOG(1) << FusionConfigToString(*fusion_config);
module->set_config(module_config);
}
VLOG(1) << "Fusion count: " << fuse_count;

View File

@ -27,12 +27,6 @@ limitations under the License.
namespace xla {
enum class FusionConfigCollection {
kOff, // Do not collect configuration.
kPerEdge, // Collect per-edge configuration.
kPerNode, // Collect per-node configuration.
};
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in

View File

@ -265,7 +265,9 @@ Service::ResolveAndValidateArguments(
StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options) {
const ExecutionOptions* execution_options,
FusionConfigCollection fusion_config_collection,
const std::vector<std::vector<bool>>& fusion_config) {
auto config = absl::make_unique<HloModuleConfig>(program_shape);
ComputationLayout* computation_layout =
config->mutable_entry_computation_layout();
@ -332,6 +334,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
config->set_alias_passthrough_params(
execution_options->alias_passthrough_params());
if (fusion_config_collection != FusionConfigCollection::kOff) {
config->set_fusion_config_collection(fusion_config_collection);
*config->mutable_fusion_config() = fusion_config;
}
return std::move(config);
}

View File

@ -183,6 +183,16 @@ class Service : public ServiceInterface {
const Backend& backend() const { return *execute_backend_; }
Backend* mutable_backend() { return execute_backend_.get(); }
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options,
FusionConfigCollection fusion_config_collection =
FusionConfigCollection::kOff,
const std::vector<std::vector<bool>>& fusion_config = {});
private:
// A private overload for Service itself, used by other methods within this
// class.
@ -218,13 +228,6 @@ class Service : public ServiceInterface {
absl::Span<const GlobalDataHandle* const> arguments,
absl::Span<se::StreamExecutor* const> stream_executors) const;
// Create a Hlo module config for the given program shape and arguments.
// execution_options is optional; if not given a default is used.
StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
const ProgramShape& program_shape,
absl::Span<const Shape* const> argument_shapes,
const ExecutionOptions* execution_options);
// Builds an Executable for the given parameters.
//
// If device_allocator is not null, the compiler may use it to allocate temp