diff --git a/tensorflow/compiler/plugin/executor/BUILD b/tensorflow/compiler/plugin/executor/BUILD index bc7c25c1205..ffecd68d921 100644 --- a/tensorflow/compiler/plugin/executor/BUILD +++ b/tensorflow/compiler/plugin/executor/BUILD @@ -16,10 +16,13 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:xla_headers_lib", "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:computation_placer", + "//tensorflow/compiler/xla/service:layout_assignment", "//third_party/eigen3", "@local_config_cuda//cuda:cuda_headers", "@protobuf_archive//:protobuf_headers", ], + alwayslink = 1, ) filegroup( diff --git a/tensorflow/compiler/plugin/executor/compiler.cc b/tensorflow/compiler/plugin/executor/compiler.cc index 72fe7ba4519..77193f06c4b 100644 --- a/tensorflow/compiler/plugin/executor/compiler.cc +++ b/tensorflow/compiler/plugin/executor/compiler.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/plugin/executor/compiler.h" #include "tensorflow/compiler/plugin/executor/executable.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -55,6 +57,8 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(true); + pipeline.AddPass( + hlo_module->mutable_entry_computation_layout()); pipeline.AddPass(); pipeline.AddPass(); @@ -107,10 +111,16 @@ ExecutorCompiler::ShapeSizeBytesFunction() const { return ExecutorExecutable::ShapeSizeBytes; } +static std::unique_ptr CreateComputationPlacer() { + return xla::MakeUnique(); +} + REGISTER_MODULE_INITIALIZER(executor_compiler, { xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() { return xla::MakeUnique(); }); + xla::ComputationPlacer::RegisterComputationPlacer(sep::kExecutorPlatformId, + &CreateComputationPlacer); }); } // namespace executorplugin diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 35a563bf227..6271b59a5bf 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -133,7 +133,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } else { // Array shape. if (!shape.has_layout()) { - return InvalidArgument("shape does not have a layout"); + return InvalidArgument("shape %s does not have a layout", + ShapeUtil::HumanString(shape).c_str()); } return ValidateLayoutForShape(shape.layout(), shape); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index fe7ff8600fb..ea367c74aed 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1058,6 +1058,8 @@ StatusOr> HloEvaluator::Evaluate( StatusOr> HloEvaluator::Evaluate( const HloComputation& computation, tensorflow::gtl::ArraySlice arg_literals) { + XLA_VLOG_LINES( + 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); arg_literals_ = arg_literals; evaluated_.clear(); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 4bbe0ba0ddd..a946d335ca6 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -106,6 +106,14 @@ cc_binary( ], ) +cc_binary( + name = "replay_computation_hlo_evaluator", + deps = [ + ":replay_computation_library", + "//tensorflow/compiler/plugin/executor:plugin_lib", + ], +) + cc_binary( name = "show_literal", srcs = ["show_literal.cc"],