Support interprocedural constant meta-information propagation for compilation

This CL does two things:

1) Supports inter-procedural constant information propagation, across
PartitionedCall and StatefulPartitionedCall.

2) Done naively, (1) leads to exponential number of calls, as each function
will be reinlined for each (indirect) caller.
In order to address this performance issue, we cache the argument indices which
need to be constant, and attach that information to the Graph object.

This might require some clarification:

a) Caching in a passed map would not work, as duplication of constant
propagation for each top-level caller is still prohibitively expensive.

b) Caching in a global object would not work, as graphs are created and
destroyed during transformations.

c) Caching this meta-information on a `Graph` object has an added benefit that
we no longer perform the same constant propagation many times (a lot of
compilation passes call BackwardsConstAnalysis, and previously all this work
had to be repeated).

PiperOrigin-RevId: 303860413
Change-Id: I78f92ca1487fc952044e5ac6526dcaa5b50d5f21
This commit is contained in:
George Karpenkov 2020-03-30 17:47:40 -07:00 committed by TensorFlower Gardener
parent 4d37ea391e
commit f3dcd9dc11
6 changed files with 122 additions and 9 deletions

View File

@ -612,10 +612,12 @@ tf_cc_test(
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -145,6 +145,21 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
TF_RETURN_IF_ERROR(
GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
} else if (node.op() == "PartitionedCall" ||
node.op() == "StatefulPartitionedCall") {
const FunctionBody* fbody;
TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
int num_inputs = fbody->fdef.signature().input_arg_size();
std::vector<bool> compile_time_const_arg_indices(num_inputs);
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*(fbody->graph), &compile_time_const_arg_indices,
/*compile_time_const_nodes=*/nullptr, flib_runtime));
for (int i = 0; i < num_inputs; i++) {
if (compile_time_const_arg_indices[i]) {
const_input_idxs->push_back(i);
}
}
return Status::OK();
} else if (op_def != nullptr) {
return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
const_input_idxs);
@ -166,11 +181,21 @@ Status GetCompileTimeConstInputs(const Node* node,
// Backwards dataflow analysis that finds arguments to a graph that must be
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
FunctionLibraryRuntime* flib_runtime,
std::function<bool(const Edge&)> edge_filter) {
Status BackwardsConstAnalysis(
const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
FunctionLibraryRuntime* flib_runtime,
std::function<bool(const Edge&)> edge_filter_input) {
if (!compile_time_const_nodes && g.GetConstArgIndicesCache().has_value() &&
!edge_filter_input) {
VLOG(5) << "Using cached argument indices on graph " << &g;
*compile_time_const_arg_indices = g.GetConstArgIndicesCache().value();
return Status::OK();
}
auto edge_filter = [&](const Edge& e) {
return edge_filter_input ? edge_filter_input(e) : true;
};
std::vector<bool> compile_time_const_nodes_impl;
if (compile_time_const_nodes) {
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
@ -252,6 +277,10 @@ Status BackwardsConstAnalysis(const Graph& g,
// acyclic graph.
DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
[](const Edge& edge) { return !edge.src()->IsNextIteration(); });
if (compile_time_const_arg_indices && !edge_filter_input) {
VLOG(5) << "Setting the cache on the graph: " << &g;
g.GetConstArgIndicesCache() = *compile_time_const_arg_indices;
}
return status;
}

View File

@ -33,14 +33,13 @@ namespace tensorflow {
// The ids of the nodes in `graph` that must be constant are returned in
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
//
// Only propagate const-ness along edges for which `edge_filter` returns true.
// If `edge_filter` is non-null, only propagate const-ness along edges for which
// `edge_filter` returns true.
Status BackwardsConstAnalysis(
const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
FunctionLibraryRuntime* flib_runtime,
std::function<bool(const Edge&)> edge_filter = [](const Edge& e) {
return true;
});
std::function<bool(const Edge&)> edge_filter_input = nullptr);
// Given an op kernel and function library runtime, return all the indices of
// inputs that need to be compile time constant.

View File

@ -19,11 +19,14 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@ -89,6 +92,59 @@ TEST(ConstAnalysisTest, TopologicalOrder) {
}
}
void TestFunctionCall(bool is_stateful_partitioned_call) {
FunctionDef callee = FunctionDefHelper::Define(
"Callee", {"t:float", "shape:int32"}, {"result:float"}, {},
{{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}});
FunctionDefLibrary flib;
*flib.add_function() = callee;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
Scope root = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0);
auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1);
NameAttrList call_attrs;
call_attrs.set_name("Callee");
if (is_stateful_partitioned_call) {
ops::StatefulPartitionedCall b(root.WithOpName("Call"),
{Output(arg0), Output(arg1)}, {DT_FLOAT},
call_attrs);
} else {
ops::PartitionedCall b(root.WithOpName("Call"),
{Output(arg0), Output(arg1)}, {DT_FLOAT},
call_attrs);
}
Graph graph(&flib_def);
TF_ASSERT_OK(root.ToGraph(&graph));
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
/*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &flib_def, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
std::vector<bool> const_args(2, false);
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
/*compile_time_const_nodes=*/nullptr,
lib_runtime));
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
}
TEST(ConstAnalysisTest, PartitionedCall) {
TestFunctionCall(/*is_stateful_partitioned_call=*/false);
}
TEST(ConstAnalysisTest, StatefulPartitionedCall) {
TestFunctionCall(/*is_stateful_partitioned_call=*/true);
}
TEST(ConstAnalysisTest, DontFollowControlDependencies) {
Scope root = Scope::NewRootScope();

View File

@ -41,6 +41,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -678,6 +679,10 @@ class Graph {
// Builds a node name to node pointer index for all nodes in the graph.
std::unordered_map<string, Node*> BuildNodeNameIndex() const;
absl::optional<std::vector<bool>>& GetConstArgIndicesCache() const {
return const_arg_indices_cache_;
}
// TODO(josh11b): uint64 hash() const;
private:
@ -751,6 +756,10 @@ class Graph {
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
std::map<string, WhileContext> while_ctxs_;
// Cache of the indices of the arguments which need to be constant for the XLA
// compilation.
mutable absl::optional<std::vector<bool>> const_arg_indices_cache_;
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
};

View File

@ -232,6 +232,24 @@ class DefFunctionTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'):
c.f1(inputs)
def testMustBeConstantPropagation(self):
if test.is_built_with_rocm():
return
@def_function.function(experimental_compile=True)
def f():
return constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@def_function.function(experimental_compile=True)
def g(a, b):
return array_ops.transpose(a, b)
@def_function.function
def z():
return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f())
z()
if __name__ == '__main__':
ops.enable_eager_execution()