STT-tensorflow/tensorflow/compiler/tf2xla/const_analysis_test.cc
George Karpenkov f3dcd9dc11 Support interprocedural constant meta-information propagation for compilation
This CL does two things:

1) Supports inter-procedural constant information propagation, across
PartitionedCall and StatefulPartitionedCall.

2) Done naively, (1) leads to exponential number of calls, as each function
will be reinlined for each (indirect) caller.
In order to address this performance issue, we cache the argument indices which
need to be constant, and attach that information to the Graph object.

This might require some clarification:

a) Caching in a passed map would not work, as duplication of constant
propagation for each top-level caller is still prohibitively expensive.

b) Caching in a global object would not work, as graphs are created and
destroyed during transformations.

c) Caching this meta-information on a `Graph` object has an added benefit that
we no longer perform the same constant propagation many times (a lot of
compilation passes call BackwardsConstAnalysis, and previously all this work
had to be repeated).

PiperOrigin-RevId: 303860413
Change-Id: I78f92ca1487fc952044e5ac6526dcaa5b50d5f21
2020-03-30 17:51:05 -07:00

222 lines
8.3 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Tests for the backward const analysis.
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
TEST(ConstAnalysisTest, Basics) {
Scope root = Scope::NewRootScope();
auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
auto arg3 = ops::_Arg(root.WithOpName("Arg3"), DT_INT32, 3);
auto a = ops::Shape(root, arg0);
auto b = ops::Add(root, a, arg1);
auto c = ops::Reshape(root, arg2, b);
auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3));
FixupSourceAndSinkEdges(root.graph());
std::vector<bool> const_args(4, false);
std::vector<bool> const_nodes(root.graph()->num_node_ids(), false);
TF_ASSERT_OK(BackwardsConstAnalysis(*root.graph(), &const_args, &const_nodes,
/*flib_runtime=*/nullptr));
// Arg 0 doesn't need to be constant since the graph only uses its shape.
// Arg 1 must be constant because it flows to the shape argument of a Reshape.
// Arg 2 is used only as the value input to a Reshape and need not be const.
// Arg 3 is used as the reduction-indices argument to Sum and must be const.
EXPECT_EQ(const_args, std::vector<bool>({false, true, false, true}));
EXPECT_FALSE(const_nodes[arg0.node()->id()]);
EXPECT_TRUE(const_nodes[arg1.node()->id()]);
EXPECT_FALSE(const_nodes[arg2.node()->id()]);
EXPECT_TRUE(const_nodes[arg3.node()->id()]);
}
// Regression test for a case where the backward const analysis did
// not visit nodes in topological order.
TEST(ConstAnalysisTest, TopologicalOrder) {
for (bool order : {false, true}) {
Scope root = Scope::NewRootScope();
auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
auto a = ops::Reshape(root, arg0, arg1);
auto b = ops::Reshape(root, arg2, a);
if (order) {
// Consider both orders for arguments to the Sum so we aren't sensitive
// to the DFS traversal order.
std::swap(a, b);
}
auto c = ops::Add(root, a, b);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(3, false);
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
/*compile_time_const_nodes=*/nullptr,
/*flib_runtime=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({true, true, false}));
}
}
void TestFunctionCall(bool is_stateful_partitioned_call) {
FunctionDef callee = FunctionDefHelper::Define(
"Callee", {"t:float", "shape:int32"}, {"result:float"}, {},
{{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}});
FunctionDefLibrary flib;
*flib.add_function() = callee;
FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
Scope root = Scope::NewRootScope().ExitOnError();
auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0);
auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1);
NameAttrList call_attrs;
call_attrs.set_name("Callee");
if (is_stateful_partitioned_call) {
ops::StatefulPartitionedCall b(root.WithOpName("Call"),
{Output(arg0), Output(arg1)}, {DT_FLOAT},
call_attrs);
} else {
ops::PartitionedCall b(root.WithOpName("Call"),
{Output(arg0), Output(arg1)}, {DT_FLOAT},
call_attrs);
}
Graph graph(&flib_def);
TF_ASSERT_OK(root.ToGraph(&graph));
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
/*config=*/nullptr,
TF_GRAPH_DEF_VERSION, &flib_def, opts));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
std::vector<bool> const_args(2, false);
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
/*compile_time_const_nodes=*/nullptr,
lib_runtime));
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
}
TEST(ConstAnalysisTest, PartitionedCall) {
TestFunctionCall(/*is_stateful_partitioned_call=*/false);
}
TEST(ConstAnalysisTest, StatefulPartitionedCall) {
TestFunctionCall(/*is_stateful_partitioned_call=*/true);
}
TEST(ConstAnalysisTest, DontFollowControlDependencies) {
Scope root = Scope::NewRootScope();
Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
Output c1 =
ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
Output add = ops::Add(root, arg1, c1);
Output reshape = ops::Reshape(root, arg1, add);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(2, false);
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
/*compile_time_const_nodes=*/nullptr,
/*flib_runtime=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({false, true}));
}
TEST(ConstAnalysisTest, RespectExplicitAttr_0) {
Scope root = Scope::NewRootScope();
Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
Output arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
Output c1 =
ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
Output add = ops::Add(root, arg1, c1);
// Force const analysis to pretend that the shape argument to `reshape` does
// not need to be a constant.
Output reshape = ops::Reshape(root, arg1, add);
reshape.node()->AddAttr(kXlaCompileTimeConstantInputsAttr,
std::vector<string>());
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(2, false);
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
/*compile_time_const_nodes=*/nullptr,
/*flib_runtime=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({false, false}));
}
TEST(ConstAnalysisTest, RespectExplicitAttr_1) {
Scope root = Scope::NewRootScope();
Output arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
Output c1 =
ops::Const(root.WithOpName("c1").WithControlDependencies(arg0), 1, {1});
Output add = ops::Add(root, arg0, c1);
// Force const analysis to pretend that the first argument to `add` needs to
// be a constant.
std::vector<string> add_constant_inputs;
add_constant_inputs.push_back("x");
add.node()->AddAttr(kXlaCompileTimeConstantInputsAttr, add_constant_inputs);
Graph graph(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph));
std::vector<bool> const_args(1, false);
TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args,
/*compile_time_const_nodes=*/nullptr,
/*flib_runtime=*/nullptr));
EXPECT_EQ(const_args, std::vector<bool>({true}));
}
} // namespace
} // namespace tensorflow