[TF:XLA] Fixes to the "evaluator" plugin.
* Mark the evaluator plugin as alwayslink so it doesn't get stripped out by the linker. * Add a generic LayoutAssignment pass to the pass pipeline; otherwise the entry computation has no layout and Service::Execute CHECK-fails in the AllocationTracker. * Register the default computation placer for the evaluator backend. * Add an replay_computation_hlo_evaluator binary that can replay computation snapshots via the HLO evaluator. PiperOrigin-RevId: 164364780
This commit is contained in:
parent
c62eaccec8
commit
5d52208e6e
tensorflow/compiler
@ -16,10 +16,13 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:xla_headers_lib",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||
"//third_party/eigen3",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@protobuf_archive//:protobuf_headers",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/plugin/executor/compiler.h"
|
||||
#include "tensorflow/compiler/plugin/executor/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
||||
@ -27,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -55,6 +57,8 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
pipeline.AddPass<ReshapeMover>();
|
||||
pipeline.AddPass<HloConstantFolding>();
|
||||
pipeline.AddPass<HloCSE>(true);
|
||||
pipeline.AddPass<LayoutAssignment>(
|
||||
hlo_module->mutable_entry_computation_layout());
|
||||
|
||||
pipeline.AddPass<HloDCE>();
|
||||
pipeline.AddPass<FlattenCallGraph>();
|
||||
@ -107,10 +111,16 @@ ExecutorCompiler::ShapeSizeBytesFunction() const {
|
||||
return ExecutorExecutable::ShapeSizeBytes;
|
||||
}
|
||||
|
||||
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
|
||||
return xla::MakeUnique<xla::ComputationPlacer>();
|
||||
}
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(executor_compiler, {
|
||||
xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() {
|
||||
return xla::MakeUnique<xla::executorplugin::ExecutorCompiler>();
|
||||
});
|
||||
xla::ComputationPlacer::RegisterComputationPlacer(sep::kExecutorPlatformId,
|
||||
&CreateComputationPlacer);
|
||||
});
|
||||
|
||||
} // namespace executorplugin
|
||||
|
@ -133,7 +133,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
} else {
|
||||
// Array shape.
|
||||
if (!shape.has_layout()) {
|
||||
return InvalidArgument("shape does not have a layout");
|
||||
return InvalidArgument("shape %s does not have a layout",
|
||||
ShapeUtil::HumanString(shape).c_str());
|
||||
}
|
||||
return ValidateLayoutForShape(shape.layout(), shape);
|
||||
}
|
||||
|
@ -1058,6 +1058,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||
const HloComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<const Literal*> arg_literals) {
|
||||
XLA_VLOG_LINES(
|
||||
2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
|
||||
arg_literals_ = arg_literals;
|
||||
evaluated_.clear();
|
||||
|
||||
|
@ -106,6 +106,14 @@ cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "replay_computation_hlo_evaluator",
|
||||
deps = [
|
||||
":replay_computation_library",
|
||||
"//tensorflow/compiler/plugin/executor:plugin_lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "show_literal",
|
||||
srcs = ["show_literal.cc"],
|
||||
|
Loading…
Reference in New Issue
Block a user