Const analysis should peek into PartitionedCall and StatefulPartitionedCall.

PiperOrigin-RevId: 300571637
Change-Id: I4aef56c80e8bd2f14152f49aaf69778c2d916315
This commit is contained in:
A. Unique TensorFlower 2020-03-12 10:09:48 -07:00 committed by TensorFlower Gardener
parent 11edb2ffe4
commit fbbb83b995
6 changed files with 21 additions and 147 deletions

View File

@ -99,8 +99,7 @@ Status XlaCompileOnDemandOp::MustArgumentBeConstant(
// TODO(jmolloy): This could be expensive, so memoize.
std::vector<int> constant_input_indices;
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
op_kernel, &constant_input_indices, flib_runtime,
/*cached_arg_indices=*/nullptr));
op_kernel, &constant_input_indices, flib_runtime));
*result = absl::c_binary_search(constant_input_indices, argument_idx);
return Status::OK();
}

View File

@ -373,7 +373,6 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@ -611,12 +610,10 @@ tf_cc_test(
":xla_compiler",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/compiler/jit:xla_cluster_util",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -85,10 +85,10 @@ Status CondConstInputIndices(
return Status::OK();
}
Status GetCompileTimeConstInputs(
const NodeDef& node, const OpKernel* op_kernel, const OpDef* op_def,
std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime,
GraphConstArgIndicesCache* cached_arg_indices) {
Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
const OpDef* op_def,
std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime) {
DCHECK(op_def != nullptr || op_kernel != nullptr);
// TODO(b/124403063): Implement similar functionality for function call nodes.
if (node.op() == "While" || node.op() == "StatelessWhile") {
@ -106,12 +106,10 @@ Status GetCompileTimeConstInputs(
std::vector<bool> compile_time_const_arg_indices(num_inputs);
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*(fcond->graph), &compile_time_const_arg_indices,
/*compile_time_const_nodes=*/nullptr, flib_runtime,
[](const Edge&) { return true; }, cached_arg_indices));
/*compile_time_const_nodes=*/nullptr, flib_runtime));
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*(fbody->graph), &compile_time_const_arg_indices,
/*compile_time_const_nodes=*/nullptr, flib_runtime,
[](const Edge&) { return true; }, cached_arg_indices));
/*compile_time_const_nodes=*/nullptr, flib_runtime));
for (int i = 0; i < num_inputs; i++) {
if (compile_time_const_arg_indices[i]) {
// Check that this input is actually a loop invariant.
@ -147,22 +145,6 @@ Status GetCompileTimeConstInputs(
TF_RETURN_IF_ERROR(
GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
} else if (node.op() == "PartitionedCall" ||
node.op() == "StatefulPartitionedCall") {
const FunctionBody* fbody;
TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
int num_inputs = fbody->fdef.signature().input_arg_size();
std::vector<bool> compile_time_const_arg_indices(num_inputs);
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*(fbody->graph), &compile_time_const_arg_indices,
/*compile_time_const_nodes=*/nullptr, flib_runtime,
[](const Edge&) { return true; }, cached_arg_indices));
for (int i = 0; i < num_inputs; i++) {
if (compile_time_const_arg_indices[i]) {
const_input_idxs->push_back(i);
}
}
return Status::OK();
} else if (op_def != nullptr) {
return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
const_input_idxs);
@ -172,13 +154,12 @@ Status GetCompileTimeConstInputs(
}
}
Status GetCompileTimeConstInputs(
const Node* node, std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime,
GraphConstArgIndicesCache* cached_arg_indices) {
Status GetCompileTimeConstInputs(const Node* node,
std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime) {
return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr,
&node->op_def(), const_input_idxs,
flib_runtime, cached_arg_indices);
flib_runtime);
}
} // namespace
@ -189,17 +170,7 @@ Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
FunctionLibraryRuntime* flib_runtime,
std::function<bool(const Edge&)> edge_filter,
GraphConstArgIndicesCache* cached_arg_indices) {
// Avoid exponential runtime by explicit memoization: can do this only
// for the nested calls which don't have `compile_time_const_nodes` set.
if (!compile_time_const_nodes && cached_arg_indices &&
cached_arg_indices->contains(&g)) {
VLOG(3) << "Memoized constant arg indices for the graph: " << &g;
*compile_time_const_arg_indices = cached_arg_indices->at(&g);
return Status::OK();
}
std::function<bool(const Edge&)> edge_filter) {
std::vector<bool> compile_time_const_nodes_impl;
if (compile_time_const_nodes) {
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
@ -208,11 +179,6 @@ Status BackwardsConstAnalysis(const Graph& g,
compile_time_const_nodes = &compile_time_const_nodes_impl;
}
GraphConstArgIndicesCache cached_arg_indices_impl;
if (!cached_arg_indices) {
cached_arg_indices = &cached_arg_indices_impl;
}
Status status;
auto visit = [&](Node* node) {
if (!status.ok()) return;
@ -255,8 +221,7 @@ Status BackwardsConstAnalysis(const Graph& g,
// Mark any compile-time constant operator arguments as const.
std::vector<int> const_input_idxs;
status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime,
cached_arg_indices);
status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
if (!status.ok()) {
return;
@ -287,19 +252,15 @@ Status BackwardsConstAnalysis(const Graph& g,
// acyclic graph.
DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
[](const Edge& edge) { return !edge.src()->IsNextIteration(); });
if (cached_arg_indices && compile_time_const_arg_indices) {
cached_arg_indices->emplace(&g, *compile_time_const_arg_indices);
}
return status;
}
Status GetCompileTimeConstInputs(
const OpKernel* op_kernel, std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime,
GraphConstArgIndicesCache* cached_arg_indices) {
Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime) {
return GetCompileTimeConstInputs(op_kernel->def(), op_kernel,
/*op_def=*/nullptr, const_input_idxs,
flib_runtime, cached_arg_indices);
flib_runtime);
}
} // namespace tensorflow

View File

@ -18,15 +18,11 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
using GraphConstArgIndicesCache =
absl::flat_hash_map<const Graph*, std::vector<bool>>;
// Backwards dataflow analysis that finds nodes in a graph that must be
// compile-time constants for us to be able to lower the graph to XLA.
//
@ -38,24 +34,19 @@ using GraphConstArgIndicesCache =
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
//
// Only propagate const-ness along edges for which `edge_filter` returns true.
//
// `cached_arg_indices` is a memoization cache used for nested invocations on
// function calls, which caches what argument indices need to be constant for
// each associated graph (e.g. called function).
Status BackwardsConstAnalysis(
const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
FunctionLibraryRuntime* flib_runtime,
std::function<bool(const Edge&)> edge_filter =
[](const Edge& e) { return true; },
GraphConstArgIndicesCache* cached_arg_indices = nullptr);
std::function<bool(const Edge&)> edge_filter = [](const Edge& e) {
return true;
});
// Given an op kernel and function library runtime, return all the indices of
// inputs that need to be compile time constant.
Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
std::vector<int>* const_input_idxs,
FunctionLibraryRuntime* flib_runtime,
GraphConstArgIndicesCache* cached_arg_indices);
FunctionLibraryRuntime* flib_runtime);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_

View File

@ -19,14 +19,11 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@ -92,59 +89,6 @@ TEST(ConstAnalysisTest, TopologicalOrder) {
}
}
void TestFunctionCall(bool is_stateful_partitioned_call) {
FunctionDef callee = FunctionDefHelper::Define(
"Callee", {"t:float", "shape:int32"}, {"result:float"}, {},
{{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}});
FunctionDefLibrary flib;
*flib.add_function() = callee;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
Scope root = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0);
auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1);
NameAttrList call_attrs;
call_attrs.set_name("Callee");
if (is_stateful_partitioned_call) {
ops::StatefulPartitionedCall b(root.WithOpName("Call"),
{Output(arg0), Output(arg1)}, {DT_FLOAT},
call_attrs);
} else {
ops::PartitionedCall b(root.WithOpName("Call"),
{Output(arg0), Output(arg1)}, {DT_FLOAT},
call_attrs);
}
Graph graph(&flib_def);
TF_ASSERT_OK(root.ToGraph(&graph));
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
/*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &flib_def, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
std::vector<bool> const_args(2, false);
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
/*compile_time_const_nodes=*/nullptr,
lib_runtime));
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
}
TEST(ConstAnalysisTest, PartitionedCall) {
TestFunctionCall(/*is_stateful_partitioned_call=*/false);
}
TEST(ConstAnalysisTest, StatefulPartitionedCall) {
TestFunctionCall(/*is_stateful_partitioned_call=*/true);
}
TEST(ConstAnalysisTest, DontFollowControlDependencies) {
Scope root = Scope::NewRootScope();

View File

@ -232,24 +232,6 @@ class DefFunctionTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'):
c.f1(inputs)
def testMustBeConstantPropagation(self):
if test.is_built_with_rocm():
return
@def_function.function(experimental_compile=True)
def f():
return constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@def_function.function(experimental_compile=True)
def g(a, b):
return array_ops.transpose(a, b)
@def_function.function
def z():
return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f())
z()
if __name__ == '__main__':
ops.enable_eager_execution()