[TF:XLA] Fixes to the "evaluator" plugin.
* Mark the evaluator plugin as alwayslink so it doesn't get stripped out by the linker. * Add a generic LayoutAssignment pass to the pass pipeline; otherwise the entry computation has no layout and Service::Execute CHECK-fails in the AllocationTracker. * Register the default computation placer for the evaluator backend. * Add an replay_computation_hlo_evaluator binary that can replay computation snapshots via the HLO evaluator. PiperOrigin-RevId: 164364780
This commit is contained in:
parent
c62eaccec8
commit
5d52208e6e
@ -16,10 +16,13 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/xla:xla_headers_lib",
|
"//tensorflow/compiler/xla:xla_headers_lib",
|
||||||
"//tensorflow/compiler/xla/service",
|
"//tensorflow/compiler/xla/service",
|
||||||
|
"//tensorflow/compiler/xla/service:computation_placer",
|
||||||
|
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
"@local_config_cuda//cuda:cuda_headers",
|
"@local_config_cuda//cuda:cuda_headers",
|
||||||
"@protobuf_archive//:protobuf_headers",
|
"@protobuf_archive//:protobuf_headers",
|
||||||
],
|
],
|
||||||
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/plugin/executor/compiler.h"
|
#include "tensorflow/compiler/plugin/executor/compiler.h"
|
||||||
#include "tensorflow/compiler/plugin/executor/executable.h"
|
#include "tensorflow/compiler/plugin/executor/executable.h"
|
||||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
#include "tensorflow/compiler/xla/service/hlo_cse.h"
|
||||||
@ -27,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/layout_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -55,6 +57,8 @@ Status ExecutorCompiler::RunHloOptimization(HloModule* hlo_module) {
|
|||||||
pipeline.AddPass<ReshapeMover>();
|
pipeline.AddPass<ReshapeMover>();
|
||||||
pipeline.AddPass<HloConstantFolding>();
|
pipeline.AddPass<HloConstantFolding>();
|
||||||
pipeline.AddPass<HloCSE>(true);
|
pipeline.AddPass<HloCSE>(true);
|
||||||
|
pipeline.AddPass<LayoutAssignment>(
|
||||||
|
hlo_module->mutable_entry_computation_layout());
|
||||||
|
|
||||||
pipeline.AddPass<HloDCE>();
|
pipeline.AddPass<HloDCE>();
|
||||||
pipeline.AddPass<FlattenCallGraph>();
|
pipeline.AddPass<FlattenCallGraph>();
|
||||||
@ -107,10 +111,16 @@ ExecutorCompiler::ShapeSizeBytesFunction() const {
|
|||||||
return ExecutorExecutable::ShapeSizeBytes;
|
return ExecutorExecutable::ShapeSizeBytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() {
|
||||||
|
return xla::MakeUnique<xla::ComputationPlacer>();
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_MODULE_INITIALIZER(executor_compiler, {
|
REGISTER_MODULE_INITIALIZER(executor_compiler, {
|
||||||
xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() {
|
xla::Compiler::RegisterCompilerFactory(sep::kExecutorPlatformId, []() {
|
||||||
return xla::MakeUnique<xla::executorplugin::ExecutorCompiler>();
|
return xla::MakeUnique<xla::executorplugin::ExecutorCompiler>();
|
||||||
});
|
});
|
||||||
|
xla::ComputationPlacer::RegisterComputationPlacer(sep::kExecutorPlatformId,
|
||||||
|
&CreateComputationPlacer);
|
||||||
});
|
});
|
||||||
|
|
||||||
} // namespace executorplugin
|
} // namespace executorplugin
|
||||||
|
@ -133,7 +133,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
|||||||
} else {
|
} else {
|
||||||
// Array shape.
|
// Array shape.
|
||||||
if (!shape.has_layout()) {
|
if (!shape.has_layout()) {
|
||||||
return InvalidArgument("shape does not have a layout");
|
return InvalidArgument("shape %s does not have a layout",
|
||||||
|
ShapeUtil::HumanString(shape).c_str());
|
||||||
}
|
}
|
||||||
return ValidateLayoutForShape(shape.layout(), shape);
|
return ValidateLayoutForShape(shape.layout(), shape);
|
||||||
}
|
}
|
||||||
|
@ -1058,6 +1058,8 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
|||||||
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate(
|
||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
tensorflow::gtl::ArraySlice<const Literal*> arg_literals) {
|
tensorflow::gtl::ArraySlice<const Literal*> arg_literals) {
|
||||||
|
XLA_VLOG_LINES(
|
||||||
|
2, "HloEvaluator::Evaluate computation:\n" + computation.ToString());
|
||||||
arg_literals_ = arg_literals;
|
arg_literals_ = arg_literals;
|
||||||
evaluated_.clear();
|
evaluated_.clear();
|
||||||
|
|
||||||
|
@ -106,6 +106,14 @@ cc_binary(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "replay_computation_hlo_evaluator",
|
||||||
|
deps = [
|
||||||
|
":replay_computation_library",
|
||||||
|
"//tensorflow/compiler/plugin/executor:plugin_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "show_literal",
|
name = "show_literal",
|
||||||
srcs = ["show_literal.cc"],
|
srcs = ["show_literal.cc"],
|
||||||
|
Loading…
Reference in New Issue
Block a user