From 9ab04addfb80cbf9334bb330acee5fca09353d23 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Jun 2018 19:40:00 -0700 Subject: [PATCH] Remove the ambiguity of device/host computation layouts within the HloModuleConfig. PiperOrigin-RevId: 201284741 --- .../compiler/xla/client/local_client.cc | 33 +++---------- .../compiler/xla/service/cpu/cpu_compiler.cc | 3 +- .../xla/service/cpu/cpu_executable.cc | 4 +- tensorflow/compiler/xla/service/executable.h | 4 +- .../compiler/xla/service/gpu/gpu_compiler.cc | 2 +- tensorflow/compiler/xla/service/hlo_module.cc | 18 +++---- tensorflow/compiler/xla/service/hlo_module.h | 19 ++++--- .../compiler/xla/service/hlo_module_config.cc | 23 +++------ .../compiler/xla/service/hlo_module_config.h | 49 +++++++------------ tensorflow/compiler/xla/service/hlo_parser.cc | 11 +---- .../compiler/xla/service/hlo_parser_test.cc | 2 +- .../xla/service/interpreter/compiler.cc | 2 +- .../compiler/xla/service/local_service.cc | 6 +-- tensorflow/compiler/xla/service/service.cc | 48 +++--------------- tensorflow/compiler/xla/service/service.h | 3 -- tensorflow/compiler/xla/tests/hlo_test_base.h | 20 ++------ 16 files changed, 70 insertions(+), 177 deletions(-) diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index cf07910c4a2..5f9710914bd 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -51,24 +51,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& host_computation_layout = - executable_->module_config().host_entry_computation_layout(); - const ComputationLayout& device_computation_layout = - executable_->module_config().device_entry_computation_layout(); + const ComputationLayout& computation_layout = + executable_->module_config().entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != host_computation_layout.parameter_count()) { + if (arguments.size() != computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - host_computation_layout.parameter_count(), arguments.size()); - } - if (arguments.size() != device_computation_layout.parameter_count()) { - return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", - device_computation_layout.parameter_count(), arguments.size()); + computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, @@ -76,24 +69,10 @@ Status LocalExecutable::ValidateExecutionOptions( "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString( - host_computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } - if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->on_device_shape())) { - return InvalidParameterArgument( - executable_.get(), i, - "Argument does not match device shape or layout of computation " - "parameter " - "%d: want %s, got %s", - i, - ShapeUtil::HumanString( - device_computation_layout.parameter_layout(i).shape()) - .c_str(), - ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); - } } if (run_options.stream() != nullptr) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d0391325350..52da9d6eac7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -303,8 +303,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_device_entry_computation_layout(), - &target_machine_features); + module->mutable_entry_computation_layout(), &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index cf43b74c699..1093559892d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -206,8 +206,8 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); ScopedShapedBuffer result_buffer( - /*on_host_shape=*/host_result_shape(), - /*on_device_shape=*/host_result_shape(), run_options->allocator(), + /*on_host_shape=*/result_shape(), + /*on_device_shape=*/result_shape(), run_options->allocator(), stream->parent()->device_ordinal()); // Move OwningDeviceMemory values which contain the array(s) of the result diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index bd92bfa50f3..98eaeee30a6 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -131,8 +131,8 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& host_result_shape() const { - return hlo_module_->config().host_entry_computation_layout().result_shape(); + const Shape& result_shape() const { + return hlo_module_->config().entry_computation_layout().result_shape(); } // Returns the size of the executable in bytes. Returns -1 by default if the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index a040e6b6816..decfc40dafa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -205,7 +205,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), stream_exec); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 11384c1456d..39bc25ba42c 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -58,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_host_entry_computation_layout()) { + if (!config_.has_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -231,14 +231,11 @@ StatusOr> HloModule::CreateFromProto( TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK( - expected_program_shape.parameters_size() == - module_config.device_entry_computation_layout().parameter_count()); + TF_RET_CHECK(expected_program_shape.parameters_size() == + module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.device_entry_computation_layout() - .parameter_layout(i) - .shape(); + module_config.entry_computation_layout().parameter_layout(i).shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -248,7 +245,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.device_entry_computation_layout().result_layout().shape(); + module_config.entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -327,7 +324,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_host_entry_computation_layout(); + module_config.mutable_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -335,9 +332,6 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); - *module_config.mutable_device_entry_computation_layout() = - module_config.host_entry_computation_layout(); - return module_config; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 5dc94e78e3c..d2e726a0db6 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -105,20 +105,19 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_host_entry_computation_layout() { - return config_.mutable_host_entry_computation_layout(); + // Creates the ComputationLayout which describes the current status of the HLO + // module entry computation. + ComputationLayout compute_computation_layout() const { + return ComputationLayout(entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); } - const ComputationLayout& host_entry_computation_layout() const { - return config_.host_entry_computation_layout(); + ComputationLayout* mutable_entry_computation_layout() { + return config_.mutable_entry_computation_layout(); } - ComputationLayout* mutable_device_entry_computation_layout() { - return config_.mutable_device_entry_computation_layout(); - } - - const ComputationLayout& device_entry_computation_layout() const { - return config_.device_entry_computation_layout(); + const ComputationLayout& entry_computation_layout() const { + return config_.entry_computation_layout(); } // Gets the computations in this module. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index dae5578a315..07a8c798dbe 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -28,16 +28,14 @@ namespace xla { using tensorflow::strings::StrAppend; -HloModuleConfig::HloModuleConfig() {} - -HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : host_entry_computation_layout_(program_shape), - device_entry_computation_layout_(program_shape) {} +HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts) + : entry_computation_layout_( + ComputationLayout(program_shape, ignore_layouts)) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - host_entry_computation_layout_ = ComputationLayout(program_shape); - device_entry_computation_layout_ = ComputationLayout(program_shape); + entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - host_entry_computation_layout_->parameter_layouts()) { + entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - host_entry_computation_layout_->result_shape().SerializeAsString()); - for (const ShapeLayout& param_layout : - device_entry_computation_layout_->parameter_layouts()) { - params.push_back(param_layout.shape().DebugString()); - } - StrAppend( - &key, tensorflow::str_util::Join(params, ", "), ") => ", - device_entry_computation_layout_->result_shape().SerializeAsString()); + entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index cdb0b29a239..074e9c90705 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -37,48 +37,34 @@ class HloModuleConfig { // ComputationLayout. The default ctor creates it without -- in this case // accessing entry_computation_layout will CHECK-fail. The ctor accepting a // ProgramShape creates a computation layout using this shape. - HloModuleConfig(); - explicit HloModuleConfig(const ProgramShape& program_shape); + // The layouts in the ProgramShape will be reset to default unless + // ignore_layouts is set to false. + HloModuleConfig() = default; + + explicit HloModuleConfig(const ProgramShape& program_shape, + bool ignore_layouts = true); // Checks if this config has an entry computation layout already. - bool has_host_entry_computation_layout() const { - return host_entry_computation_layout_.has_value(); - } - - bool has_device_entry_computation_layout() const { - return device_entry_computation_layout_.has_value(); + bool has_entry_computation_layout() const { + return entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the on-host layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& host_entry_computation_layout() const { - CHECK(host_entry_computation_layout_.has_value()); - return *host_entry_computation_layout_; - } - - // Returns a mutable pointer to the layout of the on-host entry computation. + // Returns a constant reference to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_host_entry_computation_layout() { - CHECK(host_entry_computation_layout_.has_value()); - return &(*host_entry_computation_layout_); + const ComputationLayout& entry_computation_layout() const { + CHECK(entry_computation_layout_.has_value()); + return *entry_computation_layout_; } - // Returns a constant reference to the on-device layout of the entry - // computation. Assumes the layout was set. - const ComputationLayout& device_entry_computation_layout() const { - CHECK(device_entry_computation_layout_.has_value()); - return *device_entry_computation_layout_; - } - - // Returns a mutable pointer to the layout of the on-device entry computation. + // Returns a mutable pointer to the layout of the entry computation. // Assumes the layout was set. - ComputationLayout* mutable_device_entry_computation_layout() { - CHECK(device_entry_computation_layout_.has_value()); - return &(*device_entry_computation_layout_); + ComputationLayout* mutable_entry_computation_layout() { + CHECK(entry_computation_layout_.has_value()); + return &(*entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -127,8 +113,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional host_entry_computation_layout_; - tensorflow::gtl::optional device_entry_computation_layout_; + tensorflow::gtl::optional entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index daa3bc42324..2cee74c314f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -327,22 +327,15 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(p) - ->CopyLayoutFromShape(param_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(result_shape)); - TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + TF_CHECK_OK(module_->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d551400d1ec..d481e07f60a 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1302,7 +1302,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c1666530687..9f8f4bda875 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,7 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index a6aa8bf82c2..53efc30c365 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -190,10 +190,8 @@ StatusOr> LocalService::CompileExecutable( std::unique_ptr module_config, CreateModuleConfig(program_shape, argument_layouts, &execution_options)); - VLOG(3) << "Host Computation Layout: " - << module_config->host_entry_computation_layout().ToString(); - VLOG(3) << "Device Computation Layout: " - << module_config->device_entry_computation_layout().ToString(); + VLOG(3) << "Computation Layout: " + << module_config->entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 7ab39e01f2f..da3b622bfae 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -244,10 +244,8 @@ StatusOr> Service::CreateModuleConfig( tensorflow::gtl::ArraySlice argument_shapes, const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); - ComputationLayout* host_computation_layout = - config->mutable_host_entry_computation_layout(); - ComputationLayout* device_computation_layout = - config->mutable_device_entry_computation_layout(); + ComputationLayout* computation_layout = + config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -264,10 +262,9 @@ StatusOr> Service::CreateModuleConfig( i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); - TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) - ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR( + computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + *argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { @@ -276,20 +273,11 @@ StatusOr> Service::CreateModuleConfig( TF_RETURN_IF_ERROR( ValidateResultShape(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( - shape_with_output_layout)); - TF_RETURN_IF_ERROR( - device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { // If the result layout is not set, then choose the default. - // TODO(b/29118294): Allow the compiler to choose a better layout in this - // case. - // TODO(b/78356948): We are forcing the default layout here. We should fix - // clients which expect a default layout, to be explicit about it, by - // passing the proper ExecutionOptions with shape_with_output_layout set. - host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); - device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -377,24 +365,6 @@ StatusOr>> Service::BuildExecutables( return std::move(executables); } -Status Service::ValidateEntryComputationLayout(HloModule* module) { - const ComputationLayout& on_host = module->host_entry_computation_layout(); - const ComputationLayout& on_device = - module->device_entry_computation_layout(); - for (int64 i = 0; i < on_device.parameter_count(); ++i) { - TF_RET_CHECK(ShapeUtil::Compatible(on_device.parameter_shape(i), - on_host.parameter_shape(i))) - << ShapeUtil::HumanStringWithLayout(on_device.parameter_shape(i)) - << " vs " - << ShapeUtil::HumanStringWithLayout(on_host.parameter_shape(i)); - } - TF_RET_CHECK( - ShapeUtil::Compatible(on_device.result_shape(), on_host.result_shape())) - << ShapeUtil::HumanStringWithLayout(on_device.result_shape()) << " vs " - << ShapeUtil::HumanStringWithLayout(on_host.result_shape()); - return Status::OK(); -} - StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, @@ -690,7 +660,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); + << module_config->entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -851,8 +821,6 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 79604290848..47d196fb2aa 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -193,9 +193,6 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); - // Assert that host- and device-shapes are in a consistent state. - Status ValidateEntryComputationLayout(HloModule* module); - protected: friend class LocalExecutable; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 249da87f489..9009d67cea6 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -185,13 +185,9 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT( - param_no, - module->mutable_host_entry_computation_layout()->parameter_count()); - module->mutable_host_entry_computation_layout() - ->mutable_parameter_layout(param_no) - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -199,10 +195,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(layout); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -210,10 +203,7 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_host_entry_computation_layout() - ->mutable_result_layout() - ->Clear(); - module->mutable_device_entry_computation_layout() + module->mutable_entry_computation_layout() ->mutable_result_layout() ->Clear(); }