[libtpu] Plumb HloModule::debug_options_ through the C API.
PiperOrigin-RevId: 351196886 Change-Id: Idfa6fe5a01969efb888cc9a22718e72d881f1835
This commit is contained in:
parent
bdea938786
commit
96fdbb1a8d
@ -203,42 +203,6 @@ class TpuExecutable : public TpuExecutableInterface {
|
|||||||
SE_Executable* se_executable_;
|
SE_Executable* se_executable_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XLA_HloModuleConfig HloModuleConfigToC(const xla::HloModuleConfig& config) {
|
|
||||||
XLA_HloModuleConfig hlo_config;
|
|
||||||
|
|
||||||
hlo_config.seed = config.seed();
|
|
||||||
hlo_config.launch_id = config.launch_id();
|
|
||||||
hlo_config.replica_count = config.replica_count();
|
|
||||||
hlo_config.num_partitions = config.num_partitions();
|
|
||||||
hlo_config.use_spmd_partitioning = config.use_spmd_partitioning();
|
|
||||||
hlo_config.has_static_device_assignment =
|
|
||||||
config.has_static_device_assignment();
|
|
||||||
hlo_config.has_entry_computation_layout =
|
|
||||||
config.has_entry_computation_layout();
|
|
||||||
|
|
||||||
if (config.has_static_device_assignment()) {
|
|
||||||
DeviceAssignmentProto dev_proto;
|
|
||||||
config.static_device_assignment().Serialize(&dev_proto).IgnoreError();
|
|
||||||
hlo_config.static_device_assignment =
|
|
||||||
stream_executor::tpu::SerializeProto(dev_proto);
|
|
||||||
}
|
|
||||||
if (config.has_entry_computation_layout()) {
|
|
||||||
auto layout = config.entry_computation_layout();
|
|
||||||
ApiConverter::ToC(layout.result_layout().shape(),
|
|
||||||
&hlo_config.entry_computation_layout.result_layout);
|
|
||||||
hlo_config.entry_computation_layout.parameter_layouts =
|
|
||||||
new XLA_Shape[layout.parameter_count()];
|
|
||||||
for (int i = 0; i < layout.parameter_count(); ++i) {
|
|
||||||
ApiConverter::ToC(
|
|
||||||
layout.parameter_layout(i).shape(),
|
|
||||||
&hlo_config.entry_computation_layout.parameter_layouts[i]);
|
|
||||||
}
|
|
||||||
hlo_config.entry_computation_layout.parameter_count =
|
|
||||||
layout.parameter_count();
|
|
||||||
}
|
|
||||||
return hlo_config;
|
|
||||||
}
|
|
||||||
|
|
||||||
class TpuCompiler : public Compiler {
|
class TpuCompiler : public Compiler {
|
||||||
public:
|
public:
|
||||||
TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
|
TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
|
||||||
@ -259,7 +223,7 @@ class TpuCompiler : public Compiler {
|
|||||||
stream_executor::tpu::SerializedProto_Free(result.proto);
|
stream_executor::tpu::SerializedProto_Free(result.proto);
|
||||||
ApiConverter::Free(&hlo_module.module_config);
|
ApiConverter::Free(&hlo_module.module_config);
|
||||||
});
|
});
|
||||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
hlo_module.module_config = ApiConverter::ToC(module->config());
|
||||||
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
||||||
auto allocator = ApiConverter::ToC(options.device_allocator);
|
auto allocator = ApiConverter::ToC(options.device_allocator);
|
||||||
StatusHelper status;
|
StatusHelper status;
|
||||||
@ -286,7 +250,7 @@ class TpuCompiler : public Compiler {
|
|||||||
ApiConverter::Free(&hlo_module.module_config);
|
ApiConverter::Free(&hlo_module.module_config);
|
||||||
});
|
});
|
||||||
SE_Executable* result;
|
SE_Executable* result;
|
||||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
hlo_module.module_config = ApiConverter::ToC(module->config());
|
||||||
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
||||||
auto allocator = ApiConverter::ToC(options.device_allocator);
|
auto allocator = ApiConverter::ToC(options.device_allocator);
|
||||||
|
|
||||||
@ -324,7 +288,7 @@ class TpuCompiler : public Compiler {
|
|||||||
});
|
});
|
||||||
for (int i = 0; i < module_group->size(); ++i) {
|
for (int i = 0; i < module_group->size(); ++i) {
|
||||||
const auto& config = module_group->module(i).config();
|
const auto& config = module_group->module(i).config();
|
||||||
se_module_group.module_config[i] = HloModuleConfigToC(config);
|
se_module_group.module_config[i] = ApiConverter::ToC(config);
|
||||||
}
|
}
|
||||||
std::vector<SE_StreamExecutorList> se_lists(stream_exec.size());
|
std::vector<SE_StreamExecutorList> se_lists(stream_exec.size());
|
||||||
std::vector<std::vector<SE_StreamExecutor*>> se_lists_storage;
|
std::vector<std::vector<SE_StreamExecutor*>> se_lists_storage;
|
||||||
|
@ -429,6 +429,10 @@ XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) {
|
|||||||
hlo_config.static_device_assignment =
|
hlo_config.static_device_assignment =
|
||||||
stream_executor::tpu::SerializeProto(dev_proto);
|
stream_executor::tpu::SerializeProto(dev_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hlo_config.debug_options =
|
||||||
|
stream_executor::tpu::SerializeProto(config.debug_options());
|
||||||
|
|
||||||
if (config.has_entry_computation_layout()) {
|
if (config.has_entry_computation_layout()) {
|
||||||
const auto& layout = config.entry_computation_layout();
|
const auto& layout = config.entry_computation_layout();
|
||||||
ApiConverter::ToC(layout.result_layout().shape(),
|
ApiConverter::ToC(layout.result_layout().shape(),
|
||||||
@ -462,6 +466,9 @@ xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) {
|
|||||||
config.set_static_device_assignment(
|
config.set_static_device_assignment(
|
||||||
*(device_assignment.ConsumeValueOrDie()));
|
*(device_assignment.ConsumeValueOrDie()));
|
||||||
}
|
}
|
||||||
|
config.set_debug_options(
|
||||||
|
stream_executor::tpu::DeserializeProto<xla::DebugOptions>(
|
||||||
|
c_config.debug_options));
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -477,6 +484,7 @@ void Free(XLA_HloModuleConfig* c_config) {
|
|||||||
stream_executor::tpu::SerializedProto_Free(
|
stream_executor::tpu::SerializedProto_Free(
|
||||||
c_config->static_device_assignment);
|
c_config->static_device_assignment);
|
||||||
}
|
}
|
||||||
|
stream_executor::tpu::SerializedProto_Free(c_config->debug_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ApiConverter
|
} // namespace ApiConverter
|
||||||
|
@ -269,6 +269,7 @@ typedef struct XLA_HloModuleConfig {
|
|||||||
int64_t replica_count;
|
int64_t replica_count;
|
||||||
int64_t num_partitions;
|
int64_t num_partitions;
|
||||||
bool use_spmd_partitioning;
|
bool use_spmd_partitioning;
|
||||||
|
TpuSerializedProto debug_options;
|
||||||
bool has_static_device_assignment;
|
bool has_static_device_assignment;
|
||||||
TpuSerializedProto static_device_assignment;
|
TpuSerializedProto static_device_assignment;
|
||||||
bool has_entry_computation_layout;
|
bool has_entry_computation_layout;
|
||||||
|
Loading…
Reference in New Issue
Block a user