In TpuCompiler::Compile, initialize TpuExecutables with correct HloModules.
Prior to this change, the compiler in tpu_on_demand_compiler.cc would initialize TpuExecutables with the same HloModules passed to Compile(). This isn't always correct, for example when XLA partitions the computation and changes the input/output shapes. This change has the compiler initialize the TpuExecutables with the HloModules from the underlying executables returned from the TPU C API's compile function. Other changes included as part of this: * Adds C API function for getting an executable's HLO module. * Moves some functions to ApiConverter. The implementations added to c_api_conversions.cc are copies with no additional changes. * Removes some duplicate function declarations from ApiConverter. * Fixes a double-free from TpuCompiler::Compile. PiperOrigin-RevId: 332900516 Change-Id: I7a8b6f66921259fe7c50ab1f36f4af236c77f955
This commit is contained in:
parent
14696f7e09
commit
f50646565f
@ -140,6 +140,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
|||||||
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
|
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
|
||||||
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
|
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
|
||||||
TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint);
|
TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutable_HloModule);
|
||||||
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
|
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
|
||||||
|
|
||||||
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);
|
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);
|
||||||
|
|||||||
@ -84,20 +84,6 @@ namespace {
|
|||||||
|
|
||||||
using ::tensorflow::tpu::ExecutorApiFn;
|
using ::tensorflow::tpu::ExecutorApiFn;
|
||||||
|
|
||||||
void XLA_HloModuleConfig_Free(XLA_HloModuleConfig* module_config) {
|
|
||||||
for (auto i = 0; i < module_config->entry_computation_layout.parameter_count;
|
|
||||||
++i) {
|
|
||||||
ApiConverter::Free(
|
|
||||||
&module_config->entry_computation_layout.parameter_layouts[i]);
|
|
||||||
}
|
|
||||||
delete[] module_config->entry_computation_layout.parameter_layouts;
|
|
||||||
ApiConverter::Free(&module_config->entry_computation_layout.result_layout);
|
|
||||||
if (module_config->has_static_device_assignment) {
|
|
||||||
stream_executor::tpu::SerializedProto_Free(
|
|
||||||
module_config->static_device_assignment);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class TpuExecutable : public TpuExecutableInterface {
|
class TpuExecutable : public TpuExecutableInterface {
|
||||||
public:
|
public:
|
||||||
TpuExecutable(SE_Executable* se_executable,
|
TpuExecutable(SE_Executable* se_executable,
|
||||||
@ -273,7 +259,7 @@ class TpuCompiler : public Compiler {
|
|||||||
auto cleanup = xla::MakeCleanup([&hlo_module, &result]() {
|
auto cleanup = xla::MakeCleanup([&hlo_module, &result]() {
|
||||||
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
|
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
|
||||||
stream_executor::tpu::SerializedProto_Free(result.proto);
|
stream_executor::tpu::SerializedProto_Free(result.proto);
|
||||||
XLA_HloModuleConfig_Free(&hlo_module.module_config);
|
ApiConverter::Free(&hlo_module.module_config);
|
||||||
});
|
});
|
||||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
hlo_module.module_config = HloModuleConfigToC(module->config());
|
||||||
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
|
||||||
@ -309,7 +295,7 @@ class TpuCompiler : public Compiler {
|
|||||||
XLA_HloModule hlo_module;
|
XLA_HloModule hlo_module;
|
||||||
auto cleanup = xla::MakeCleanup([&hlo_module]() {
|
auto cleanup = xla::MakeCleanup([&hlo_module]() {
|
||||||
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
|
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
|
||||||
XLA_HloModuleConfig_Free(&hlo_module.module_config);
|
ApiConverter::Free(&hlo_module.module_config);
|
||||||
});
|
});
|
||||||
SE_Executable* result;
|
SE_Executable* result;
|
||||||
hlo_module.module_config = HloModuleConfigToC(module->config());
|
hlo_module.module_config = HloModuleConfigToC(module->config());
|
||||||
@ -344,7 +330,7 @@ class TpuCompiler : public Compiler {
|
|||||||
auto cleanup_config =
|
auto cleanup_config =
|
||||||
xla::MakeCleanup([&se_module_group, module_group_size]() {
|
xla::MakeCleanup([&se_module_group, module_group_size]() {
|
||||||
for (auto i = 0; i < module_group_size; ++i) {
|
for (auto i = 0; i < module_group_size; ++i) {
|
||||||
XLA_HloModuleConfig_Free(&se_module_group.module_config[i]);
|
ApiConverter::Free(&se_module_group.module_config[i]);
|
||||||
}
|
}
|
||||||
delete[] se_module_group.module_config;
|
delete[] se_module_group.module_config;
|
||||||
});
|
});
|
||||||
@ -378,15 +364,24 @@ class TpuCompiler : public Compiler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<Executable>> executables;
|
std::vector<std::unique_ptr<Executable>> executables;
|
||||||
std::vector<std::unique_ptr<HloModule>> modules =
|
|
||||||
module_group->ConsumeModules();
|
|
||||||
for (int i = 0; i < module_group->size(); ++i) {
|
for (int i = 0; i < module_group->size(); ++i) {
|
||||||
executables[i] = absl::make_unique<TpuExecutable>(se_executables[i],
|
// We get the HloModule from the compiled executable, rather than reusing
|
||||||
std::move(modules[i]));
|
// the input module from 'module_group', in case the module changed in
|
||||||
|
// some way. For example, if the computation is automatically partitioned
|
||||||
|
// via XLA, the executable's module may have different input/output shapes
|
||||||
|
// than the input module.
|
||||||
|
XLA_HloModule c_module =
|
||||||
|
ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executables[i]);
|
||||||
|
auto cleanup_c_module =
|
||||||
|
xla::MakeCleanup([&c_module]() { ApiConverter::Free(&c_module); });
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||||
|
ApiConverter::FromC(c_module));
|
||||||
|
std::shared_ptr<HloModule> module_shared(module.release());
|
||||||
|
executables.emplace_back(absl::make_unique<TpuExecutable>(
|
||||||
|
se_executables[i], std::move(module_shared)));
|
||||||
}
|
}
|
||||||
|
|
||||||
stream_executor::tpu::SerializedProto_Free(se_module_group.proto);
|
stream_executor::tpu::SerializedProto_Free(se_module_group.proto);
|
||||||
delete se_module_group.module_config;
|
|
||||||
delete[] se_executables;
|
delete[] se_executables;
|
||||||
|
|
||||||
return executables;
|
return executables;
|
||||||
|
|||||||
@ -45,6 +45,8 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla/service:executable",
|
"//tensorflow/compiler/xla/service:executable",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||||
"//tensorflow/core/tpu:tpu_api",
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
|
|||||||
@ -150,7 +150,7 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) {
|
|||||||
return base;
|
return base;
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::Shape FromC(XLA_Shape* shape) {
|
xla::Shape FromC(const XLA_Shape* shape) {
|
||||||
xla::ShapeProto p;
|
xla::ShapeProto p;
|
||||||
p.ParseFromArray(shape->bytes, shape->size);
|
p.ParseFromArray(shape->bytes, shape->size);
|
||||||
return xla::Shape(p);
|
return xla::Shape(p);
|
||||||
@ -230,4 +230,106 @@ void Free(XLA_ShapedBuffer* c_buffer) {
|
|||||||
delete[] c_buffer->bases;
|
delete[] c_buffer->bases;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_HloModule ToC(const xla::HloModule& module) {
|
||||||
|
XLA_HloModule c_module;
|
||||||
|
c_module.proto = stream_executor::tpu::SerializeProto(module.ToProto());
|
||||||
|
c_module.module_config = ApiConverter::ToC(module.config());
|
||||||
|
return c_module;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC(
|
||||||
|
const XLA_HloModule& c_module) {
|
||||||
|
xla::HloModuleProto module_proto =
|
||||||
|
stream_executor::tpu::DeserializeProto<xla::HloModuleProto>(
|
||||||
|
c_module.proto);
|
||||||
|
return xla::HloModule::CreateFromProto(
|
||||||
|
module_proto, ApiConverter::FromC(c_module.module_config));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(XLA_HloModule* c_module) {
|
||||||
|
stream_executor::tpu::SerializedProto_Free(c_module->proto);
|
||||||
|
Free(&c_module->module_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
static xla::HloModuleConfig ConfigWithLayout(
|
||||||
|
const XLA_HloModuleConfig& se_config) {
|
||||||
|
xla::ShapeLayout result_layout(
|
||||||
|
FromC(&se_config.entry_computation_layout.result_layout));
|
||||||
|
xla::ComputationLayout layout(result_layout);
|
||||||
|
for (int i = 0; i < se_config.entry_computation_layout.parameter_count; ++i) {
|
||||||
|
layout.add_parameter_layout(xla::ShapeLayout(
|
||||||
|
FromC(&se_config.entry_computation_layout.parameter_layouts[i])));
|
||||||
|
}
|
||||||
|
return xla::HloModuleConfig(layout);
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) {
|
||||||
|
XLA_HloModuleConfig hlo_config;
|
||||||
|
|
||||||
|
hlo_config.seed = config.seed();
|
||||||
|
hlo_config.launch_id = config.launch_id();
|
||||||
|
hlo_config.replica_count = config.replica_count();
|
||||||
|
hlo_config.num_partitions = config.num_partitions();
|
||||||
|
hlo_config.use_spmd_partitioning = config.use_spmd_partitioning();
|
||||||
|
hlo_config.has_static_device_assignment =
|
||||||
|
config.has_static_device_assignment();
|
||||||
|
hlo_config.has_entry_computation_layout =
|
||||||
|
config.has_entry_computation_layout();
|
||||||
|
|
||||||
|
if (config.has_static_device_assignment()) {
|
||||||
|
xla::DeviceAssignmentProto dev_proto;
|
||||||
|
config.static_device_assignment().Serialize(&dev_proto).IgnoreError();
|
||||||
|
hlo_config.static_device_assignment =
|
||||||
|
stream_executor::tpu::SerializeProto(dev_proto);
|
||||||
|
}
|
||||||
|
if (config.has_entry_computation_layout()) {
|
||||||
|
const auto& layout = config.entry_computation_layout();
|
||||||
|
ApiConverter::ToC(layout.result_layout().shape(),
|
||||||
|
&hlo_config.entry_computation_layout.result_layout);
|
||||||
|
hlo_config.entry_computation_layout.parameter_layouts =
|
||||||
|
new XLA_Shape[layout.parameter_count()];
|
||||||
|
for (int i = 0; i < layout.parameter_count(); ++i) {
|
||||||
|
ApiConverter::ToC(
|
||||||
|
layout.parameter_layout(i).shape(),
|
||||||
|
&hlo_config.entry_computation_layout.parameter_layouts[i]);
|
||||||
|
}
|
||||||
|
hlo_config.entry_computation_layout.parameter_count =
|
||||||
|
layout.parameter_count();
|
||||||
|
}
|
||||||
|
return hlo_config;
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) {
|
||||||
|
xla::HloModuleConfig config = c_config.has_entry_computation_layout
|
||||||
|
? ConfigWithLayout(c_config)
|
||||||
|
: xla::HloModuleConfig();
|
||||||
|
config.set_launch_id(c_config.launch_id);
|
||||||
|
config.set_seed(c_config.seed);
|
||||||
|
config.set_replica_count(c_config.replica_count);
|
||||||
|
config.set_num_partitions(c_config.num_partitions);
|
||||||
|
config.set_use_spmd_partitioning(c_config.use_spmd_partitioning);
|
||||||
|
if (c_config.has_static_device_assignment) {
|
||||||
|
auto device_assignment = xla::DeviceAssignment::Deserialize(
|
||||||
|
stream_executor::tpu::DeserializeProto<xla::DeviceAssignmentProto>(
|
||||||
|
c_config.static_device_assignment));
|
||||||
|
config.set_static_device_assignment(
|
||||||
|
*(device_assignment.ConsumeValueOrDie()));
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(XLA_HloModuleConfig* c_config) {
|
||||||
|
for (auto i = 0; i < c_config->entry_computation_layout.parameter_count;
|
||||||
|
++i) {
|
||||||
|
ApiConverter::Free(
|
||||||
|
&c_config->entry_computation_layout.parameter_layouts[i]);
|
||||||
|
}
|
||||||
|
delete[] c_config->entry_computation_layout.parameter_layouts;
|
||||||
|
ApiConverter::Free(&c_config->entry_computation_layout.result_layout);
|
||||||
|
if (c_config->has_static_device_assignment) {
|
||||||
|
stream_executor::tpu::SerializedProto_Free(
|
||||||
|
c_config->static_device_assignment);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace ApiConverter
|
} // namespace ApiConverter
|
||||||
|
|||||||
@ -19,6 +19,8 @@ limitations under the License.
|
|||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||||
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
|
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||||
@ -41,7 +43,7 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
|
|||||||
void Free(SE_DeviceMemoryBase*);
|
void Free(SE_DeviceMemoryBase*);
|
||||||
|
|
||||||
// xla::Shape
|
// xla::Shape
|
||||||
xla::Shape FromC(XLA_Shape* shape);
|
xla::Shape FromC(const XLA_Shape* shape);
|
||||||
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
|
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
|
||||||
void Free(XLA_Shape* shape);
|
void Free(XLA_Shape* shape);
|
||||||
|
|
||||||
@ -65,11 +67,6 @@ SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base);
|
|||||||
stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
|
stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
|
||||||
void Free(SE_DeviceMemoryBase*);
|
void Free(SE_DeviceMemoryBase*);
|
||||||
|
|
||||||
// xla::Shape
|
|
||||||
xla::Shape FromC(XLA_Shape* shape);
|
|
||||||
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
|
|
||||||
void Free(XLA_Shape* shape);
|
|
||||||
|
|
||||||
// Literal
|
// Literal
|
||||||
void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal);
|
void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal);
|
||||||
xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal);
|
xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal);
|
||||||
@ -94,6 +91,17 @@ SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem);
|
|||||||
// 'mem' is unowned.
|
// 'mem' is unowned.
|
||||||
SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased);
|
SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased);
|
||||||
|
|
||||||
|
// HloModule
|
||||||
|
XLA_HloModule ToC(const xla::HloModule& module);
|
||||||
|
xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC(
|
||||||
|
const XLA_HloModule& c_module);
|
||||||
|
void Free(XLA_HloModule* c_module);
|
||||||
|
|
||||||
|
// HloModuleConfig
|
||||||
|
XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config);
|
||||||
|
xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config);
|
||||||
|
void Free(XLA_HloModuleConfig* c_config);
|
||||||
|
|
||||||
// Helper for managing stack based C -> C++ conversions.
|
// Helper for managing stack based C -> C++ conversions.
|
||||||
template <class CType>
|
template <class CType>
|
||||||
struct StackHelper {
|
struct StackHelper {
|
||||||
|
|||||||
@ -310,6 +310,11 @@ TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable,
|
|||||||
const char** fingerprint,
|
const char** fingerprint,
|
||||||
size_t* size);
|
size_t* size);
|
||||||
|
|
||||||
|
// Caller is responsible for freeing the returned module's proto and its
|
||||||
|
// config's proto.
|
||||||
|
TFTPU_CAPI_EXPORT XLA_HloModule
|
||||||
|
TpuExecutable_HloModule(SE_Executable* executable);
|
||||||
|
|
||||||
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
|
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
|
||||||
|
|
||||||
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
||||||
@ -458,6 +463,7 @@ struct TfTpu_ExecutorApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint);
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_HloModule);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
|
||||||
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);
|
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user