In TpuCompiler::Compile, initialize TpuExecutables with correct HloModules.

Prior to this change, the compiler in tpu_on_demand_compiler.cc would
initialize TpuExecutables with the same HloModules passed to
Compile(). This isn't always correct, for example when XLA partitions
the computation and changes the input/output shapes.

This change has the compiler initialize the TpuExecutables with the
HloModules from the underlying executables returned from the TPU C
API's compile function.

Other changes included as part of this:
* Adds C API function for getting an executable's HLO module.
* Moves some functions to ApiConverter. The implementations added to
  c_api_conversions.cc are copies with no additional changes.
* Removes some duplicate function declarations from ApiConverter.
* Fixes a double-free from TpuCompiler::Compile.

PiperOrigin-RevId: 332900516
Change-Id: I7a8b6f66921259fe7c50ab1f36f4af236c77f955
This commit is contained in:
Skye Wanderman-Milne 2020-09-21 11:57:34 -07:00 committed by TensorFlower Gardener
parent 14696f7e09
commit f50646565f
6 changed files with 143 additions and 29 deletions

View File

@ -140,6 +140,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint);
TFTPU_SET_FN(executor_fn, TpuExecutable_HloModule);
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);

View File

@ -84,20 +84,6 @@ namespace {
using ::tensorflow::tpu::ExecutorApiFn;
void XLA_HloModuleConfig_Free(XLA_HloModuleConfig* module_config) {
for (auto i = 0; i < module_config->entry_computation_layout.parameter_count;
++i) {
ApiConverter::Free(
&module_config->entry_computation_layout.parameter_layouts[i]);
}
delete[] module_config->entry_computation_layout.parameter_layouts;
ApiConverter::Free(&module_config->entry_computation_layout.result_layout);
if (module_config->has_static_device_assignment) {
stream_executor::tpu::SerializedProto_Free(
module_config->static_device_assignment);
}
}
class TpuExecutable : public TpuExecutableInterface {
public:
TpuExecutable(SE_Executable* se_executable,
@ -273,7 +259,7 @@ class TpuCompiler : public Compiler {
auto cleanup = xla::MakeCleanup([&hlo_module, &result]() {
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
stream_executor::tpu::SerializedProto_Free(result.proto);
XLA_HloModuleConfig_Free(&hlo_module.module_config);
ApiConverter::Free(&hlo_module.module_config);
});
hlo_module.module_config = HloModuleConfigToC(module->config());
hlo_module.proto = stream_executor::tpu::SerializeProto(module->ToProto());
@ -309,7 +295,7 @@ class TpuCompiler : public Compiler {
XLA_HloModule hlo_module;
auto cleanup = xla::MakeCleanup([&hlo_module]() {
stream_executor::tpu::SerializedProto_Free(hlo_module.proto);
XLA_HloModuleConfig_Free(&hlo_module.module_config);
ApiConverter::Free(&hlo_module.module_config);
});
SE_Executable* result;
hlo_module.module_config = HloModuleConfigToC(module->config());
@ -344,7 +330,7 @@ class TpuCompiler : public Compiler {
auto cleanup_config =
xla::MakeCleanup([&se_module_group, module_group_size]() {
for (auto i = 0; i < module_group_size; ++i) {
XLA_HloModuleConfig_Free(&se_module_group.module_config[i]);
ApiConverter::Free(&se_module_group.module_config[i]);
}
delete[] se_module_group.module_config;
});
@ -378,15 +364,24 @@ class TpuCompiler : public Compiler {
}
std::vector<std::unique_ptr<Executable>> executables;
std::vector<std::unique_ptr<HloModule>> modules =
module_group->ConsumeModules();
for (int i = 0; i < module_group->size(); ++i) {
executables[i] = absl::make_unique<TpuExecutable>(se_executables[i],
std::move(modules[i]));
// We get the HloModule from the compiled executable, rather than reusing
// the input module from 'module_group', in case the module changed in
// some way. For example, if the computation is automatically partitioned
// via XLA, the executable's module may have different input/output shapes
// than the input module.
XLA_HloModule c_module =
ExecutorApiFn()->TpuExecutable_HloModuleFn(se_executables[i]);
auto cleanup_c_module =
xla::MakeCleanup([&c_module]() { ApiConverter::Free(&c_module); });
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
ApiConverter::FromC(c_module));
std::shared_ptr<HloModule> module_shared(module.release());
executables.emplace_back(absl::make_unique<TpuExecutable>(
se_executables[i], std::move(module_shared)));
}
stream_executor::tpu::SerializedProto_Free(se_module_group.proto);
delete se_module_group.module_config;
delete[] se_executables;
return executables;

View File

@ -45,6 +45,8 @@ cc_library(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core/tpu:tpu_api",

View File

@ -150,7 +150,7 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base) {
return base;
}
xla::Shape FromC(XLA_Shape* shape) {
xla::Shape FromC(const XLA_Shape* shape) {
xla::ShapeProto p;
p.ParseFromArray(shape->bytes, shape->size);
return xla::Shape(p);
@ -230,4 +230,106 @@ void Free(XLA_ShapedBuffer* c_buffer) {
delete[] c_buffer->bases;
}
XLA_HloModule ToC(const xla::HloModule& module) {
XLA_HloModule c_module;
c_module.proto = stream_executor::tpu::SerializeProto(module.ToProto());
c_module.module_config = ApiConverter::ToC(module.config());
return c_module;
}
xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC(
const XLA_HloModule& c_module) {
xla::HloModuleProto module_proto =
stream_executor::tpu::DeserializeProto<xla::HloModuleProto>(
c_module.proto);
return xla::HloModule::CreateFromProto(
module_proto, ApiConverter::FromC(c_module.module_config));
}
void Free(XLA_HloModule* c_module) {
stream_executor::tpu::SerializedProto_Free(c_module->proto);
Free(&c_module->module_config);
}
static xla::HloModuleConfig ConfigWithLayout(
const XLA_HloModuleConfig& se_config) {
xla::ShapeLayout result_layout(
FromC(&se_config.entry_computation_layout.result_layout));
xla::ComputationLayout layout(result_layout);
for (int i = 0; i < se_config.entry_computation_layout.parameter_count; ++i) {
layout.add_parameter_layout(xla::ShapeLayout(
FromC(&se_config.entry_computation_layout.parameter_layouts[i])));
}
return xla::HloModuleConfig(layout);
}
XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config) {
XLA_HloModuleConfig hlo_config;
hlo_config.seed = config.seed();
hlo_config.launch_id = config.launch_id();
hlo_config.replica_count = config.replica_count();
hlo_config.num_partitions = config.num_partitions();
hlo_config.use_spmd_partitioning = config.use_spmd_partitioning();
hlo_config.has_static_device_assignment =
config.has_static_device_assignment();
hlo_config.has_entry_computation_layout =
config.has_entry_computation_layout();
if (config.has_static_device_assignment()) {
xla::DeviceAssignmentProto dev_proto;
config.static_device_assignment().Serialize(&dev_proto).IgnoreError();
hlo_config.static_device_assignment =
stream_executor::tpu::SerializeProto(dev_proto);
}
if (config.has_entry_computation_layout()) {
const auto& layout = config.entry_computation_layout();
ApiConverter::ToC(layout.result_layout().shape(),
&hlo_config.entry_computation_layout.result_layout);
hlo_config.entry_computation_layout.parameter_layouts =
new XLA_Shape[layout.parameter_count()];
for (int i = 0; i < layout.parameter_count(); ++i) {
ApiConverter::ToC(
layout.parameter_layout(i).shape(),
&hlo_config.entry_computation_layout.parameter_layouts[i]);
}
hlo_config.entry_computation_layout.parameter_count =
layout.parameter_count();
}
return hlo_config;
}
xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config) {
xla::HloModuleConfig config = c_config.has_entry_computation_layout
? ConfigWithLayout(c_config)
: xla::HloModuleConfig();
config.set_launch_id(c_config.launch_id);
config.set_seed(c_config.seed);
config.set_replica_count(c_config.replica_count);
config.set_num_partitions(c_config.num_partitions);
config.set_use_spmd_partitioning(c_config.use_spmd_partitioning);
if (c_config.has_static_device_assignment) {
auto device_assignment = xla::DeviceAssignment::Deserialize(
stream_executor::tpu::DeserializeProto<xla::DeviceAssignmentProto>(
c_config.static_device_assignment));
config.set_static_device_assignment(
*(device_assignment.ConsumeValueOrDie()));
}
return config;
}
void Free(XLA_HloModuleConfig* c_config) {
for (auto i = 0; i < c_config->entry_computation_layout.parameter_count;
++i) {
ApiConverter::Free(
&c_config->entry_computation_layout.parameter_layouts[i]);
}
delete[] c_config->entry_computation_layout.parameter_layouts;
ApiConverter::Free(&c_config->entry_computation_layout.result_layout);
if (c_config->has_static_device_assignment) {
stream_executor::tpu::SerializedProto_Free(
c_config->static_device_assignment);
}
}
} // namespace ApiConverter

View File

@ -19,6 +19,8 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@ -41,7 +43,7 @@ stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
void Free(SE_DeviceMemoryBase*);
// xla::Shape
xla::Shape FromC(XLA_Shape* shape);
xla::Shape FromC(const XLA_Shape* shape);
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
void Free(XLA_Shape* shape);
@ -65,11 +67,6 @@ SE_DeviceMemoryBase ToC(const stream_executor::DeviceMemoryBase& base);
stream_executor::DeviceMemoryBase FromC(const SE_DeviceMemoryBase& se_base);
void Free(SE_DeviceMemoryBase*);
// xla::Shape
xla::Shape FromC(XLA_Shape* shape);
void ToC(const xla::Shape& xla_shape, XLA_Shape* c_shape);
void Free(XLA_Shape* shape);
// Literal
void ToC(const xla::LiteralSlice& literal, XLA_Literal* c_literal);
xla::MutableBorrowingLiteral FromC(XLA_Literal* c_literal);
@ -94,6 +91,17 @@ SE_MaybeOwningDeviceMemory ToC(stream_executor::OwningDeviceMemory* mem);
// 'mem' is unowned.
SE_MaybeOwningDeviceMemory ToC(xla::MaybeOwningDeviceMemory& mem, bool aliased);
// HloModule
XLA_HloModule ToC(const xla::HloModule& module);
xla::StatusOr<std::unique_ptr<xla::HloModule>> FromC(
const XLA_HloModule& c_module);
void Free(XLA_HloModule* c_module);
// HloModuleConfig
XLA_HloModuleConfig ToC(const xla::HloModuleConfig& config);
xla::HloModuleConfig FromC(const XLA_HloModuleConfig& c_config);
void Free(XLA_HloModuleConfig* c_config);
// Helper for managing stack based C -> C++ conversions.
template <class CType>
struct StackHelper {

View File

@ -310,6 +310,11 @@ TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable,
const char** fingerprint,
size_t* size);
// Caller is responsible for freeing the returned module's proto and its
// config's proto.
TFTPU_CAPI_EXPORT XLA_HloModule
TpuExecutable_HloModule(SE_Executable* executable);
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
@ -458,6 +463,7 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_HloModule);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);