Decluster some must-be-constant ops to reduce XLA recompilations

The CL is organized as follows:

 - The main change is in jit/partially_decluster_pass.
 - tf2xla/const_analysis now takes an "edge_filter" to facilitate use by
   jit/partially_decluster_pass.
 - tests/dense_layer_test.py was using the execution of ListDiff as what I
   assume is a sanity check to see that the XLA cluster ran.  With this CL the
   ListDiff op gets declustered so we now check for "MatMult" for the sanity
   check.
 - Some tests were dropping TF_XLA_FLAGS; fixed them to not do so.

PiperOrigin-RevId: 212071118
This commit is contained in:
Sanjoy Das 2018-09-07 18:47:56 -07:00 committed by TensorFlower Gardener
parent 3e1b06ee93
commit 4fd48f57cd
10 changed files with 326 additions and 62 deletions

View File

@ -395,6 +395,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)
@ -480,6 +481,7 @@ tf_cc_test(
":common",
":compilation_passes",
":xla_cluster_util",
":xla_gpu_device",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
@ -496,6 +498,8 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)

View File

@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
return Status::OK();
}
} // namespace
Status PartiallyDeclusterPass::Run(
const GraphOptimizationPassOptions& options) {
// NB! In this pass we assume the only XLA-auto-clusterable operations that
// may have side effects are resource variable operations so we don't cluster
// those. The pass will have to be updated if this assumption becomes
// invalid.
Graph* graph = options.graph->get();
bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
// Clones nodes to outside their cluster to avoid device-to-host copies. For
// instance, converts this:
//
// .....
// |
// v
// A_Clustered ====> C_Unclustered
// |
// v
// B_Clustered
//
// to:
//
// .....
// | |
// | +-------------+
// | |
// v v
// A_Clustered A_Unclustered ====> C_Unclustered
// |
// v
// B_Clustered
//
// where the ===> arrow has a hostmem source and destination and would entail a
// device to host copy if the source and destination were not in the same XLA
// cluster.
Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
// When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been
// visited before producers.
std::vector<Node*> post_order;
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
/*edge_filter=*/NotBackedge);
gtl::FlatSet<Node*> nodes_to_partially_decluster;
TF_RETURN_IF_ERROR(FindNodesToDecluster(
**options.graph, &nodes_to_partially_decluster, post_order));
TF_RETURN_IF_ERROR(
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
if (VLOG_IS_ON(3)) {
for (Node* n : post_order) {
@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run(
}
nodes_to_partially_decluster.clear();
TF_RETURN_IF_ERROR(FindNodesToDecluster(
**options.graph, &nodes_to_partially_decluster, post_order));
TF_RETURN_IF_ERROR(
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
CHECK(nodes_to_partially_decluster.empty());
return Status::OK();
}
bool IsIntraClusterEdge(const Edge& edge) {
absl::optional<absl::string_view> src_cluster_name =
GetXlaClusterForNode(*edge.src());
absl::optional<absl::string_view> dst_cluster_name =
GetXlaClusterForNode(*edge.dst());
return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
}
Status MustCompileNode(const Node* n, bool* result) {
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(n->assigned_device_name(), &device_type));
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
*result = false;
} else {
*result = registration->requires_compilation;
}
return Status::OK();
}
// Declusters nodes to reduce the number of times we think we need to recompile
// a TensorFlow graph.
//
// Abstractly, if we have a cluster of this form:
//
// x0 = arg0
// x1 = arg1
// ...
// shape = f(x0, x1, ...)
// result = Reshape(input=<something>, new_shape=shape)
//
// then pulling `f` out of the cluster may reduce the number of compilations and
// will never increase the number of compilations.
//
// We may reduce the number of compilations if f is many to one. For instance
// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
// compilations if f is in the cluster but only one compilation if f is outside
// the cluster.
//
// Declustering f will increase the number of compilations only if f is a
// one-to-many "function" i.e. isn't a function at all. RNG is one possible
// example, depending on how we look at it. But we never create clusters where
// such f's would be marked as must-be-constant.
//
// We assume here that the extra repeated (repeated compared to a clustered f
// where it will always be constant folded) host-side computation of f does not
// regress performance in any significant manner. We will have to revisit this
// algorith with a more complex cost model if this assumption turns out to be
// incorrect.
Status DeclusterNodesToReduceRecompilations(Graph* graph) {
std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
*graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
std::vector<Node*> rpo;
GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
for (Node* n : rpo) {
if (!compile_time_const_nodes[n->id()]) {
continue;
}
absl::string_view cluster_name = *GetXlaClusterForNode(*n);
bool node_on_cluster_edge =
absl::c_all_of(n->in_edges(), [&](const Edge* e) {
absl::optional<absl::string_view> incoming_cluster =
GetXlaClusterForNode(*e->src());
return !incoming_cluster || *incoming_cluster != cluster_name;
});
// We don't want to decluster F in a graph like
//
// Input -> OP -> Shape -> F -> Reshape
//
// Doing so will break up the cluster. Even if we were okay with breaking
// up the cluster we will at least have to relabel the two clusters to have
// different cluster names.
//
// We may want to revisit this in the future: we may have cases where OP is
// a small computation that does not benefit from XLA while XLA can optimize
// everything that follows the Reshape. In these cases it may be wise to
// remove Input, OP, Shape and F from the cluster, if F is a many-to-one
// function.
//
// Note that we do do the right thing for graphs like:
//
// Input -> F0 -> F1 -> Reshape
//
// Since we iterate in RPO, we'll first encounter F0, decluster it, then
// encounter F1, decluster it and so on.
if (node_on_cluster_edge) {
bool must_compile_node;
TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
if (!must_compile_node) {
VLOG(3) << "Declustering must-be-constant node " << n->name();
RemoveFromXlaCluster(n);
}
}
}
return Status::OK();
}
} // namespace
Status PartiallyDeclusterPass::Run(
const GraphOptimizationPassOptions& options) {
// NB! In this pass we assume the only XLA-auto-clusterable operations that
// may have side effects are resource variable operations so we don't cluster
// those. The pass will have to be updated if this assumption becomes
// invalid.
Graph* graph = options.graph->get();
TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
return Status::OK();
}
} // namespace tensorflow

View File

@ -20,34 +20,11 @@ limitations under the License.
namespace tensorflow {
// Clones nodes from within a cluster to outside the cluster if profitable.
// Clones or moves nodes from within a cluster to outside the cluster if
// profitable. There are two reasons why we do this:
//
// Today this only clones to avoid device-to-host copies, but in the future we
// may consider other reasons to clone. For instance, we convert this:
//
// .....
// |
// v
// A_Clustered ====> C_Unclustered
// |
// v
// B_Clustered
//
// to:
//
// .....
// | |
// | +-------------+
// | |
// v v
// A_Clustered A_Unclustered ====> C_Unclustered
// |
// v
// B_Clustered
//
// where the ===> arrow has a hostmem source and destination and would entail a
// device to host copy if the source and destination were not in the same XLA
// cluster.
// - Reducing device-to-host copies.
// - Reducing the number of XLA recompilations.
class PartiallyDeclusterPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -82,8 +84,10 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
if (n->assigned_device_name().empty()) {
n->set_assigned_device_name(kCpuDevice);
}
}
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
return pass.Run(opt_options);
}
const Node* FindNodeByName(const Graph& graph, const string& name) {
for (const Node* node : graph.nodes()) {
Node* FindNodeByName(const Graph& graph, const string& name) {
for (Node* node : graph.nodes()) {
if (node->name() == name) {
return node;
}
@ -279,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
"ClusteredProducer0/declustered");
EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
}
void AddToCluster(absl::Span<Node* const> nodes,
absl::string_view cluster_name) {
for (Node* n : nodes) {
n->AddAttr(kXlaClusterAttr, string(cluster_name));
}
}
TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
ops::Placeholder::Attrs{});
Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
ops::Placeholder::Attrs{});
Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
DT_FLOAT, ops::Placeholder::Attrs{});
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
AddToCluster({shape.node(), reshape.node()}, "cluster_0");
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(graph.get()));
TF_ASSERT_OK(PartiallyDecluster(&graph));
const Node* n = FindNodeByName(*graph, "shape");
ASSERT_NE(n, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt);
}
TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
ops::Placeholder::Attrs{});
Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
ops::Placeholder::Attrs{});
Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
DT_FLOAT, ops::Placeholder::Attrs{});
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
"cluster_0");
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(graph.get()));
TF_ASSERT_OK(PartiallyDecluster(&graph));
const Node* n = FindNodeByName(*graph, "shape");
ASSERT_NE(n, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
}
TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
ops::Placeholder::Attrs{});
Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
ops::Placeholder::Attrs{});
Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
DT_FLOAT, ops::Placeholder::Attrs{});
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
AddToCluster({reshape.node()}, "cluster_0");
AddToCluster({shape.node()}, "cluster_1");
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(graph.get()));
TF_ASSERT_OK(PartiallyDecluster(&graph));
const Node* n = FindNodeByName(*graph, "shape");
ASSERT_NE(n, nullptr);
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
}
TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
ops::Placeholder::Attrs{});
Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
ops::Placeholder::Attrs{});
Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
DT_FLOAT, ops::Placeholder::Attrs{});
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
AddToCluster({shape.node(), reshape.node()}, "cluster_0");
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(graph.get()));
// This is needed to register the XLA_GPU device.
std::vector<Device*> devices;
TF_ASSERT_OK(DeviceFactory::AddDevices(
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
// Scope::ToGraph loses the assigned device name since it goes through
// GraphDef/NodeDef which does not have a field for the assigned device name.
Node* n = FindNodeByName(*graph, "shape");
ASSERT_NE(n, nullptr);
n->set_assigned_device_name(
"/job:localhost/replica:0/task:0/device:XLA_GPU:0");
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
for (Device* d : devices) {
delete d;
}
}
} // namespace
} // namespace tensorflow

View File

@ -210,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
node_def->mutable_attr()->erase(kXlaClusterAttr);
}
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
Status AdjustCycleDetectionGraphForResourceOps(
const Graph* graph, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,

View File

@ -53,6 +53,9 @@ absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
void RemoveFromXlaCluster(NodeDef* node_def);
// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute).
void RemoveFromXlaCluster(Node* node);
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);

View File

@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase):
Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
os.environ["TF_XLA_FLAGS"] = (
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
config = config_pb2.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
config_pb2.OptimizerOptions.ON_1)
@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(1, XlaLaunchOpCount(labels))
self.assertFalse(InLabels(labels, "ListDiff"))
self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(2, XlaLaunchOpCount(labels))
self.assertFalse(InLabels(labels, "ListDiff"))
self.assertFalse(InLabels(labels, "MatMult"))
if __name__ == "__main__":

View File

@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase):
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
arg1 = np.random.rand(2, 2).astype(np.float32)
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
"--tf_xla_cpu_global_jit")
os.environ["TF_XLA_FLAGS"] = (
"--tf_xla_fusion_only=true "
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
tf_op, tf_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.OFF)
self.assertEqual(0, tf_count)

View File

@ -26,8 +26,9 @@ namespace tensorflow {
// Backwards dataflow analysis that finds arguments to a graph that must be
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_args,
std::vector<bool>* compile_time_const_nodes) {
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
std::function<bool(const Edge&)> edge_filter) {
// Operators that don't look at the data of their inputs, just the shapes.
const std::unordered_set<string> metadata_ops = {
"Rank",
@ -45,8 +46,7 @@ Status BackwardsConstAnalysis(const Graph& g,
}
Status status;
auto visit = [&status, &metadata_ops, compile_time_const_nodes,
compile_time_const_args](Node* node) {
auto visit = [&](Node* node) {
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
@ -59,13 +59,13 @@ Status BackwardsConstAnalysis(const Graph& g,
int index;
status = GetNodeAttr(node->attrs(), "index", &index);
if (!status.ok()) return;
if (compile_time_const_args) {
(*compile_time_const_args)[index] = true;
if (compile_time_const_arg_indices) {
(*compile_time_const_arg_indices)[index] = true;
}
return;
}
for (const Edge* pred : node->in_edges()) {
if (!pred->IsControlEdge()) {
if (!pred->IsControlEdge() && edge_filter(*pred)) {
(*compile_time_const_nodes)[pred->src()->id()] = true;
}
}
@ -88,7 +88,8 @@ Status BackwardsConstAnalysis(const Graph& g,
for (Edge const* edge : node->in_edges()) {
if (edge->dst_input() >= name_range->second.first &&
edge->dst_input() < name_range->second.second) {
edge->dst_input() < name_range->second.second &&
edge_filter(*edge)) {
(*compile_time_const_nodes)[edge->src()->id()] = true;
}
}
@ -97,7 +98,8 @@ Status BackwardsConstAnalysis(const Graph& g,
// Post-order traversal visits nodes in reverse topological order for an
// acyclic graph.
DFS(g, {}, visit);
DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
[](const Edge& edge) { return !edge.src()->IsNextIteration(); });
return status;
}

View File

@ -32,9 +32,13 @@ namespace tensorflow {
//
// The ids of the nodes in `graph` that must be constant are returned in
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
Status BackwardsConstAnalysis(const Graph& graph,
//
// Only propagate const-ness along edges for which `edge_filter` returns true.
Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes);
std::vector<bool>* compile_time_const_nodes,
std::function<bool(const Edge&)> edge_filter =
[](const Edge& e) { return true; });
} // namespace tensorflow