Const analysis should peek into PartitionedCall and StatefulPartitionedCall.
PiperOrigin-RevId: 300571637 Change-Id: I4aef56c80e8bd2f14152f49aaf69778c2d916315
This commit is contained in:
parent
11edb2ffe4
commit
fbbb83b995
@ -99,8 +99,7 @@ Status XlaCompileOnDemandOp::MustArgumentBeConstant(
|
||||
// TODO(jmolloy): This could be expensive, so memoize.
|
||||
std::vector<int> constant_input_indices;
|
||||
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
|
||||
op_kernel, &constant_input_indices, flib_runtime,
|
||||
/*cached_arg_indices=*/nullptr));
|
||||
op_kernel, &constant_input_indices, flib_runtime));
|
||||
*result = absl::c_binary_search(constant_input_indices, argument_idx);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -373,7 +373,6 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -611,12 +610,10 @@ tf_cc_test(
|
||||
":xla_compiler",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -85,10 +85,10 @@ Status CondConstInputIndices(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetCompileTimeConstInputs(
|
||||
const NodeDef& node, const OpKernel* op_kernel, const OpDef* op_def,
|
||||
std::vector<int>* const_input_idxs, FunctionLibraryRuntime* flib_runtime,
|
||||
GraphConstArgIndicesCache* cached_arg_indices) {
|
||||
Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel,
|
||||
const OpDef* op_def,
|
||||
std::vector<int>* const_input_idxs,
|
||||
FunctionLibraryRuntime* flib_runtime) {
|
||||
DCHECK(op_def != nullptr || op_kernel != nullptr);
|
||||
// TODO(b/124403063): Implement similar functionality for function call nodes.
|
||||
if (node.op() == "While" || node.op() == "StatelessWhile") {
|
||||
@ -106,12 +106,10 @@ Status GetCompileTimeConstInputs(
|
||||
std::vector<bool> compile_time_const_arg_indices(num_inputs);
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||
*(fcond->graph), &compile_time_const_arg_indices,
|
||||
/*compile_time_const_nodes=*/nullptr, flib_runtime,
|
||||
[](const Edge&) { return true; }, cached_arg_indices));
|
||||
/*compile_time_const_nodes=*/nullptr, flib_runtime));
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||
*(fbody->graph), &compile_time_const_arg_indices,
|
||||
/*compile_time_const_nodes=*/nullptr, flib_runtime,
|
||||
[](const Edge&) { return true; }, cached_arg_indices));
|
||||
/*compile_time_const_nodes=*/nullptr, flib_runtime));
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
if (compile_time_const_arg_indices[i]) {
|
||||
// Check that this input is actually a loop invariant.
|
||||
@ -147,22 +145,6 @@ Status GetCompileTimeConstInputs(
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies));
|
||||
return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime);
|
||||
} else if (node.op() == "PartitionedCall" ||
|
||||
node.op() == "StatefulPartitionedCall") {
|
||||
const FunctionBody* fbody;
|
||||
TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody));
|
||||
int num_inputs = fbody->fdef.signature().input_arg_size();
|
||||
std::vector<bool> compile_time_const_arg_indices(num_inputs);
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||
*(fbody->graph), &compile_time_const_arg_indices,
|
||||
/*compile_time_const_nodes=*/nullptr, flib_runtime,
|
||||
[](const Edge&) { return true; }, cached_arg_indices));
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
if (compile_time_const_arg_indices[i]) {
|
||||
const_input_idxs->push_back(i);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (op_def != nullptr) {
|
||||
return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def,
|
||||
const_input_idxs);
|
||||
@ -172,13 +154,12 @@ Status GetCompileTimeConstInputs(
|
||||
}
|
||||
}
|
||||
|
||||
Status GetCompileTimeConstInputs(
|
||||
const Node* node, std::vector<int>* const_input_idxs,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
GraphConstArgIndicesCache* cached_arg_indices) {
|
||||
Status GetCompileTimeConstInputs(const Node* node,
|
||||
std::vector<int>* const_input_idxs,
|
||||
FunctionLibraryRuntime* flib_runtime) {
|
||||
return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr,
|
||||
&node->op_def(), const_input_idxs,
|
||||
flib_runtime, cached_arg_indices);
|
||||
flib_runtime);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -189,17 +170,7 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
std::vector<bool>* compile_time_const_arg_indices,
|
||||
std::vector<bool>* compile_time_const_nodes,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
std::function<bool(const Edge&)> edge_filter,
|
||||
GraphConstArgIndicesCache* cached_arg_indices) {
|
||||
// Avoid exponential runtime by explicit memoization: can do this only
|
||||
// for the nested calls which don't have `compile_time_const_nodes` set.
|
||||
if (!compile_time_const_nodes && cached_arg_indices &&
|
||||
cached_arg_indices->contains(&g)) {
|
||||
VLOG(3) << "Memoized constant arg indices for the graph: " << &g;
|
||||
*compile_time_const_arg_indices = cached_arg_indices->at(&g);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::function<bool(const Edge&)> edge_filter) {
|
||||
std::vector<bool> compile_time_const_nodes_impl;
|
||||
if (compile_time_const_nodes) {
|
||||
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
|
||||
@ -208,11 +179,6 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
compile_time_const_nodes = &compile_time_const_nodes_impl;
|
||||
}
|
||||
|
||||
GraphConstArgIndicesCache cached_arg_indices_impl;
|
||||
if (!cached_arg_indices) {
|
||||
cached_arg_indices = &cached_arg_indices_impl;
|
||||
}
|
||||
|
||||
Status status;
|
||||
auto visit = [&](Node* node) {
|
||||
if (!status.ok()) return;
|
||||
@ -255,8 +221,7 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
|
||||
// Mark any compile-time constant operator arguments as const.
|
||||
std::vector<int> const_input_idxs;
|
||||
status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime,
|
||||
cached_arg_indices);
|
||||
status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime);
|
||||
|
||||
if (!status.ok()) {
|
||||
return;
|
||||
@ -287,19 +252,15 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
// acyclic graph.
|
||||
DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
|
||||
[](const Edge& edge) { return !edge.src()->IsNextIteration(); });
|
||||
if (cached_arg_indices && compile_time_const_arg_indices) {
|
||||
cached_arg_indices->emplace(&g, *compile_time_const_arg_indices);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status GetCompileTimeConstInputs(
|
||||
const OpKernel* op_kernel, std::vector<int>* const_input_idxs,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
GraphConstArgIndicesCache* cached_arg_indices) {
|
||||
Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
|
||||
std::vector<int>* const_input_idxs,
|
||||
FunctionLibraryRuntime* flib_runtime) {
|
||||
return GetCompileTimeConstInputs(op_kernel->def(), op_kernel,
|
||||
/*op_def=*/nullptr, const_input_idxs,
|
||||
flib_runtime, cached_arg_indices);
|
||||
flib_runtime);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -18,15 +18,11 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using GraphConstArgIndicesCache =
|
||||
absl::flat_hash_map<const Graph*, std::vector<bool>>;
|
||||
|
||||
// Backwards dataflow analysis that finds nodes in a graph that must be
|
||||
// compile-time constants for us to be able to lower the graph to XLA.
|
||||
//
|
||||
@ -38,24 +34,19 @@ using GraphConstArgIndicesCache =
|
||||
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
|
||||
//
|
||||
// Only propagate const-ness along edges for which `edge_filter` returns true.
|
||||
//
|
||||
// `cached_arg_indices` is a memoization cache used for nested invocations on
|
||||
// function calls, which caches what argument indices need to be constant for
|
||||
// each associated graph (e.g. called function).
|
||||
Status BackwardsConstAnalysis(
|
||||
const Graph& g, std::vector<bool>* compile_time_const_arg_indices,
|
||||
std::vector<bool>* compile_time_const_nodes,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
std::function<bool(const Edge&)> edge_filter =
|
||||
[](const Edge& e) { return true; },
|
||||
GraphConstArgIndicesCache* cached_arg_indices = nullptr);
|
||||
std::function<bool(const Edge&)> edge_filter = [](const Edge& e) {
|
||||
return true;
|
||||
});
|
||||
|
||||
// Given an op kernel and function library runtime, return all the indices of
|
||||
// inputs that need to be compile time constant.
|
||||
Status GetCompileTimeConstInputs(const OpKernel* op_kernel,
|
||||
std::vector<int>* const_input_idxs,
|
||||
FunctionLibraryRuntime* flib_runtime,
|
||||
GraphConstArgIndicesCache* cached_arg_indices);
|
||||
FunctionLibraryRuntime* flib_runtime);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_
|
||||
|
@ -19,14 +19,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -92,59 +89,6 @@ TEST(ConstAnalysisTest, TopologicalOrder) {
|
||||
}
|
||||
}
|
||||
|
||||
void TestFunctionCall(bool is_stateful_partitioned_call) {
|
||||
FunctionDef callee = FunctionDefHelper::Define(
|
||||
"Callee", {"t:float", "shape:int32"}, {"result:float"}, {},
|
||||
{{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}});
|
||||
|
||||
FunctionDefLibrary flib;
|
||||
*flib.add_function() = callee;
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0);
|
||||
auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1);
|
||||
|
||||
NameAttrList call_attrs;
|
||||
call_attrs.set_name("Callee");
|
||||
if (is_stateful_partitioned_call) {
|
||||
ops::StatefulPartitionedCall b(root.WithOpName("Call"),
|
||||
{Output(arg0), Output(arg1)}, {DT_FLOAT},
|
||||
call_attrs);
|
||||
} else {
|
||||
ops::PartitionedCall b(root.WithOpName("Call"),
|
||||
{Output(arg0), Output(arg1)}, {DT_FLOAT},
|
||||
call_attrs);
|
||||
}
|
||||
|
||||
Graph graph(&flib_def);
|
||||
TF_ASSERT_OK(root.ToGraph(&graph));
|
||||
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
|
||||
/*config=*/nullptr,
|
||||
TF_GRAPH_DEF_VERSION, &flib_def, opts));
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
std::vector<bool> const_args(2, false);
|
||||
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
|
||||
/*compile_time_const_nodes=*/nullptr,
|
||||
lib_runtime));
|
||||
|
||||
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
|
||||
}
|
||||
|
||||
TEST(ConstAnalysisTest, PartitionedCall) {
|
||||
TestFunctionCall(/*is_stateful_partitioned_call=*/false);
|
||||
}
|
||||
|
||||
TEST(ConstAnalysisTest, StatefulPartitionedCall) {
|
||||
TestFunctionCall(/*is_stateful_partitioned_call=*/true);
|
||||
}
|
||||
|
||||
TEST(ConstAnalysisTest, DontFollowControlDependencies) {
|
||||
Scope root = Scope::NewRootScope();
|
||||
|
||||
|
@ -232,24 +232,6 @@ class DefFunctionTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'):
|
||||
c.f1(inputs)
|
||||
|
||||
def testMustBeConstantPropagation(self):
|
||||
if test.is_built_with_rocm():
|
||||
return
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f():
|
||||
return constant_op.constant([0, 2, 1], dtype=dtypes.int32)
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def g(a, b):
|
||||
return array_ops.transpose(a, b)
|
||||
|
||||
@def_function.function
|
||||
def z():
|
||||
return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f())
|
||||
|
||||
z()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user