[TF/XLA] Force all tensors which need to be constant during the XLA compilation to be located on the host
Otherwise, this leads to strange crashes during the compilation. PiperOrigin-RevId: 304226917 Change-Id: Ia2f1e77b13a25c7e15f009787af81f93b90e8bca
This commit is contained in:
parent
c58518c893
commit
9771765f41
|
@ -422,6 +422,7 @@ cc_library(
|
|||
"xla_kernel_creator.h",
|
||||
],
|
||||
deps = [
|
||||
":compilability_check_util",
|
||||
":flags",
|
||||
":jit_compilation_passes",
|
||||
":xla_kernel_creator_util",
|
||||
|
@ -591,6 +592,7 @@ cc_library(
|
|||
"encapsulate_subgraphs_pass.cc",
|
||||
"encapsulate_xla_computations_pass.cc",
|
||||
"extract_outside_compilation_pass.cc",
|
||||
"force_xla_constants_on_host_pass.cc",
|
||||
"increase_dynamism_for_auto_jit_pass.cc",
|
||||
"introduce_floating_point_jitter_pass.cc",
|
||||
"mark_for_compilation_pass.cc",
|
||||
|
@ -606,6 +608,7 @@ cc_library(
|
|||
"encapsulate_subgraphs_pass.h",
|
||||
"encapsulate_xla_computations_pass.h",
|
||||
"extract_outside_compilation_pass.h",
|
||||
"force_xla_constants_on_host_pass.h",
|
||||
"increase_dynamism_for_auto_jit_pass.h",
|
||||
"introduce_floating_point_jitter_pass.h",
|
||||
"mark_for_compilation_pass.h",
|
||||
|
@ -774,6 +777,7 @@ tf_cc_test(
|
|||
"encapsulate_subgraphs_pass_test.cc",
|
||||
"encapsulate_xla_computations_pass_test.cc",
|
||||
"extract_outside_compilation_pass_test.cc",
|
||||
"force_xla_constants_on_host_pass_test.cc",
|
||||
"increase_dynamism_for_auto_jit_pass_test.cc",
|
||||
"introduce_floating_point_jitter_pass_internal.h",
|
||||
"introduce_floating_point_jitter_pass_test.cc",
|
||||
|
@ -786,6 +790,7 @@ tf_cc_test(
|
|||
tags = ["nomsan"] + tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":common",
|
||||
":compilability_check_util",
|
||||
":compilation_passes",
|
||||
":compilation_passes_test_main",
|
||||
":encapsulate_util",
|
||||
|
|
|
@ -518,4 +518,50 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
|||
}
|
||||
}
|
||||
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||
return it != node_def.attr().end() && it->second.b();
|
||||
}
|
||||
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||
std::vector<bool> const_args(arg_types.size());
|
||||
// If we can't analyze the const args. Bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (size_t i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -265,6 +265,23 @@ class RecursiveCompilabilityChecker {
|
|||
|
||||
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||
const XlaOpRegistry::DeviceRegistration& registration);
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
// runtime, returns this function's body in `fbody` as well as the indices
|
||||
// of its constant and resource arguments.
|
||||
// `fbody` is owned by `flr`.
|
||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||
// They are sorted in ascending order on this function's return.
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices);
|
||||
|
||||
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
|
||||
// set.
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status ForceXlaConstantsOnHostPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
OptimizerOptions opts;
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
nullptr, options.session_options->env, /*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, options.flib_def, opts);
|
||||
FunctionLibraryRuntime* flr =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
for (Node* node : graph->nodes()) {
|
||||
if (CanCreateXlaKernel(node->def())) {
|
||||
const FunctionBody* fbody = nullptr;
|
||||
std::vector<int> constant_arg_indices;
|
||||
std::vector<int> resource_arg_indices;
|
||||
|
||||
// Force all constants to be on the host memory.
|
||||
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||
flr, node->def(), &fbody, &constant_arg_indices,
|
||||
&resource_arg_indices));
|
||||
VLOG(3) << "Found constant arg indices: "
|
||||
<< absl::StrJoin(constant_arg_indices, ", ");
|
||||
|
||||
node->AddAttr("_input_hostmem", constant_arg_indices);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
|
@ -0,0 +1,36 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// An optimization pass which marks the constants which have to be resolved for
|
||||
// XLA compilation with `_input_hostmem`.
|
||||
class ForceXlaConstantsOnHostPass : public GraphOptimizationPass {
|
||||
public:
|
||||
ForceXlaConstantsOnHostPass() = default;
|
||||
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
|
|
@ -0,0 +1,108 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status ForceXlaConstantsOnHost(const Scope& s,
|
||||
FunctionLibraryDefinition* flib_def,
|
||||
std::unique_ptr<Graph>* result) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
GraphOptimizationPassOptions options;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
options.graph = &graph;
|
||||
options.session_options = &session_options;
|
||||
options.flib_def = flib_def;
|
||||
TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
|
||||
ForceXlaConstantsOnHostPass rewriter;
|
||||
TF_RETURN_IF_ERROR(rewriter.Run(options));
|
||||
*result = std::move(graph);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST(ForceXlaConstantsOnHostPassTest, Simple) {
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
FunctionDefLibrary library;
|
||||
|
||||
FunctionDef called_func =
|
||||
FunctionDefHelper::Create("TransposeCall",
|
||||
/*in_def=*/{"a:float", "b:int32"},
|
||||
/*out_def=*/{"c:float"}, {},
|
||||
{{{"t0"},
|
||||
"Transpose",
|
||||
{"a", "b"},
|
||||
{
|
||||
{"T", DT_FLOAT},
|
||||
{"Tperm", DT_INT32},
|
||||
}}},
|
||||
{{"c", "t0:y:0"}});
|
||||
|
||||
AttrValue true_attribute;
|
||||
true_attribute.set_b(true);
|
||||
(*called_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
|
||||
*library.add_function() = called_func;
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), library);
|
||||
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||
Output perm = ops::Const(root, {3, 1, 2, 0});
|
||||
|
||||
NameAttrList b_name_attr;
|
||||
b_name_attr.set_name("TransposeCall");
|
||||
ops::PartitionedCall call(root.WithOpName("call"), {in, perm}, {DT_FLOAT},
|
||||
b_name_attr);
|
||||
call.output.front().node()->AddAttr(kXlaMustCompileAttr, true);
|
||||
|
||||
std::unique_ptr<Graph> graph;
|
||||
TF_ASSERT_OK(ForceXlaConstantsOnHost(root, &flib_def, &graph));
|
||||
|
||||
bool found = false;
|
||||
for (Node* node : graph->nodes()) {
|
||||
if (CanCreateXlaKernel(node->def())) {
|
||||
EXPECT_FALSE(found);
|
||||
found = true;
|
||||
std::vector<int32> hostmem_attr;
|
||||
EXPECT_TRUE(TryGetNodeAttr(node->def(), "_input_hostmem", &hostmem_attr));
|
||||
EXPECT_EQ(hostmem_attr.size(), 1);
|
||||
EXPECT_EQ(hostmem_attr[0], 1);
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(found);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
|
||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||
|
@ -57,6 +58,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 9,
|
|||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
||||
MarkForCompilationPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 12,
|
||||
ForceXlaConstantsOnHostPass);
|
||||
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||
IncreaseDynamismForAutoJitPass);
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
|
|
|
@ -70,58 +70,6 @@ class SinglePassSearch {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||
return it != node_def.attr().end() && it->second.b();
|
||||
}
|
||||
|
||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||
// runtime, returns this function's body in `fbody` as well as the indices
|
||||
// of its constant and resource arguments.
|
||||
// `fbody` is owned by `flr`.
|
||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||
// They are sorted in ascending order on this function's return.
|
||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||
const NodeDef& node_def,
|
||||
const FunctionBody** fbody,
|
||||
std::vector<int>* constant_arg_indices,
|
||||
std::vector<int>* resource_arg_indices) {
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
// If node_def is not instantiable, e.g., the function does not exist,
|
||||
// simply bail out.
|
||||
NameAttrList function;
|
||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
||||
*fbody = flr->GetFunctionBody(handle);
|
||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||
std::vector<bool> const_args(arg_types.size());
|
||||
// If we can't analyze the const args. Bail out.
|
||||
TF_RETURN_IF_ERROR(
|
||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr, flr));
|
||||
|
||||
for (size_t i = 0; i < const_args.size(); ++i) {
|
||||
if (const_args[i]) {
|
||||
constant_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
// There can be hundreds of resource variables. Reserve the space for them.
|
||||
// We don't reserve for constants above as they are usually few.
|
||||
resource_arg_indices->reserve(arg_types.size());
|
||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||
if (arg_types[i] == DT_RESOURCE) {
|
||||
resource_arg_indices->push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel) {
|
||||
if (!CanCreateXlaKernel(node_def)) {
|
||||
|
|
|
@ -24,10 +24,6 @@ namespace tensorflow {
|
|||
class FunctionLibraryRuntime;
|
||||
class OpKernel;
|
||||
|
||||
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
|
||||
// set.
|
||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||
|
||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||
std::unique_ptr<OpKernel>* kernel);
|
||||
|
|
Loading…
Reference in New Issue