[TF/XLA] Force all tensors which need to be constant during the XLA compilation to be located on the host
Otherwise, this leads to strange crashes during the compilation. PiperOrigin-RevId: 304226917 Change-Id: Ia2f1e77b13a25c7e15f009787af81f93b90e8bca
This commit is contained in:
parent
c58518c893
commit
9771765f41
@ -422,6 +422,7 @@ cc_library(
|
|||||||
"xla_kernel_creator.h",
|
"xla_kernel_creator.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":compilability_check_util",
|
||||||
":flags",
|
":flags",
|
||||||
":jit_compilation_passes",
|
":jit_compilation_passes",
|
||||||
":xla_kernel_creator_util",
|
":xla_kernel_creator_util",
|
||||||
@ -591,6 +592,7 @@ cc_library(
|
|||||||
"encapsulate_subgraphs_pass.cc",
|
"encapsulate_subgraphs_pass.cc",
|
||||||
"encapsulate_xla_computations_pass.cc",
|
"encapsulate_xla_computations_pass.cc",
|
||||||
"extract_outside_compilation_pass.cc",
|
"extract_outside_compilation_pass.cc",
|
||||||
|
"force_xla_constants_on_host_pass.cc",
|
||||||
"increase_dynamism_for_auto_jit_pass.cc",
|
"increase_dynamism_for_auto_jit_pass.cc",
|
||||||
"introduce_floating_point_jitter_pass.cc",
|
"introduce_floating_point_jitter_pass.cc",
|
||||||
"mark_for_compilation_pass.cc",
|
"mark_for_compilation_pass.cc",
|
||||||
@ -606,6 +608,7 @@ cc_library(
|
|||||||
"encapsulate_subgraphs_pass.h",
|
"encapsulate_subgraphs_pass.h",
|
||||||
"encapsulate_xla_computations_pass.h",
|
"encapsulate_xla_computations_pass.h",
|
||||||
"extract_outside_compilation_pass.h",
|
"extract_outside_compilation_pass.h",
|
||||||
|
"force_xla_constants_on_host_pass.h",
|
||||||
"increase_dynamism_for_auto_jit_pass.h",
|
"increase_dynamism_for_auto_jit_pass.h",
|
||||||
"introduce_floating_point_jitter_pass.h",
|
"introduce_floating_point_jitter_pass.h",
|
||||||
"mark_for_compilation_pass.h",
|
"mark_for_compilation_pass.h",
|
||||||
@ -774,6 +777,7 @@ tf_cc_test(
|
|||||||
"encapsulate_subgraphs_pass_test.cc",
|
"encapsulate_subgraphs_pass_test.cc",
|
||||||
"encapsulate_xla_computations_pass_test.cc",
|
"encapsulate_xla_computations_pass_test.cc",
|
||||||
"extract_outside_compilation_pass_test.cc",
|
"extract_outside_compilation_pass_test.cc",
|
||||||
|
"force_xla_constants_on_host_pass_test.cc",
|
||||||
"increase_dynamism_for_auto_jit_pass_test.cc",
|
"increase_dynamism_for_auto_jit_pass_test.cc",
|
||||||
"introduce_floating_point_jitter_pass_internal.h",
|
"introduce_floating_point_jitter_pass_internal.h",
|
||||||
"introduce_floating_point_jitter_pass_test.cc",
|
"introduce_floating_point_jitter_pass_test.cc",
|
||||||
@ -786,6 +790,7 @@ tf_cc_test(
|
|||||||
tags = ["nomsan"] + tf_cuda_tests_tags(),
|
tags = ["nomsan"] + tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
|
":compilability_check_util",
|
||||||
":compilation_passes",
|
":compilation_passes",
|
||||||
":compilation_passes_test_main",
|
":compilation_passes_test_main",
|
||||||
":encapsulate_util",
|
":encapsulate_util",
|
||||||
|
|||||||
@ -518,4 +518,50 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
||||||
|
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
||||||
|
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
||||||
|
return it != node_def.attr().end() && it->second.b();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||||
|
const NodeDef& node_def,
|
||||||
|
const FunctionBody** fbody,
|
||||||
|
std::vector<int>* constant_arg_indices,
|
||||||
|
std::vector<int>* resource_arg_indices) {
|
||||||
|
FunctionLibraryRuntime::Handle handle;
|
||||||
|
// If node_def is not instantiable, e.g., the function does not exist,
|
||||||
|
// simply bail out.
|
||||||
|
NameAttrList function;
|
||||||
|
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
||||||
|
*fbody = flr->GetFunctionBody(handle);
|
||||||
|
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
||||||
|
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
||||||
|
std::vector<bool> const_args(arg_types.size());
|
||||||
|
// If we can't analyze the const args. Bail out.
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
||||||
|
/*compile_time_const_nodes=*/nullptr, flr));
|
||||||
|
|
||||||
|
for (size_t i = 0; i < const_args.size(); ++i) {
|
||||||
|
if (const_args[i]) {
|
||||||
|
constant_arg_indices->push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// There can be hundreds of resource variables. Reserve the space for them.
|
||||||
|
// We don't reserve for constants above as they are usually few.
|
||||||
|
resource_arg_indices->reserve(arg_types.size());
|
||||||
|
for (size_t i = 0; i < arg_types.size(); ++i) {
|
||||||
|
if (arg_types[i] == DT_RESOURCE) {
|
||||||
|
resource_arg_indices->push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -265,6 +265,23 @@ class RecursiveCompilabilityChecker {
|
|||||||
|
|
||||||
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter(
|
||||||
const XlaOpRegistry::DeviceRegistration& registration);
|
const XlaOpRegistry::DeviceRegistration& registration);
|
||||||
|
|
||||||
|
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
||||||
|
// runtime, returns this function's body in `fbody` as well as the indices
|
||||||
|
// of its constant and resource arguments.
|
||||||
|
// `fbody` is owned by `flr`.
|
||||||
|
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
||||||
|
// They are sorted in ascending order on this function's return.
|
||||||
|
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
||||||
|
const NodeDef& node_def,
|
||||||
|
const FunctionBody** fbody,
|
||||||
|
std::vector<int>* constant_arg_indices,
|
||||||
|
std::vector<int>* resource_arg_indices);
|
||||||
|
|
||||||
|
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
|
||||||
|
// set.
|
||||||
|
bool CanCreateXlaKernel(const NodeDef& node_def);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
#endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_
|
||||||
|
|||||||
54
tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc
Normal file
54
tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||||
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status ForceXlaConstantsOnHostPass::Run(
|
||||||
|
const GraphOptimizationPassOptions& options) {
|
||||||
|
Graph* graph = options.graph->get();
|
||||||
|
|
||||||
|
OptimizerOptions opts;
|
||||||
|
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||||
|
nullptr, options.session_options->env, /*config=*/nullptr,
|
||||||
|
TF_GRAPH_DEF_VERSION, options.flib_def, opts);
|
||||||
|
FunctionLibraryRuntime* flr =
|
||||||
|
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||||
|
|
||||||
|
for (Node* node : graph->nodes()) {
|
||||||
|
if (CanCreateXlaKernel(node->def())) {
|
||||||
|
const FunctionBody* fbody = nullptr;
|
||||||
|
std::vector<int> constant_arg_indices;
|
||||||
|
std::vector<int> resource_arg_indices;
|
||||||
|
|
||||||
|
// Force all constants to be on the host memory.
|
||||||
|
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
|
||||||
|
flr, node->def(), &fbody, &constant_arg_indices,
|
||||||
|
&resource_arg_indices));
|
||||||
|
VLOG(3) << "Found constant arg indices: "
|
||||||
|
<< absl::StrJoin(constant_arg_indices, ", ");
|
||||||
|
|
||||||
|
node->AddAttr("_input_hostmem", constant_arg_indices);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
36
tensorflow/compiler/jit/force_xla_constants_on_host_pass.h
Normal file
36
tensorflow/compiler/jit/force_xla_constants_on_host_pass.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||||
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// An optimization pass which marks the constants which have to be resolved for
|
||||||
|
// XLA compilation with `_input_hostmem`.
|
||||||
|
class ForceXlaConstantsOnHostPass : public GraphOptimizationPass {
|
||||||
|
public:
|
||||||
|
ForceXlaConstantsOnHostPass() = default;
|
||||||
|
|
||||||
|
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_JIT_FORCE_XLA_CONSTANTS_ON_HOST_PASS_H_
|
||||||
108
tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc
Normal file
108
tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
|
||||||
|
|
||||||
|
#include "absl/strings/match.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
|
#include "tensorflow/cc/ops/functional_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||||
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
|
#include "tensorflow/compiler/jit/test_util.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status ForceXlaConstantsOnHost(const Scope& s,
|
||||||
|
FunctionLibraryDefinition* flib_def,
|
||||||
|
std::unique_ptr<Graph>* result) {
|
||||||
|
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||||
|
GraphOptimizationPassOptions options;
|
||||||
|
SessionOptions session_options;
|
||||||
|
session_options.env = Env::Default();
|
||||||
|
options.graph = &graph;
|
||||||
|
options.session_options = &session_options;
|
||||||
|
options.flib_def = flib_def;
|
||||||
|
TF_RETURN_IF_ERROR(s.ToGraph(graph.get()));
|
||||||
|
ForceXlaConstantsOnHostPass rewriter;
|
||||||
|
TF_RETURN_IF_ERROR(rewriter.Run(options));
|
||||||
|
*result = std::move(graph);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ForceXlaConstantsOnHostPassTest, Simple) {
|
||||||
|
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
FunctionDefLibrary library;
|
||||||
|
|
||||||
|
FunctionDef called_func =
|
||||||
|
FunctionDefHelper::Create("TransposeCall",
|
||||||
|
/*in_def=*/{"a:float", "b:int32"},
|
||||||
|
/*out_def=*/{"c:float"}, {},
|
||||||
|
{{{"t0"},
|
||||||
|
"Transpose",
|
||||||
|
{"a", "b"},
|
||||||
|
{
|
||||||
|
{"T", DT_FLOAT},
|
||||||
|
{"Tperm", DT_INT32},
|
||||||
|
}}},
|
||||||
|
{{"c", "t0:y:0"}});
|
||||||
|
|
||||||
|
AttrValue true_attribute;
|
||||||
|
true_attribute.set_b(true);
|
||||||
|
(*called_func.mutable_attr())[kXlaMustCompileAttr] = true_attribute;
|
||||||
|
*library.add_function() = called_func;
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(library));
|
||||||
|
FunctionLibraryDefinition flib_def(OpRegistry::Global(), library);
|
||||||
|
Output in = ops::Placeholder(root, DT_FLOAT);
|
||||||
|
Output perm = ops::Const(root, {3, 1, 2, 0});
|
||||||
|
|
||||||
|
NameAttrList b_name_attr;
|
||||||
|
b_name_attr.set_name("TransposeCall");
|
||||||
|
ops::PartitionedCall call(root.WithOpName("call"), {in, perm}, {DT_FLOAT},
|
||||||
|
b_name_attr);
|
||||||
|
call.output.front().node()->AddAttr(kXlaMustCompileAttr, true);
|
||||||
|
|
||||||
|
std::unique_ptr<Graph> graph;
|
||||||
|
TF_ASSERT_OK(ForceXlaConstantsOnHost(root, &flib_def, &graph));
|
||||||
|
|
||||||
|
bool found = false;
|
||||||
|
for (Node* node : graph->nodes()) {
|
||||||
|
if (CanCreateXlaKernel(node->def())) {
|
||||||
|
EXPECT_FALSE(found);
|
||||||
|
found = true;
|
||||||
|
std::vector<int32> hostmem_attr;
|
||||||
|
EXPECT_TRUE(TryGetNodeAttr(node->def(), "_input_hostmem", &hostmem_attr));
|
||||||
|
EXPECT_EQ(hostmem_attr.size(), 1);
|
||||||
|
EXPECT_EQ(hostmem_attr[0], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(found);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
#include "tensorflow/compiler/jit/cluster_scoping_pass.h"
|
||||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||||
|
#include "tensorflow/compiler/jit/force_xla_constants_on_host_pass.h"
|
||||||
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
#include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h"
|
||||||
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
#include "tensorflow/compiler/jit/introduce_floating_point_jitter_pass.h"
|
||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
|
||||||
@ -57,6 +58,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 9,
|
|||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
|
||||||
MarkForCompilationPass);
|
MarkForCompilationPass);
|
||||||
|
|
||||||
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 12,
|
||||||
|
ForceXlaConstantsOnHostPass);
|
||||||
|
|
||||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
|
||||||
IncreaseDynamismForAutoJitPass);
|
IncreaseDynamismForAutoJitPass);
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/jit/compilability_check_util.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
|||||||
@ -70,58 +70,6 @@ class SinglePassSearch {
|
|||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool CanCreateXlaKernel(const NodeDef& node_def) {
|
|
||||||
// If kXlaMustCompileAttr is set on the node_def, use its value.
|
|
||||||
const auto& it = node_def.attr().find(kXlaMustCompileAttr);
|
|
||||||
return it != node_def.attr().end() && it->second.b();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
|
|
||||||
// runtime, returns this function's body in `fbody` as well as the indices
|
|
||||||
// of its constant and resource arguments.
|
|
||||||
// `fbody` is owned by `flr`.
|
|
||||||
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
|
|
||||||
// They are sorted in ascending order on this function's return.
|
|
||||||
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
|
|
||||||
const NodeDef& node_def,
|
|
||||||
const FunctionBody** fbody,
|
|
||||||
std::vector<int>* constant_arg_indices,
|
|
||||||
std::vector<int>* resource_arg_indices) {
|
|
||||||
FunctionLibraryRuntime::Handle handle;
|
|
||||||
// If node_def is not instantiable, e.g., the function does not exist,
|
|
||||||
// simply bail out.
|
|
||||||
NameAttrList function;
|
|
||||||
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle));
|
|
||||||
*fbody = flr->GetFunctionBody(handle);
|
|
||||||
CHECK(*fbody); // Can't be nullptr since we just instantiated it.
|
|
||||||
const DataTypeVector& arg_types = (*fbody)->arg_types;
|
|
||||||
std::vector<bool> const_args(arg_types.size());
|
|
||||||
// If we can't analyze the const args. Bail out.
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
BackwardsConstAnalysis(*((*fbody)->graph), &const_args,
|
|
||||||
/*compile_time_const_nodes=*/nullptr, flr));
|
|
||||||
|
|
||||||
for (size_t i = 0; i < const_args.size(); ++i) {
|
|
||||||
if (const_args[i]) {
|
|
||||||
constant_arg_indices->push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// There can be hundreds of resource variables. Reserve the space for them.
|
|
||||||
// We don't reserve for constants above as they are usually few.
|
|
||||||
resource_arg_indices->reserve(arg_types.size());
|
|
||||||
for (size_t i = 0; i < arg_types.size(); ++i) {
|
|
||||||
if (arg_types[i] == DT_RESOURCE) {
|
|
||||||
resource_arg_indices->push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||||
std::unique_ptr<OpKernel>* kernel) {
|
std::unique_ptr<OpKernel>* kernel) {
|
||||||
if (!CanCreateXlaKernel(node_def)) {
|
if (!CanCreateXlaKernel(node_def)) {
|
||||||
|
|||||||
@ -24,10 +24,6 @@ namespace tensorflow {
|
|||||||
class FunctionLibraryRuntime;
|
class FunctionLibraryRuntime;
|
||||||
class OpKernel;
|
class OpKernel;
|
||||||
|
|
||||||
// Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr
|
|
||||||
// set.
|
|
||||||
bool CanCreateXlaKernel(const NodeDef& node_def);
|
|
||||||
|
|
||||||
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
|
||||||
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
|
||||||
std::unique_ptr<OpKernel>* kernel);
|
std::unique_ptr<OpKernel>* kernel);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user