[TF/XLA] Force all tensors which need to be constant during the XLA compilation to be located on the host

Otherwise, this leads to strange crashes during the compilation.

PiperOrigin-RevId: 304226917
Change-Id: Ia2f1e77b13a25c7e15f009787af81f93b90e8bca
This commit is contained in:
George Karpenkov 2020-04-01 11:31:31 -07:00 committed by TensorFlower Gardener
parent c58518c893
commit 9771765f41
10 changed files with 271 additions and 56 deletions

View File

@ -422,6 +422,7 @@ cc_library(
"xla_kernel_creator.h",
],
deps = [
":compilability_check_util",
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
@ -591,6 +592,7 @@ cc_library(
"encapsulate_subgraphs_pass.cc",
"encapsulate_xla_computations_pass.cc",
"extract_outside_compilation_pass.cc",
"force_xla_constants_on_host_pass.cc",
"increase_dynamism_for_auto_jit_pass.cc",
"introduce_floating_point_jitter_pass.cc",
"mark_for_compilation_pass.cc",
@ -606,6 +608,7 @@ cc_library(
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
"extract_outside_compilation_pass.h",
"force_xla_constants_on_host_pass.h",
"increase_dynamism_for_auto_jit_pass.h",
"introduce_floating_point_jitter_pass.h",
"mark_for_compilation_pass.h",
@ -774,6 +777,7 @@ tf_cc_test(
"encapsulate_subgraphs_pass_test.cc",
"encapsulate_xla_computations_pass_test.cc",
"extract_outside_compilation_pass_test.cc",
"force_xla_constants_on_host_pass_test.cc",
"increase_dynamism_for_auto_jit_pass_test.cc",
"introduce_floating_point_jitter_pass_internal.h",
"introduce_floating_point_jitter_pass_test.cc",
@ -786,6 +790,7 @@ tf_cc_test(
tags = ["nomsan"] + tf_cuda_tests_tags(),
deps = [
":common",
":compilability_check_util",
":compilation_passes",
":compilation_passes_test_main",
":encapsulate_util",

View File

@ -518,4 +518,50 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
}
}
bool CanCreateXlaKernel(const NodeDef& node_def) {
// If kXlaMustCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
return it != node_def.attr().end() && it->second.b();
}
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
TF_RETURN_IF_ERROR(
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
TF_RETURN_IF_ERROR(
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
/*compile_time_const_nodes=*/nullptr, flr));
for (size_t i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
constant_arg_indices->push_back(i);
}
}
// There can be hundreds of resource variables. Reserve the space for them.
// We don't reserve for constants above as they are usually few.
resource_arg_indices->reserve(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] == DT_RESOURCE) {
resource_arg_indices->push_back(i);
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -265,6 +265,23 @@ class RecursiveCompilabilityChecker {
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
const XlaOpRegistry::DeviceRegistration& registration);
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
// runtime, returns this function's body in `fbody` as well as the indices
// of its constant and resource arguments.
// `fbody` is owned by `flr`.
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
// They are sorted in ascending order on this function's return.
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices);
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
// set.
bool CanCreateXlaKernel(const NodeDef& node_def);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
Status ForceXlaConstantsOnHostPass::Run(
const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
OptimizerOptions opts;
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, options.session_options->env, /*config=*/nullptr,
TF_GRAPH_DEF_VERSION, options.flib_def, opts);
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
for (Node* node : graph->nodes()) {
if (CanCreateXlaKernel(node->def())) {
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
// Force all constants to be on the host memory.
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node->def(), &fbody, &constant_arg_indices,
&resource_arg_indices));
VLOG(3) << "Found constant arg indices: "
<< absl::StrJoin(constant_arg_indices, ", ");
node->AddAttr("_input_hostmem", constant_arg_indices);
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
#define TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
// An optimization pass which marks the constants which have to be resolved for
// XLA compilation with `_input_hostmem`.
class ForceXlaConstantsOnHostPass : public GraphOptimizationPass {
public:
ForceXlaConstantsOnHostPass() = default;
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_

View File

@ -0,0 +1,108 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
Status ForceXlaConstantsOnHost(const Scope& s,
FunctionLibraryDefinition* flib_def,
std::unique_ptr<Graph>* result) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
GraphOptimizationPassOptions options;
SessionOptions session_options;
session_options.env = Env::Default();
options.graph = &graph;
options.session_options = &session_options;
options.flib_def = flib_def;
TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
ForceXlaConstantsOnHostPass rewriter;
TF_RETURN_IF_ERROR(rewriter.Run(options));
*result = std::move(graph);
return Status::OK();
}
TEST(ForceXlaConstantsOnHostPassTest, Simple) {
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary library;
FunctionDef called_func =
FunctionDefHelper::Create("TransposeCall",
/*in_def=*/{"a:float", "b:int32"},
/*out_def=*/{"c:float"}, {},
{{{"t0"},
"Transpose",
{"a", "b"},
{
{"T", DT_FLOAT},
{"Tperm", DT_INT32},
}}},
{{"c", "t0:y:0"}});
AttrValue true_attribute;
true_attribute.set_b(true);
(*called_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
*library.add_function() = called_func;
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
FunctionLibraryDefinition flib_def(OpRegistry::Global(), library);
Output in = ops::Placeholder(root, DT_FLOAT);
Output perm = ops::Const(root, {3, 1, 2, 0});
NameAttrList b_name_attr;
b_name_attr.set_name("TransposeCall");
ops::PartitionedCall call(root.WithOpName("call"), {in, perm}, {DT_FLOAT},
b_name_attr);
call.output.front().node()->AddAttr(kXlaMustCompileAttr, true);
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(ForceXlaConstantsOnHost(root, &flib_def, &graph));
bool found = false;
for (Node* node : graph->nodes()) {
if (CanCreateXlaKernel(node->def())) {
EXPECT_FALSE(found);
found = true;
std::vector<int32> hostmem_attr;
EXPECT_TRUE(TryGetNodeAttr(node->def(), "_input_hostmem", &hostmem_attr));
EXPECT_EQ(hostmem_attr.size(), 1);
EXPECT_EQ(hostmem_attr[0], 1);
}
}
EXPECT_TRUE(found);
}
} // namespace
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@ -57,6 +58,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 9,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 12,
ForceXlaConstantsOnHostPass);
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
IncreaseDynamismForAutoJitPass);

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/core/common_runtime/function.h"

View File

@ -70,58 +70,6 @@ class SinglePassSearch {
};
} // namespace
bool CanCreateXlaKernel(const NodeDef& node_def) {
// If kXlaMustCompileAttr is set on the node_def, use its value.
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
return it != node_def.attr().end() && it->second.b();
}
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
// runtime, returns this function's body in `fbody` as well as the indices
// of its constant and resource arguments.
// `fbody` is owned by `flr`.
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
// They are sorted in ascending order on this function's return.
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
const FunctionBody** fbody,
std::vector<int>* constant_arg_indices,
std::vector<int>* resource_arg_indices) {
FunctionLibraryRuntime::Handle handle;
// If node_def is not instantiable, e.g., the function does not exist,
// simply bail out.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
TF_RETURN_IF_ERROR(
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
*fbody = flr->GetFunctionBody(handle);
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
const DataTypeVector& arg_types = (*fbody)->arg_types;
std::vector<bool> const_args(arg_types.size());
// If we can't analyze the const args. Bail out.
TF_RETURN_IF_ERROR(
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
/*compile_time_const_nodes=*/nullptr, flr));
for (size_t i = 0; i < const_args.size(); ++i) {
if (const_args[i]) {
constant_arg_indices->push_back(i);
}
}
// There can be hundreds of resource variables. Reserve the space for them.
// We don't reserve for constants above as they are usually few.
resource_arg_indices->reserve(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i] == DT_RESOURCE) {
resource_arg_indices->push_back(i);
}
}
return Status::OK();
}
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(node_def)) {

View File

@ -24,10 +24,6 @@ namespace tensorflow {
class FunctionLibraryRuntime;
class OpKernel;
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
// set.
bool CanCreateXlaKernel(const NodeDef& node_def);
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel);