Remove the ambiguity of device/host computation layouts within the HloModuleConfig.
PiperOrigin-RevId: 201284741
This commit is contained in:
parent
c04396e3fd
commit
9ab04addfb
@ -51,24 +51,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
|
||||
Status LocalExecutable::ValidateExecutionOptions(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const ExecutableRunOptions& run_options, const Backend& backend) {
|
||||
const ComputationLayout& host_computation_layout =
|
||||
executable_->module_config().host_entry_computation_layout();
|
||||
const ComputationLayout& device_computation_layout =
|
||||
executable_->module_config().device_entry_computation_layout();
|
||||
const ComputationLayout& computation_layout =
|
||||
executable_->module_config().entry_computation_layout();
|
||||
|
||||
// Check argument number, shapes, and layouts.
|
||||
if (arguments.size() != host_computation_layout.parameter_count()) {
|
||||
if (arguments.size() != computation_layout.parameter_count()) {
|
||||
return InvalidArgument(
|
||||
"invalid number of arguments for computation: expected %d, got %zu",
|
||||
host_computation_layout.parameter_count(), arguments.size());
|
||||
}
|
||||
if (arguments.size() != device_computation_layout.parameter_count()) {
|
||||
return InvalidArgument(
|
||||
"invalid number of arguments for computation: expected %d, got %zu",
|
||||
device_computation_layout.parameter_count(), arguments.size());
|
||||
computation_layout.parameter_count(), arguments.size());
|
||||
}
|
||||
for (int i = 0; i < arguments.size(); ++i) {
|
||||
if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape(
|
||||
if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
|
||||
arguments[i]->on_host_shape())) {
|
||||
return InvalidParameterArgument(
|
||||
executable_.get(), i,
|
||||
@ -76,24 +69,10 @@ Status LocalExecutable::ValidateExecutionOptions(
|
||||
"parameter "
|
||||
"%d: want %s, got %s",
|
||||
i,
|
||||
ShapeUtil::HumanString(
|
||||
host_computation_layout.parameter_layout(i).shape())
|
||||
ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
|
||||
.c_str(),
|
||||
ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str());
|
||||
}
|
||||
if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape(
|
||||
arguments[i]->on_device_shape())) {
|
||||
return InvalidParameterArgument(
|
||||
executable_.get(), i,
|
||||
"Argument does not match device shape or layout of computation "
|
||||
"parameter "
|
||||
"%d: want %s, got %s",
|
||||
i,
|
||||
ShapeUtil::HumanString(
|
||||
device_computation_layout.parameter_layout(i).shape())
|
||||
.c_str(),
|
||||
ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (run_options.stream() != nullptr) {
|
||||
|
@ -303,8 +303,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
|
||||
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
|
||||
|
||||
pipeline.AddPass<CpuLayoutAssignment>(
|
||||
module->mutable_device_entry_computation_layout(),
|
||||
&target_machine_features);
|
||||
module->mutable_entry_computation_layout(), &target_machine_features);
|
||||
// The LayoutAssignment pass may leave behind kCopy instructions which are
|
||||
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
|
||||
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
|
||||
|
@ -206,8 +206,8 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
|
||||
tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
ScopedShapedBuffer result_buffer(
|
||||
/*on_host_shape=*/host_result_shape(),
|
||||
/*on_device_shape=*/host_result_shape(), run_options->allocator(),
|
||||
/*on_host_shape=*/result_shape(),
|
||||
/*on_device_shape=*/result_shape(), run_options->allocator(),
|
||||
stream->parent()->device_ordinal());
|
||||
|
||||
// Move OwningDeviceMemory values which contain the array(s) of the result
|
||||
|
@ -131,8 +131,8 @@ class Executable {
|
||||
|
||||
// The shape (including layout) that results from this execution. This is the
|
||||
// shape of the DeviceMemoryBase result value in ExecuteOnStream above.
|
||||
const Shape& host_result_shape() const {
|
||||
return hlo_module_->config().host_entry_computation_layout().result_shape();
|
||||
const Shape& result_shape() const {
|
||||
return hlo_module_->config().entry_computation_layout().result_shape();
|
||||
}
|
||||
|
||||
// Returns the size of the executable in bytes. Returns -1 by default if the
|
||||
|
@ -205,7 +205,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
{
|
||||
HloPassPipeline pipeline("layout_assignment");
|
||||
pipeline.AddPass<GpuLayoutAssignment>(
|
||||
hlo_module->mutable_device_entry_computation_layout(), stream_exec);
|
||||
hlo_module->mutable_entry_computation_layout(), stream_exec);
|
||||
|
||||
// The LayoutAssignment pass may leave behind kCopy instructions which are
|
||||
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
|
||||
|
@ -58,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal(
|
||||
|
||||
// If the module configuration has no entry layout computation set, create a
|
||||
// default one based on the program shape.
|
||||
if (!config_.has_host_entry_computation_layout()) {
|
||||
if (!config_.has_entry_computation_layout()) {
|
||||
config_.SetDefaultComputationLayout(
|
||||
entry_computation_->ComputeProgramShape());
|
||||
}
|
||||
@ -231,14 +231,11 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
TF_RET_CHECK(proto.has_program_shape())
|
||||
<< "No program shape found in the proto";
|
||||
const auto& expected_program_shape = proto.program_shape();
|
||||
TF_RET_CHECK(
|
||||
expected_program_shape.parameters_size() ==
|
||||
module_config.device_entry_computation_layout().parameter_count());
|
||||
TF_RET_CHECK(expected_program_shape.parameters_size() ==
|
||||
module_config.entry_computation_layout().parameter_count());
|
||||
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
|
||||
const Shape& parameter_shape =
|
||||
module_config.device_entry_computation_layout()
|
||||
.parameter_layout(i)
|
||||
.shape();
|
||||
module_config.entry_computation_layout().parameter_layout(i).shape();
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
|
||||
parameter_shape))
|
||||
<< "HloModuleConfig has different shape for parameter " << i
|
||||
@ -248,7 +245,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
<< ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
|
||||
}
|
||||
const Shape& result_shape =
|
||||
module_config.device_entry_computation_layout().result_layout().shape();
|
||||
module_config.entry_computation_layout().result_layout().shape();
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
|
||||
<< "HloModuleConfig has different result shape than the HLO module. "
|
||||
@ -327,7 +324,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
||||
// The module config is constructed with default layouts regardless of what is
|
||||
// passed in via the ProgramShape. Set the layouts to the appropriate values.
|
||||
ComputationLayout* entry_layout =
|
||||
module_config.mutable_host_entry_computation_layout();
|
||||
module_config.mutable_entry_computation_layout();
|
||||
for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
@ -335,9 +332,6 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
||||
}
|
||||
TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
program_shape.result()));
|
||||
*module_config.mutable_device_entry_computation_layout() =
|
||||
module_config.host_entry_computation_layout();
|
||||
|
||||
return module_config;
|
||||
}
|
||||
|
||||
|
@ -105,20 +105,19 @@ class HloModule {
|
||||
return entry_computation_;
|
||||
}
|
||||
|
||||
ComputationLayout* mutable_host_entry_computation_layout() {
|
||||
return config_.mutable_host_entry_computation_layout();
|
||||
// Creates the ComputationLayout which describes the current status of the HLO
|
||||
// module entry computation.
|
||||
ComputationLayout compute_computation_layout() const {
|
||||
return ComputationLayout(entry_computation()->ComputeProgramShape(),
|
||||
/*ignore_layouts=*/false);
|
||||
}
|
||||
|
||||
const ComputationLayout& host_entry_computation_layout() const {
|
||||
return config_.host_entry_computation_layout();
|
||||
ComputationLayout* mutable_entry_computation_layout() {
|
||||
return config_.mutable_entry_computation_layout();
|
||||
}
|
||||
|
||||
ComputationLayout* mutable_device_entry_computation_layout() {
|
||||
return config_.mutable_device_entry_computation_layout();
|
||||
}
|
||||
|
||||
const ComputationLayout& device_entry_computation_layout() const {
|
||||
return config_.device_entry_computation_layout();
|
||||
const ComputationLayout& entry_computation_layout() const {
|
||||
return config_.entry_computation_layout();
|
||||
}
|
||||
|
||||
// Gets the computations in this module.
|
||||
|
@ -28,16 +28,14 @@ namespace xla {
|
||||
|
||||
using tensorflow::strings::StrAppend;
|
||||
|
||||
HloModuleConfig::HloModuleConfig() {}
|
||||
|
||||
HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape)
|
||||
: host_entry_computation_layout_(program_shape),
|
||||
device_entry_computation_layout_(program_shape) {}
|
||||
HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
|
||||
bool ignore_layouts)
|
||||
: entry_computation_layout_(
|
||||
ComputationLayout(program_shape, ignore_layouts)) {}
|
||||
|
||||
void HloModuleConfig::SetDefaultComputationLayout(
|
||||
const ProgramShape& program_shape) {
|
||||
host_entry_computation_layout_ = ComputationLayout(program_shape);
|
||||
device_entry_computation_layout_ = ComputationLayout(program_shape);
|
||||
entry_computation_layout_ = ComputationLayout(program_shape);
|
||||
}
|
||||
|
||||
string HloModuleConfig::compilation_cache_key() const {
|
||||
@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const {
|
||||
StrAppend(&key, "::(");
|
||||
std::vector<string> params;
|
||||
for (const ShapeLayout& param_layout :
|
||||
host_entry_computation_layout_->parameter_layouts()) {
|
||||
entry_computation_layout_->parameter_layouts()) {
|
||||
params.push_back(param_layout.shape().DebugString());
|
||||
}
|
||||
StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
|
||||
host_entry_computation_layout_->result_shape().SerializeAsString());
|
||||
for (const ShapeLayout& param_layout :
|
||||
device_entry_computation_layout_->parameter_layouts()) {
|
||||
params.push_back(param_layout.shape().DebugString());
|
||||
}
|
||||
StrAppend(
|
||||
&key, tensorflow::str_util::Join(params, ", "), ") => ",
|
||||
device_entry_computation_layout_->result_shape().SerializeAsString());
|
||||
entry_computation_layout_->result_shape().SerializeAsString());
|
||||
if (seed() != 0) {
|
||||
// TODO(b/32083678): force recompilation to reset global state.
|
||||
static std::atomic<int> counter{0};
|
||||
|
@ -37,48 +37,34 @@ class HloModuleConfig {
|
||||
// ComputationLayout. The default ctor creates it without -- in this case
|
||||
// accessing entry_computation_layout will CHECK-fail. The ctor accepting a
|
||||
// ProgramShape creates a computation layout using this shape.
|
||||
HloModuleConfig();
|
||||
explicit HloModuleConfig(const ProgramShape& program_shape);
|
||||
// The layouts in the ProgramShape will be reset to default unless
|
||||
// ignore_layouts is set to false.
|
||||
HloModuleConfig() = default;
|
||||
|
||||
explicit HloModuleConfig(const ProgramShape& program_shape,
|
||||
bool ignore_layouts = true);
|
||||
|
||||
// Checks if this config has an entry computation layout already.
|
||||
bool has_host_entry_computation_layout() const {
|
||||
return host_entry_computation_layout_.has_value();
|
||||
}
|
||||
|
||||
bool has_device_entry_computation_layout() const {
|
||||
return device_entry_computation_layout_.has_value();
|
||||
bool has_entry_computation_layout() const {
|
||||
return entry_computation_layout_.has_value();
|
||||
}
|
||||
|
||||
// Sets the entry computation layout for this config. If the entry computation
|
||||
// layout already exists, it is silently replaced.
|
||||
void SetDefaultComputationLayout(const ProgramShape& program_shape);
|
||||
|
||||
// Returns a constant reference to the on-host layout of the entry
|
||||
// computation. Assumes the layout was set.
|
||||
const ComputationLayout& host_entry_computation_layout() const {
|
||||
CHECK(host_entry_computation_layout_.has_value());
|
||||
return *host_entry_computation_layout_;
|
||||
}
|
||||
|
||||
// Returns a mutable pointer to the layout of the on-host entry computation.
|
||||
// Returns a constant reference to the layout of the entry computation.
|
||||
// Assumes the layout was set.
|
||||
ComputationLayout* mutable_host_entry_computation_layout() {
|
||||
CHECK(host_entry_computation_layout_.has_value());
|
||||
return &(*host_entry_computation_layout_);
|
||||
const ComputationLayout& entry_computation_layout() const {
|
||||
CHECK(entry_computation_layout_.has_value());
|
||||
return *entry_computation_layout_;
|
||||
}
|
||||
|
||||
// Returns a constant reference to the on-device layout of the entry
|
||||
// computation. Assumes the layout was set.
|
||||
const ComputationLayout& device_entry_computation_layout() const {
|
||||
CHECK(device_entry_computation_layout_.has_value());
|
||||
return *device_entry_computation_layout_;
|
||||
}
|
||||
|
||||
// Returns a mutable pointer to the layout of the on-device entry computation.
|
||||
// Returns a mutable pointer to the layout of the entry computation.
|
||||
// Assumes the layout was set.
|
||||
ComputationLayout* mutable_device_entry_computation_layout() {
|
||||
CHECK(device_entry_computation_layout_.has_value());
|
||||
return &(*device_entry_computation_layout_);
|
||||
ComputationLayout* mutable_entry_computation_layout() {
|
||||
CHECK(entry_computation_layout_.has_value());
|
||||
return &(*entry_computation_layout_);
|
||||
}
|
||||
|
||||
// Returns whether to enable HLO-level profiling.
|
||||
@ -127,8 +113,7 @@ class HloModuleConfig {
|
||||
private:
|
||||
// If you add new members, be sure to update compilation_cache_key.
|
||||
|
||||
tensorflow::gtl::optional<ComputationLayout> host_entry_computation_layout_;
|
||||
tensorflow::gtl::optional<ComputationLayout> device_entry_computation_layout_;
|
||||
tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
|
||||
|
||||
// Whether this is a 'host module'.
|
||||
bool is_host_module_ = false;
|
||||
|
@ -327,22 +327,15 @@ bool HloParser::ParseComputations() {
|
||||
// set the layouts to what the hlo text says.
|
||||
for (int p = 0; p < computation->num_parameters(); p++) {
|
||||
const Shape& param_shape = computation->parameter_instruction(p)->shape();
|
||||
TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
|
||||
->mutable_parameter_layout(p)
|
||||
->CopyLayoutFromShape(param_shape));
|
||||
TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
|
||||
TF_CHECK_OK(module_->mutable_entry_computation_layout()
|
||||
->mutable_parameter_layout(p)
|
||||
->CopyLayoutFromShape(param_shape));
|
||||
}
|
||||
const Shape& result_shape = computation->root_instruction()->shape();
|
||||
TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->CopyLayoutFromShape(result_shape));
|
||||
TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
|
||||
TF_CHECK_OK(module_->mutable_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->CopyLayoutFromShape(result_shape));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1302,7 +1302,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
|
||||
|
||||
auto module = ParseHloString(original);
|
||||
TF_ASSERT_OK(module.status());
|
||||
auto program_layout = module.ValueOrDie()->host_entry_computation_layout();
|
||||
auto program_layout = module.ValueOrDie()->entry_computation_layout();
|
||||
ASSERT_EQ(program_layout.parameter_count(), 1);
|
||||
auto param_layout = program_layout.parameter_layout(0).layout();
|
||||
auto result_layout = program_layout.result_layout().layout();
|
||||
|
@ -44,7 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
HloPassPipeline pipeline("Interpreter");
|
||||
|
||||
pipeline.AddPass<LayoutAssignment>(
|
||||
hlo_module->mutable_device_entry_computation_layout());
|
||||
hlo_module->mutable_entry_computation_layout());
|
||||
return pipeline.Run(hlo_module).status();
|
||||
}
|
||||
|
||||
|
@ -190,10 +190,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(program_shape, argument_layouts, &execution_options));
|
||||
|
||||
VLOG(3) << "Host Computation Layout: "
|
||||
<< module_config->host_entry_computation_layout().ToString();
|
||||
VLOG(3) << "Device Computation Layout: "
|
||||
<< module_config->device_entry_computation_layout().ToString();
|
||||
VLOG(3) << "Computation Layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::StreamExecutor * executor,
|
||||
|
@ -244,10 +244,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
|
||||
const ExecutionOptions* execution_options) {
|
||||
auto config = MakeUnique<HloModuleConfig>(program_shape);
|
||||
ComputationLayout* host_computation_layout =
|
||||
config->mutable_host_entry_computation_layout();
|
||||
ComputationLayout* device_computation_layout =
|
||||
config->mutable_device_entry_computation_layout();
|
||||
ComputationLayout* computation_layout =
|
||||
config->mutable_entry_computation_layout();
|
||||
if (program_shape.parameters_size() != argument_shapes.size()) {
|
||||
return InvalidArgument("computation takes %d parameters, but %zu given",
|
||||
program_shape.parameters_size(),
|
||||
@ -264,10 +262,9 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
|
||||
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i)
|
||||
->CopyLayoutFromShape(*argument_shapes[i]));
|
||||
TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i)
|
||||
->CopyLayoutFromShape(*argument_shapes[i]));
|
||||
TF_RETURN_IF_ERROR(
|
||||
computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
*argument_shapes[i]));
|
||||
}
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_shape_with_output_layout()) {
|
||||
@ -276,20 +273,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateResultShape(shape_with_output_layout, program_shape.result()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
host_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
shape_with_output_layout));
|
||||
TF_RETURN_IF_ERROR(
|
||||
device_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
shape_with_output_layout));
|
||||
} else {
|
||||
// If the result layout is not set, then choose the default.
|
||||
// TODO(b/29118294): Allow the compiler to choose a better layout in this
|
||||
// case.
|
||||
// TODO(b/78356948): We are forcing the default layout here. We should fix
|
||||
// clients which expect a default layout, to be explicit about it, by
|
||||
// passing the proper ExecutionOptions with shape_with_output_layout set.
|
||||
host_computation_layout->mutable_result_layout()->SetToDefaultLayout();
|
||||
device_computation_layout->mutable_result_layout()->SetToDefaultLayout();
|
||||
computation_layout->mutable_result_layout()->SetToDefaultLayout();
|
||||
}
|
||||
|
||||
config->set_replica_count(options_.number_of_replicas());
|
||||
@ -377,24 +365,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
return std::move(executables);
|
||||
}
|
||||
|
||||
Status Service::ValidateEntryComputationLayout(HloModule* module) {
|
||||
const ComputationLayout& on_host = module->host_entry_computation_layout();
|
||||
const ComputationLayout& on_device =
|
||||
module->device_entry_computation_layout();
|
||||
for (int64 i = 0; i < on_device.parameter_count(); ++i) {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(on_device.parameter_shape(i),
|
||||
on_host.parameter_shape(i)))
|
||||
<< ShapeUtil::HumanStringWithLayout(on_device.parameter_shape(i))
|
||||
<< " vs "
|
||||
<< ShapeUtil::HumanStringWithLayout(on_host.parameter_shape(i));
|
||||
}
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Compatible(on_device.result_shape(), on_host.result_shape()))
|
||||
<< ShapeUtil::HumanStringWithLayout(on_device.result_shape()) << " vs "
|
||||
<< ShapeUtil::HumanStringWithLayout(on_host.result_shape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::vector<GlobalDataHandle>>
|
||||
Service::ExecuteParallelAndRegisterResult(
|
||||
tensorflow::gtl::ArraySlice<Executable*> executables,
|
||||
@ -690,7 +660,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
||||
request.execution_options()));
|
||||
VLOG(3)
|
||||
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
|
||||
<< module_config->host_entry_computation_layout().ToString();
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
|
||||
// Adds to the vectors to build and execute the computations after the loop.
|
||||
all_arguments.push_back(replicated_arguments);
|
||||
@ -851,8 +821,6 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, backend->compiler()->RunHloPasses(std::move(module), executor,
|
||||
device_allocator));
|
||||
// Check that on-host and on-device shapes are consistent.
|
||||
TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
backend->compiler()->RunBackend(
|
||||
|
@ -193,9 +193,6 @@ class Service : public ServiceInterface {
|
||||
const ExecutionOptions& execution_options,
|
||||
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
|
||||
|
||||
// Assert that host- and device-shapes are in a consistent state.
|
||||
Status ValidateEntryComputationLayout(HloModule* module);
|
||||
|
||||
protected:
|
||||
friend class LocalExecutable;
|
||||
|
||||
|
@ -185,13 +185,9 @@ class HloTestBase : public ::testing::Test {
|
||||
// 'layout'.
|
||||
void ForceParameterLayout(HloModule* module, int64 param_no,
|
||||
const Layout& layout) {
|
||||
ASSERT_LT(
|
||||
param_no,
|
||||
module->mutable_host_entry_computation_layout()->parameter_count());
|
||||
module->mutable_host_entry_computation_layout()
|
||||
->mutable_parameter_layout(param_no)
|
||||
->ResetLayout(layout);
|
||||
module->mutable_device_entry_computation_layout()
|
||||
ASSERT_LT(param_no,
|
||||
module->mutable_entry_computation_layout()->parameter_count());
|
||||
module->mutable_entry_computation_layout()
|
||||
->mutable_parameter_layout(param_no)
|
||||
->ResetLayout(layout);
|
||||
}
|
||||
@ -199,10 +195,7 @@ class HloTestBase : public ::testing::Test {
|
||||
// Convenience method to force the layout of the computation result in a
|
||||
// module. The result layout of 'module' is set to 'layout'.
|
||||
void ForceResultLayout(HloModule* module, const Layout& layout) {
|
||||
module->mutable_host_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->ResetLayout(layout);
|
||||
module->mutable_device_entry_computation_layout()
|
||||
module->mutable_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->ResetLayout(layout);
|
||||
}
|
||||
@ -210,10 +203,7 @@ class HloTestBase : public ::testing::Test {
|
||||
// Convenience method to clear the layout of the computation result in
|
||||
// 'module'.
|
||||
void ForceClearResultLayout(HloModule* module) {
|
||||
module->mutable_host_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->Clear();
|
||||
module->mutable_device_entry_computation_layout()
|
||||
module->mutable_entry_computation_layout()
|
||||
->mutable_result_layout()
|
||||
->Clear();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user