[libtpu] Plumb HloModule::debug_options_ through the C API.

PiperOrigin-RevId: 351196886
Change-Id: Idfa6fe5a01969efb888cc9a22718e72d881f1835
This commit is contained in:
Skye Wanderman-Milne 2021-01-11 11:24:35 -08:00 committed by TensorFlower Gardener
parent bdea938786
commit 96fdbb1a8d
3 changed files with 12 additions and 39 deletions

View File

@ -203,42 +203,6 @@ class TpuExecutable : public TpuExecutableInterface {
SE_Executable* se_executable_;
};
XLA_HloModuleConfig HloModuleConfigToC(const xla::HloModuleConfig& config) {
XLA_HloModuleConfig hlo_config;
hlo_config.seed = config.seed();
hlo_config.launch_id = config.launch_id();
hlo_config.replica_count = config.replica_count();
hlo_config.num_partitions = config.num_partitions();
hlo_config.use_spmd_partitioning = config.use_spmd_partitioning();
hlo_config.has_static_device_assignment =
config.has_static_device_assignment();
hlo_config.has_entry_computation_layout =
config.has_entry_computation_layout();
if (config.has_static_device_assignment()) {
DeviceAssignmentProto dev_proto;
config.static_device_assignment().Serialize(&dev_proto).IgnoreError();
hlo_config.static_device_assignment =
stream_executor::tpu::SerializeProto(dev_proto);
}
if (config.has_entry_computation_layout()) {
auto layout = config.entry_computation_layout();
ApiConverter::ToC(layout.result_layout().shape(),
&hlo_config.entry_computation_layout.result_layout);
hlo_config.entry_computation_layout.parameter_layouts =
new XLA_Shape[layout.parameter_count()];
for (int i = 0; i < layout.parameter_count(); ++i) {
ApiConverter::ToC(
layout.parameter_layout(i).shape(),
&hlo_config.entry_computation_layout.parameter_layouts[i]);
}
hlo_config.entry_computation_layout.parameter_count =
layout.parameter_count();
}
return hlo_config;
}
class TpuCompiler : public Compiler {
public:
TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); }
@ -259,7 +223,7 @@ class TpuCompiler : public Compiler {
stream_executor::tpu::SerializedProto_Free(result.proto);
ApiConverter::Free(&hlo_module.module_config);
});
hlo_module.module_config = HloModuleConfigToC(module->config());
hlo_module.module_config = ApiConverter::ToC(module->config());
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
auto allocator = ApiConverter::ToC(options.device_allocator);
StatusHelper status;
@ -286,7 +250,7 @@ class TpuCompiler : public Compiler {
ApiConverter::Free(&hlo_module.module_config);
});
SE_Executable* result;
hlo_module.module_config = HloModuleConfigToC(module->config());
hlo_module.module_config = ApiConverter::ToC(module->config());
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
auto allocator = ApiConverter::ToC(options.device_allocator);
@ -324,7 +288,7 @@ class TpuCompiler : public Compiler {
});
for (int i = 0; i < module_group->size(); ++i) {
const auto& config = module_group->module(i).config();
se_module_group.module_config[i] = HloModuleConfigToC(config);
se_module_group.module_config[i] = ApiConverter::ToC(config);
}
std::vector<SE_StreamExecutorList> se_lists(stream_exec.size());
std::vector<std::vector<SE_StreamExecutor*>> se_lists_storage;

View File

@ -429,6 +429,10 @@ XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) {
hlo_config.static_device_assignment =
stream_executor::tpu::SerializeProto(dev_proto);
}
hlo_config.debug_options =
stream_executor::tpu::SerializeProto(config.debug_options());
if (config.has_entry_computation_layout()) {
const auto& layout = config.entry_computation_layout();
ApiConverter::ToC(layout.result_layout().shape(),
@ -462,6 +466,9 @@ xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) {
config.set_static_device_assignment(
*(device_assignment.ConsumeValueOrDie()));
}
config.set_debug_options(
stream_executor::tpu::DeserializeProto<xla::DebugOptions>(
c_config.debug_options));
return config;
}
@ -477,6 +484,7 @@ void Free(XLA_HloModuleConfig* c_config) {
stream_executor::tpu::SerializedProto_Free(
c_config->static_device_assignment);
}
stream_executor::tpu::SerializedProto_Free(c_config->debug_options);
}
} // namespace ApiConverter

View File

@ -269,6 +269,7 @@ typedef struct XLA_HloModuleConfig {
int64_t replica_count;
int64_t num_partitions;
bool use_spmd_partitioning;
TpuSerializedProto debug_options;
bool has_static_device_assignment;
TpuSerializedProto static_device_assignment;
bool has_entry_computation_layout;