Support interprocedural constant meta-information propagation for compilation
This CL does two things: 1) Supports inter-procedural constant information propagation, across PartitionedCall and StatefulPartitionedCall. 2) Done naively, (1) leads to exponential number of calls, as each function will be reinlined for each (indirect) caller. In order to address this performance issue, we cache the argument indices which need to be constant, and attach that information to the Graph object. This might require some clarification: a) Caching in a passed map would not work, as duplication of constant propagation for each top-level caller is still prohibitively expensive. b) Caching in a global object would not work, as graphs are created and destroyed during transformations. c) Caching this meta-information on a `Graph` object has an added benefit that we no longer perform the same constant propagation many times (a lot of compilation passes call BackwardsConstAnalysis, and previously all this work had to be repeated). PiperOrigin-RevId: 303860413 Change-Id: I78f92ca1487fc952044e5ac6526dcaa5b50d5f21
This commit is contained in:
parent
4d37ea391e
commit
f3dcd9dc11
tensorflow
compiler/tf2xla
core/graph
python/eager
@ -612,10 +612,12 @@ tf_cc_test(
|
||||
":xla_compiler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -145,6 +145,21 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
|
||||
return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
|
||||
} else if (node.op() == "PartitionedCall" ||
|
||||
node.op() == "StatefulPartitionedCall") {
|
||||
const FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
|
||||
int num_inputs = fbody->fdef.signature().input_arg_size();
|
||||
std::vector<bool> compile_time_const_arg_indices(num_inputs);
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||
*(fbody->graph), &compile_time_const_arg_indices,
|
||||
/*compile_time_const_nodes=*/nullptr, flib_runtime));
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
if (compile_time_const_arg_indices[i]) {
|
||||
const_input_idxs->push_back(i);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (op_def != nullptr) {
|
||||
return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
|
||||
const_input_idxs);
|
||||
@ -166,11 +181,21 @@ Status GetCompileTimeConstInputs(const Node* node,
|
||||
|
||||
// Backwards dataflow analysis that finds arguments to a graph that must be
|
||||
// compile-time constants.
|
||||
Status BackwardsConstAnalysis(const Graph& g,
|
||||
std::vector<bool>* compile_time_const_arg_indices,
|
||||
std::vector<bool>* compile_time_const_nodes,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
std::function<bool(const Edge&)> edge_filter) {
|
||||
Status BackwardsConstAnalysis(
|
||||
const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
|
||||
std::vector<bool>* compile_time_const_nodes,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
std::function<bool(const Edge&)> edge_filter_input) {
|
||||
if (!compile_time_const_nodes && g.GetConstArgIndicesCache().has_value() &&
|
||||
!edge_filter_input) {
|
||||
VLOG(5) << "Using cached argument indices on graph " << &g;
|
||||
*compile_time_const_arg_indices = g.GetConstArgIndicesCache().value();
|
||||
return Status::OK();
|
||||
}
|
||||
auto edge_filter = [&](const Edge& e) {
|
||||
return edge_filter_input ? edge_filter_input(e) : true;
|
||||
};
|
||||
|
||||
std::vector<bool> compile_time_const_nodes_impl;
|
||||
if (compile_time_const_nodes) {
|
||||
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
|
||||
@ -252,6 +277,10 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
// acyclic graph.
|
||||
DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
|
||||
[](const Edge& edge) { return !edge.src()->IsNextIteration(); });
|
||||
if (compile_time_const_arg_indices && !edge_filter_input) {
|
||||
VLOG(5) << "Setting the cache on the graph: " << &g;
|
||||
g.GetConstArgIndicesCache() = *compile_time_const_arg_indices;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -33,14 +33,13 @@ namespace tensorflow {
|
||||
// The ids of the nodes in `graph` that must be constant are returned in
|
||||
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
|
||||
//
|
||||
// Only propagate const-ness along edges for which `edge_filter` returns true.
|
||||
// If `edge_filter` is non-null, only propagate const-ness along edges for which
|
||||
// `edge_filter` returns true.
|
||||
Status BackwardsConstAnalysis(
|
||||
const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
|
||||
std::vector<bool>* compile_time_const_nodes,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
std::function<bool(const Edge&)> edge_filter = [](const Edge& e) {
|
||||
return true;
|
||||
});
|
||||
std::function<bool(const Edge&)> edge_filter_input = nullptr);
|
||||
|
||||
// Given an op kernel and function library runtime, return all the indices of
|
||||
// inputs that need to be compile time constant.
|
||||
|
@ -19,11 +19,14 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -89,6 +92,59 @@ TEST(ConstAnalysisTest, TopologicalOrder) {
|
||||
}
|
||||
}
|
||||
|
||||
void TestFunctionCall(bool is_stateful_partitioned_call) {
|
||||
FunctionDef callee = FunctionDefHelper::Define(
|
||||
"Callee", {"t:float", "shape:int32"}, {"result:float"}, {},
|
||||
{{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}});
|
||||
|
||||
FunctionDefLibrary flib;
|
||||
*flib.add_function() = callee;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0);
|
||||
auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1);
|
||||
|
||||
NameAttrList call_attrs;
|
||||
call_attrs.set_name("Callee");
|
||||
if (is_stateful_partitioned_call) {
|
||||
ops::StatefulPartitionedCall b(root.WithOpName("Call"),
|
||||
{Output(arg0), Output(arg1)}, {DT_FLOAT},
|
||||
call_attrs);
|
||||
} else {
|
||||
ops::PartitionedCall b(root.WithOpName("Call"),
|
||||
{Output(arg0), Output(arg1)}, {DT_FLOAT},
|
||||
call_attrs);
|
||||
}
|
||||
|
||||
Graph graph(&flib_def);
|
||||
TF_ASSERT_OK(root.ToGraph(&graph));
|
||||
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
|
||||
/*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, &flib_def, opts));
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
std::vector<bool> const_args(2, false);
|
||||
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr,
|
||||
lib_runtime));
|
||||
|
||||
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
|
||||
}
|
||||
|
||||
TEST(ConstAnalysisTest, PartitionedCall) {
|
||||
TestFunctionCall(/*is_stateful_partitioned_call=*/false);
|
||||
}
|
||||
|
||||
TEST(ConstAnalysisTest, StatefulPartitionedCall) {
|
||||
TestFunctionCall(/*is_stateful_partitioned_call=*/true);
|
||||
}
|
||||
|
||||
TEST(ConstAnalysisTest, DontFollowControlDependencies) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -678,6 +679,10 @@ class Graph {
|
||||
// Builds a node name to node pointer index for all nodes in the graph.
|
||||
std::unordered_map<string, Node*> BuildNodeNameIndex() const;
|
||||
|
||||
absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
|
||||
return const_arg_indices_cache_;
|
||||
}
|
||||
|
||||
// TODO(josh11b): uint64 hash() const;
|
||||
|
||||
private:
|
||||
@ -751,6 +756,10 @@ class Graph {
|
||||
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
|
||||
std::map<string, WhileContext> while_ctxs_;
|
||||
|
||||
// Cache of the indices of the arguments which need to be constant for the XLA
|
||||
// compilation.
|
||||
mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
|
||||
};
|
||||
|
||||
|
@ -232,6 +232,24 @@ class DefFunctionTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'):
|
||||
c.f1(inputs)
|
||||
|
||||
def testMustBeConstantPropagation(self):
|
||||
if test.is_built_with_rocm():
|
||||
return
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f():
|
||||
return constant_op.constant([0, 2, 1], dtype=dtypes.int32)
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def g(a, b):
|
||||
return array_ops.transpose(a, b)
|
||||
|
||||
@def_function.function
|
||||
def z():
|
||||
return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f())
|
||||
|
||||
z()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user