In TpuCompiler::Compile, initialize TpuExecutables with correct HloModules.
Prior to this change, the compiler in tpu_on_demand_compiler.cc would initialize TpuExecutables with the same HloModules passed to Compile(). This isn't always correct, for example when XLA partitions the computation and changes the input/output shapes. This change has the compiler initialize the TpuExecutables with the HloModules from the underlying executables returned from the TPU C API's compile function. Other changes included as part of this: * Adds C API function for getting an executable's HLO module. * Moves some functions to ApiConverter. The implementations added to c_api_conversions.cc are copies with no additional changes. * Removes some duplicate function declarations from ApiConverter. * Fixes a double-free from TpuCompiler::Compile. PiperOrigin-RevId: 332900516 Change-Id: I7a8b6f66921259fe7c50ab1f36f4af236c77f955
This commit is contained in:
parent
14696f7e09
commit
f50646565f
@ -140,6 +140,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutable_HloModule);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);
|
||||
|
@ -84,20 +84,6 @@ namespace {
|
||||
|
||||
using ::tensorflow::tpu::ExecutorApiFn;
|
||||
|
||||
void XLA_HloModuleConfig_Free(XLA_HloModuleConfig* module_config) {
|
||||
for (auto i = 0; i < module_config->entry_computation_layout.parameter_count;
|
||||
++i) {
|
||||
ApiConverter::Free(
|
||||
&module_config->entry_computation_layout.parameter_layouts[i]);
|
||||
}
|
||||
delete[] module_config->entry_computation_layout.parameter_layouts;
|
||||
ApiConverter::Free(&module_config->entry_computation_layout.result_layout);
|
||||
if (module_config->has_static_device_assignment) {
|
||||
stream_executor::tpu::SerializedProto_Free(
|
||||
module_config->static_device_assignment);
|
||||
}
|
||||
}
|
||||
|
||||
class TpuExecutable : public TpuExecutableInterface {
|
||||
public:
|
||||
TpuExecutable(SE_Executable* se_executable,
|
||||
@ -273,7 +259,7 @@ class TpuCompiler : public Compiler {
|
||||
auto cleanup = xla::MakeCleanup([&hlo_module, &result]() {
|
||||
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
|
||||
stream_executor::tpu::SerializedProto_Free(result.proto);
|
||||
XLA_HloModuleConfig_Free(&hlo_module.module_config);
|
||||
ApiConverter::Free(&hlo_module.module_config);
|
||||
});
|
||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
||||
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
||||
@ -309,7 +295,7 @@ class TpuCompiler : public Compiler {
|
||||
XLA_HloModule hlo_module;
|
||||
auto cleanup = xla::MakeCleanup([&hlo_module]() {
|
||||
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
|
||||
XLA_HloModuleConfig_Free(&hlo_module.module_config);
|
||||
ApiConverter::Free(&hlo_module.module_config);
|
||||
});
|
||||
SE_Executable* result;
|
||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
||||
@ -344,7 +330,7 @@ class TpuCompiler : public Compiler {
|
||||
auto cleanup_config =
|
||||
xla::MakeCleanup([&se_module_group, module_group_size]() {
|
||||
for (auto i = 0; i < module_group_size; ++i) {
|
||||
XLA_HloModuleConfig_Free(&se_module_group.module_config[i]);
|
||||
ApiConverter::Free(&se_module_group.module_config[i]);
|
||||
}
|
||||
delete[] se_module_group.module_config;
|
||||
});
|
||||
@ -378,15 +364,24 @@ class TpuCompiler : public Compiler {
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Executable>> executables;
|
||||
std::vector<std::unique_ptr<HloModule>> modules =
|
||||
module_group->ConsumeModules();
|
||||
for (int i = 0; i < module_group->size(); ++i) {
|
||||
executables[i] = absl::make_unique<TpuExecutable>(se_executables[i],
|
||||
std::move(modules[i]));
|
||||
// We get the HloModule from the compiled executable, rather than reusing
|
||||
// the input module from 'module_group', in case the module changed in
|
||||
// some way. For example, if the computation is automatically partitioned
|
||||
// via XLA, the executable's module may have different input/output shapes
|
||||
// than the input module.
|
||||
XLA_HloModule c_module =
|
||||
ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executables[i]);
|
||||
auto cleanup_c_module =
|
||||
xla::MakeCleanup([&c_module]() { ApiConverter::Free(&c_module); });
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
ApiConverter::FromC(c_module));
|
||||
std::shared_ptr<HloModule> module_shared(module.release());
|
||||
executables.emplace_back(absl::make_unique<TpuExecutable>(
|
||||
se_executables[i], std::move(module_shared)));
|
||||
}
|
||||
|
||||
stream_executor::tpu::SerializedProto_Free(se_module_group.proto);
|
||||
delete se_module_group.module_config;
|
||||
delete[] se_executables;
|
||||
|
||||
return executables;
|
||||
|
@ -45,6 +45,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core/tpu:tpu_api",
|
||||
|
@ -150,7 +150,7 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) {
|
||||
return base;
|
||||
}
|
||||
|
||||
xla::Shape FromC(XLA_Shape* shape) {
|
||||
xla::Shape FromC(const XLA_Shape* shape) {
|
||||
xla::ShapeProto p;
|
||||
p.ParseFromArray(shape->bytes, shape->size);
|
||||
return xla::Shape(p);
|
||||
@ -230,4 +230,106 @@ void Free(XLA_ShapedBuffer* c_buffer) {
|
||||
delete[] c_buffer->bases;
|
||||
}
|
||||
|
||||
XLA_HloModule ToC(const xla::HloModule& module) {
|
||||
XLA_HloModule c_module;
|
||||
c_module.proto = stream_executor::tpu::SerializeProto(module.ToProto());
|
||||
c_module.module_config = ApiConverter::ToC(module.config());
|
||||
return c_module;
|
||||
}
|
||||
|
||||
xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC(
|
||||
const XLA_HloModule& c_module) {
|
||||
xla::HloModuleProto module_proto =
|
||||
stream_executor::tpu::DeserializeProto<xla::HloModuleProto>(
|
||||
c_module.proto);
|
||||
return xla::HloModule::CreateFromProto(
|
||||
module_proto, ApiConverter::FromC(c_module.module_config));
|
||||
}
|
||||
|
||||
void Free(XLA_HloModule* c_module) {
|
||||
stream_executor::tpu::SerializedProto_Free(c_module->proto);
|
||||
Free(&c_module->module_config);
|
||||
}
|
||||
|
||||
static xla::HloModuleConfig ConfigWithLayout(
|
||||
const XLA_HloModuleConfig& se_config) {
|
||||
xla::ShapeLayout result_layout(
|
||||
FromC(&se_config.entry_computation_layout.result_layout));
|
||||
xla::ComputationLayout layout(result_layout);
|
||||
for (int i = 0; i < se_config.entry_computation_layout.parameter_count; ++i) {
|
||||
layout.add_parameter_layout(xla::ShapeLayout(
|
||||
FromC(&se_config.entry_computation_layout.parameter_layouts[i])));
|
||||
}
|
||||
return xla::HloModuleConfig(layout);
|
||||
}
|
||||
|
||||
XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) {
|
||||
XLA_HloModuleConfig hlo_config;
|
||||
|
||||
hlo_config.seed = config.seed();
|
||||
hlo_config.launch_id = config.launch_id();
|
||||
hlo_config.replica_count = config.replica_count();
|
||||
hlo_config.num_partitions = config.num_partitions();
|
||||
hlo_config.use_spmd_partitioning = config.use_spmd_partitioning();
|
||||
hlo_config.has_static_device_assignment =
|
||||
config.has_static_device_assignment();
|
||||
hlo_config.has_entry_computation_layout =
|
||||
config.has_entry_computation_layout();
|
||||
|
||||
if (config.has_static_device_assignment()) {
|
||||
xla::DeviceAssignmentProto dev_proto;
|
||||
config.static_device_assignment().Serialize(&dev_proto).IgnoreError();
|
||||
hlo_config.static_device_assignment =
|
||||
stream_executor::tpu::SerializeProto(dev_proto);
|
||||
}
|
||||
if (config.has_entry_computation_layout()) {
|
||||
const auto& layout = config.entry_computation_layout();
|
||||
ApiConverter::ToC(layout.result_layout().shape(),
|
||||
&hlo_config.entry_computation_layout.result_layout);
|
||||
hlo_config.entry_computation_layout.parameter_layouts =
|
||||
new XLA_Shape[layout.parameter_count()];
|
||||
for (int i = 0; i < layout.parameter_count(); ++i) {
|
||||
ApiConverter::ToC(
|
||||
layout.parameter_layout(i).shape(),
|
||||
&hlo_config.entry_computation_layout.parameter_layouts[i]);
|
||||
}
|
||||
hlo_config.entry_computation_layout.parameter_count =
|
||||
layout.parameter_count();
|
||||
}
|
||||
return hlo_config;
|
||||
}
|
||||
|
||||
xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) {
|
||||
xla::HloModuleConfig config = c_config.has_entry_computation_layout
|
||||
? ConfigWithLayout(c_config)
|
||||
: xla::HloModuleConfig();
|
||||
config.set_launch_id(c_config.launch_id);
|
||||
config.set_seed(c_config.seed);
|
||||
config.set_replica_count(c_config.replica_count);
|
||||
config.set_num_partitions(c_config.num_partitions);
|
||||
config.set_use_spmd_partitioning(c_config.use_spmd_partitioning);
|
||||
if (c_config.has_static_device_assignment) {
|
||||
auto device_assignment = xla::DeviceAssignment::Deserialize(
|
||||
stream_executor::tpu::DeserializeProto<xla::DeviceAssignmentProto>(
|
||||
c_config.static_device_assignment));
|
||||
config.set_static_device_assignment(
|
||||
*(device_assignment.ConsumeValueOrDie()));
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
void Free(XLA_HloModuleConfig* c_config) {
|
||||
for (auto i = 0; i < c_config->entry_computation_layout.parameter_count;
|
||||
++i) {
|
||||
ApiConverter::Free(
|
||||
&c_config->entry_computation_layout.parameter_layouts[i]);
|
||||
}
|
||||
delete[] c_config->entry_computation_layout.parameter_layouts;
|
||||
ApiConverter::Free(&c_config->entry_computation_layout.result_layout);
|
||||
if (c_config->has_static_device_assignment) {
|
||||
stream_executor::tpu::SerializedProto_Free(
|
||||
c_config->static_device_assignment);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ApiConverter
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
@ -41,7 +43,7 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
|
||||
void Free(SE_DeviceMemoryBase*);
|
||||
|
||||
// xla::Shape
|
||||
xla::Shape FromC(XLA_Shape* shape);
|
||||
xla::Shape FromC(const XLA_Shape* shape);
|
||||
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
|
||||
void Free(XLA_Shape* shape);
|
||||
|
||||
@ -65,11 +67,6 @@ SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base);
|
||||
stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
|
||||
void Free(SE_DeviceMemoryBase*);
|
||||
|
||||
// xla::Shape
|
||||
xla::Shape FromC(XLA_Shape* shape);
|
||||
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
|
||||
void Free(XLA_Shape* shape);
|
||||
|
||||
// Literal
|
||||
void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal);
|
||||
xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal);
|
||||
@ -94,6 +91,17 @@ SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem);
|
||||
// 'mem' is unowned.
|
||||
SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased);
|
||||
|
||||
// HloModule
|
||||
XLA_HloModule ToC(const xla::HloModule& module);
|
||||
xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC(
|
||||
const XLA_HloModule& c_module);
|
||||
void Free(XLA_HloModule* c_module);
|
||||
|
||||
// HloModuleConfig
|
||||
XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config);
|
||||
xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config);
|
||||
void Free(XLA_HloModuleConfig* c_config);
|
||||
|
||||
// Helper for managing stack based C -> C++ conversions.
|
||||
template <class CType>
|
||||
struct StackHelper {
|
||||
|
@ -310,6 +310,11 @@ TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable,
|
||||
const char** fingerprint,
|
||||
size_t* size);
|
||||
|
||||
// Caller is responsible for freeing the returned module's proto and its
|
||||
// config's proto.
|
||||
TFTPU_CAPI_EXPORT XLA_HloModule
|
||||
TpuExecutable_HloModule(SE_Executable* executable);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
|
||||
|
||||
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
||||
@ -458,6 +463,7 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_HloModule);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);
|
||||
|
Loading…
Reference in New Issue
Block a user