[libtpu] Plumb HloModule::debug_options_ through the C API.
PiperOrigin-RevId: 351196886 Change-Id: Idfa6fe5a01969efb888cc9a22718e72d881f1835
This commit is contained in:
parent
bdea938786
commit
96fdbb1a8d
@ -203,42 +203,6 @@ class TpuExecutable : public TpuExecutableInterface {
|
||||
SE_Executable* se_executable_;
|
||||
};
|
||||
|
||||
XLA_HloModuleConfig HloModuleConfigToC(const xla::HloModuleConfig& config) {
|
||||
XLA_HloModuleConfig hlo_config;
|
||||
|
||||
hlo_config.seed = config.seed();
|
||||
hlo_config.launch_id = config.launch_id();
|
||||
hlo_config.replica_count = config.replica_count();
|
||||
hlo_config.num_partitions = config.num_partitions();
|
||||
hlo_config.use_spmd_partitioning = config.use_spmd_partitioning();
|
||||
hlo_config.has_static_device_assignment =
|
||||
config.has_static_device_assignment();
|
||||
hlo_config.has_entry_computation_layout =
|
||||
config.has_entry_computation_layout();
|
||||
|
||||
if (config.has_static_device_assignment()) {
|
||||
DeviceAssignmentProto dev_proto;
|
||||
config.static_device_assignment().Serialize(&dev_proto).IgnoreError();
|
||||
hlo_config.static_device_assignment =
|
||||
stream_executor::tpu::SerializeProto(dev_proto);
|
||||
}
|
||||
if (config.has_entry_computation_layout()) {
|
||||
auto layout = config.entry_computation_layout();
|
||||
ApiConverter::ToC(layout.result_layout().shape(),
|
||||
&hlo_config.entry_computation_layout.result_layout);
|
||||
hlo_config.entry_computation_layout.parameter_layouts =
|
||||
new XLA_Shape[layout.parameter_count()];
|
||||
for (int i = 0; i < layout.parameter_count(); ++i) {
|
||||
ApiConverter::ToC(
|
||||
layout.parameter_layout(i).shape(),
|
||||
&hlo_config.entry_computation_layout.parameter_layouts[i]);
|
||||
}
|
||||
hlo_config.entry_computation_layout.parameter_count =
|
||||
layout.parameter_count();
|
||||
}
|
||||
return hlo_config;
|
||||
}
|
||||
|
||||
class TpuCompiler : public Compiler {
|
||||
public:
|
||||
TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
|
||||
@ -259,7 +223,7 @@ class TpuCompiler : public Compiler {
|
||||
stream_executor::tpu::SerializedProto_Free(result.proto);
|
||||
ApiConverter::Free(&hlo_module.module_config);
|
||||
});
|
||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
||||
hlo_module.module_config = ApiConverter::ToC(module->config());
|
||||
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
||||
auto allocator = ApiConverter::ToC(options.device_allocator);
|
||||
StatusHelper status;
|
||||
@ -286,7 +250,7 @@ class TpuCompiler : public Compiler {
|
||||
ApiConverter::Free(&hlo_module.module_config);
|
||||
});
|
||||
SE_Executable* result;
|
||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
||||
hlo_module.module_config = ApiConverter::ToC(module->config());
|
||||
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
||||
auto allocator = ApiConverter::ToC(options.device_allocator);
|
||||
|
||||
@ -324,7 +288,7 @@ class TpuCompiler : public Compiler {
|
||||
});
|
||||
for (int i = 0; i < module_group->size(); ++i) {
|
||||
const auto& config = module_group->module(i).config();
|
||||
se_module_group.module_config[i] = HloModuleConfigToC(config);
|
||||
se_module_group.module_config[i] = ApiConverter::ToC(config);
|
||||
}
|
||||
std::vector<SE_StreamExecutorList> se_lists(stream_exec.size());
|
||||
std::vector<std::vector<SE_StreamExecutor*>> se_lists_storage;
|
||||
|
@ -429,6 +429,10 @@ XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) {
|
||||
hlo_config.static_device_assignment =
|
||||
stream_executor::tpu::SerializeProto(dev_proto);
|
||||
}
|
||||
|
||||
hlo_config.debug_options =
|
||||
stream_executor::tpu::SerializeProto(config.debug_options());
|
||||
|
||||
if (config.has_entry_computation_layout()) {
|
||||
const auto& layout = config.entry_computation_layout();
|
||||
ApiConverter::ToC(layout.result_layout().shape(),
|
||||
@ -462,6 +466,9 @@ xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) {
|
||||
config.set_static_device_assignment(
|
||||
*(device_assignment.ConsumeValueOrDie()));
|
||||
}
|
||||
config.set_debug_options(
|
||||
stream_executor::tpu::DeserializeProto<xla::DebugOptions>(
|
||||
c_config.debug_options));
|
||||
return config;
|
||||
}
|
||||
|
||||
@ -477,6 +484,7 @@ void Free(XLA_HloModuleConfig* c_config) {
|
||||
stream_executor::tpu::SerializedProto_Free(
|
||||
c_config->static_device_assignment);
|
||||
}
|
||||
stream_executor::tpu::SerializedProto_Free(c_config->debug_options);
|
||||
}
|
||||
|
||||
} // namespace ApiConverter
|
||||
|
@ -269,6 +269,7 @@ typedef struct XLA_HloModuleConfig {
|
||||
int64_t replica_count;
|
||||
int64_t num_partitions;
|
||||
bool use_spmd_partitioning;
|
||||
TpuSerializedProto debug_options;
|
||||
bool has_static_device_assignment;
|
||||
TpuSerializedProto static_device_assignment;
|
||||
bool has_entry_computation_layout;
|
||||
|
Loading…
Reference in New Issue
Block a user