diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc index 069cf0d37e1..a0e38de2a2d 100644 --- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc +++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc @@ -203,42 +203,6 @@ class TpuExecutable : public TpuExecutableInterface { SE_Executable* se_executable_; }; -XLA_HloModuleConfig HloModuleConfigToC(const xla::HloModuleConfig& config) { - XLA_HloModuleConfig hlo_config; - - hlo_config.seed = config.seed(); - hlo_config.launch_id = config.launch_id(); - hlo_config.replica_count = config.replica_count(); - hlo_config.num_partitions = config.num_partitions(); - hlo_config.use_spmd_partitioning = config.use_spmd_partitioning(); - hlo_config.has_static_device_assignment = - config.has_static_device_assignment(); - hlo_config.has_entry_computation_layout = - config.has_entry_computation_layout(); - - if (config.has_static_device_assignment()) { - DeviceAssignmentProto dev_proto; - config.static_device_assignment().Serialize(&dev_proto).IgnoreError(); - hlo_config.static_device_assignment = - stream_executor::tpu::SerializeProto(dev_proto); - } - if (config.has_entry_computation_layout()) { - auto layout = config.entry_computation_layout(); - ApiConverter::ToC(layout.result_layout().shape(), - &hlo_config.entry_computation_layout.result_layout); - hlo_config.entry_computation_layout.parameter_layouts = - new XLA_Shape[layout.parameter_count()]; - for (int i = 0; i < layout.parameter_count(); ++i) { - ApiConverter::ToC( - layout.parameter_layout(i).shape(), - &hlo_config.entry_computation_layout.parameter_layouts[i]); - } - hlo_config.entry_computation_layout.parameter_count = - layout.parameter_count(); - } - return hlo_config; -} - class TpuCompiler : public Compiler { public: TpuCompiler() { compiler_ = ExecutorApiFn()->TpuCompiler_NewFn(); } @@ -259,7 +223,7 @@ class TpuCompiler : public Compiler { stream_executor::tpu::SerializedProto_Free(result.proto); ApiConverter::Free(&hlo_module.module_config); }); - hlo_module.module_config = HloModuleConfigToC(module->config()); + hlo_module.module_config = ApiConverter::ToC(module->config()); hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto()); auto allocator = ApiConverter::ToC(options.device_allocator); StatusHelper status; @@ -286,7 +250,7 @@ class TpuCompiler : public Compiler { ApiConverter::Free(&hlo_module.module_config); }); SE_Executable* result; - hlo_module.module_config = HloModuleConfigToC(module->config()); + hlo_module.module_config = ApiConverter::ToC(module->config()); hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto()); auto allocator = ApiConverter::ToC(options.device_allocator); @@ -324,7 +288,7 @@ class TpuCompiler : public Compiler { }); for (int i = 0; i < module_group->size(); ++i) { const auto& config = module_group->module(i).config(); - se_module_group.module_config[i] = HloModuleConfigToC(config); + se_module_group.module_config[i] = ApiConverter::ToC(config); } std::vector se_lists(stream_exec.size()); std::vector> se_lists_storage; diff --git a/tensorflow/stream_executor/tpu/c_api_conversions.cc b/tensorflow/stream_executor/tpu/c_api_conversions.cc index 896a2b6a1f8..3a507ab1680 100644 --- a/tensorflow/stream_executor/tpu/c_api_conversions.cc +++ b/tensorflow/stream_executor/tpu/c_api_conversions.cc @@ -429,6 +429,10 @@ XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) { hlo_config.static_device_assignment = stream_executor::tpu::SerializeProto(dev_proto); } + + hlo_config.debug_options = + stream_executor::tpu::SerializeProto(config.debug_options()); + if (config.has_entry_computation_layout()) { const auto& layout = config.entry_computation_layout(); ApiConverter::ToC(layout.result_layout().shape(), @@ -462,6 +466,9 @@ xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) { config.set_static_device_assignment( *(device_assignment.ConsumeValueOrDie())); } + config.set_debug_options( + stream_executor::tpu::DeserializeProto( + c_config.debug_options)); return config; } @@ -477,6 +484,7 @@ void Free(XLA_HloModuleConfig* c_config) { stream_executor::tpu::SerializedProto_Free( c_config->static_device_assignment); } + stream_executor::tpu::SerializedProto_Free(c_config->debug_options); } } // namespace ApiConverter diff --git a/tensorflow/stream_executor/tpu/c_api_decl.h b/tensorflow/stream_executor/tpu/c_api_decl.h index 71a725f5886..a9144ef2c4a 100644 --- a/tensorflow/stream_executor/tpu/c_api_decl.h +++ b/tensorflow/stream_executor/tpu/c_api_decl.h @@ -269,6 +269,7 @@ typedef struct XLA_HloModuleConfig { int64_t replica_count; int64_t num_partitions; bool use_spmd_partitioning; + TpuSerializedProto debug_options; bool has_static_device_assignment; TpuSerializedProto static_device_assignment; bool has_entry_computation_layout;