Decluster some must-be-constant ops to reduce XLA recompilations
The CL is organized as follows: - The main change is in jit/partially_decluster_pass. - tf2xla/const_analysis now takes an "edge_filter" to facilitate use by jit/partially_decluster_pass. - tests/dense_layer_test.py was using the execution of ListDiff as what I assume is a sanity check to see that the XLA cluster ran. With this CL the ListDiff op gets declustered so we now check for "MatMult" for the sanity check. - Some tests were dropping TF_XLA_FLAGS; fixed them to not do so. PiperOrigin-RevId: 212071118
This commit is contained in:
parent
3e1b06ee93
commit
4fd48f57cd
@ -395,6 +395,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -480,6 +481,7 @@ tf_cc_test(
|
||||
":common",
|
||||
":compilation_passes",
|
||||
":xla_cluster_util",
|
||||
":xla_gpu_device",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
@ -496,6 +498,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler/optimizers/data:graph_utils",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -14,8 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status PartiallyDeclusterPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
// NB! In this pass we assume the only XLA-auto-clusterable operations that
|
||||
// may have side effects are resource variable operations so we don't cluster
|
||||
// those. The pass will have to be updated if this assumption becomes
|
||||
// invalid.
|
||||
|
||||
Graph* graph = options.graph->get();
|
||||
bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
|
||||
|
||||
// Clones nodes to outside their cluster to avoid device-to-host copies. For
|
||||
// instance, converts this:
|
||||
//
|
||||
// .....
|
||||
// |
|
||||
// v
|
||||
// A_Clustered ====> C_Unclustered
|
||||
// |
|
||||
// v
|
||||
// B_Clustered
|
||||
//
|
||||
// to:
|
||||
//
|
||||
// .....
|
||||
// | |
|
||||
// | +-------------+
|
||||
// | |
|
||||
// v v
|
||||
// A_Clustered A_Unclustered ====> C_Unclustered
|
||||
// |
|
||||
// v
|
||||
// B_Clustered
|
||||
//
|
||||
// where the ===> arrow has a hostmem source and destination and would entail a
|
||||
// device to host copy if the source and destination were not in the same XLA
|
||||
// cluster.
|
||||
Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
|
||||
// When deciding whether to decluster a particular node, we base our decision
|
||||
// on if we've decided that some of its consumers have to be declustered too.
|
||||
// Iterating the graph in post-order guarantees that consumers have been
|
||||
// visited before producers.
|
||||
std::vector<Node*> post_order;
|
||||
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
|
||||
/*edge_filter=*/[](const Edge& edge) {
|
||||
return !edge.src()->IsNextIteration();
|
||||
});
|
||||
/*edge_filter=*/NotBackedge);
|
||||
|
||||
gtl::FlatSet<Node*> nodes_to_partially_decluster;
|
||||
TF_RETURN_IF_ERROR(FindNodesToDecluster(
|
||||
**options.graph, &nodes_to_partially_decluster, post_order));
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
for (Node* n : post_order) {
|
||||
@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run(
|
||||
}
|
||||
|
||||
nodes_to_partially_decluster.clear();
|
||||
TF_RETURN_IF_ERROR(FindNodesToDecluster(
|
||||
**options.graph, &nodes_to_partially_decluster, post_order));
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
|
||||
CHECK(nodes_to_partially_decluster.empty());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsIntraClusterEdge(const Edge& edge) {
|
||||
absl::optional<absl::string_view> src_cluster_name =
|
||||
GetXlaClusterForNode(*edge.src());
|
||||
absl::optional<absl::string_view> dst_cluster_name =
|
||||
GetXlaClusterForNode(*edge.dst());
|
||||
return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
|
||||
}
|
||||
|
||||
Status MustCompileNode(const Node* n, bool* result) {
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceToDeviceType(n->assigned_device_name(), &device_type));
|
||||
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
*result = false;
|
||||
} else {
|
||||
*result = registration->requires_compilation;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Declusters nodes to reduce the number of times we think we need to recompile
|
||||
// a TensorFlow graph.
|
||||
//
|
||||
// Abstractly, if we have a cluster of this form:
|
||||
//
|
||||
// x0 = arg0
|
||||
// x1 = arg1
|
||||
// ...
|
||||
// shape = f(x0, x1, ...)
|
||||
// result = Reshape(input=<something>, new_shape=shape)
|
||||
//
|
||||
// then pulling `f` out of the cluster may reduce the number of compilations and
|
||||
// will never increase the number of compilations.
|
||||
//
|
||||
// We may reduce the number of compilations if f is many to one. For instance
|
||||
// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
|
||||
// compilations if f is in the cluster but only one compilation if f is outside
|
||||
// the cluster.
|
||||
//
|
||||
// Declustering f will increase the number of compilations only if f is a
|
||||
// one-to-many "function" i.e. isn't a function at all. RNG is one possible
|
||||
// example, depending on how we look at it. But we never create clusters where
|
||||
// such f's would be marked as must-be-constant.
|
||||
//
|
||||
// We assume here that the extra repeated (repeated compared to a clustered f
|
||||
// where it will always be constant folded) host-side computation of f does not
|
||||
// regress performance in any significant manner. We will have to revisit this
|
||||
// algorith with a more complex cost model if this assumption turns out to be
|
||||
// incorrect.
|
||||
Status DeclusterNodesToReduceRecompilations(Graph* graph) {
|
||||
std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
|
||||
TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
|
||||
*graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
|
||||
|
||||
std::vector<Node*> rpo;
|
||||
GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
|
||||
/*edge_filter=*/NotBackedge);
|
||||
for (Node* n : rpo) {
|
||||
if (!compile_time_const_nodes[n->id()]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
absl::string_view cluster_name = *GetXlaClusterForNode(*n);
|
||||
bool node_on_cluster_edge =
|
||||
absl::c_all_of(n->in_edges(), [&](const Edge* e) {
|
||||
absl::optional<absl::string_view> incoming_cluster =
|
||||
GetXlaClusterForNode(*e->src());
|
||||
return !incoming_cluster || *incoming_cluster != cluster_name;
|
||||
});
|
||||
|
||||
// We don't want to decluster F in a graph like
|
||||
//
|
||||
// Input -> OP -> Shape -> F -> Reshape
|
||||
//
|
||||
// Doing so will break up the cluster. Even if we were okay with breaking
|
||||
// up the cluster we will at least have to relabel the two clusters to have
|
||||
// different cluster names.
|
||||
//
|
||||
// We may want to revisit this in the future: we may have cases where OP is
|
||||
// a small computation that does not benefit from XLA while XLA can optimize
|
||||
// everything that follows the Reshape. In these cases it may be wise to
|
||||
// remove Input, OP, Shape and F from the cluster, if F is a many-to-one
|
||||
// function.
|
||||
//
|
||||
// Note that we do do the right thing for graphs like:
|
||||
//
|
||||
// Input -> F0 -> F1 -> Reshape
|
||||
//
|
||||
// Since we iterate in RPO, we'll first encounter F0, decluster it, then
|
||||
// encounter F1, decluster it and so on.
|
||||
if (node_on_cluster_edge) {
|
||||
bool must_compile_node;
|
||||
TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
|
||||
if (!must_compile_node) {
|
||||
VLOG(3) << "Declustering must-be-constant node " << n->name();
|
||||
RemoveFromXlaCluster(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status PartiallyDeclusterPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
// NB! In this pass we assume the only XLA-auto-clusterable operations that
|
||||
// may have side effects are resource variable operations so we don't cluster
|
||||
// those. The pass will have to be updated if this assumption becomes
|
||||
// invalid.
|
||||
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
|
||||
TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -20,34 +20,11 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Clones nodes from within a cluster to outside the cluster if profitable.
|
||||
// Clones or moves nodes from within a cluster to outside the cluster if
|
||||
// profitable. There are two reasons why we do this:
|
||||
//
|
||||
// Today this only clones to avoid device-to-host copies, but in the future we
|
||||
// may consider other reasons to clone. For instance, we convert this:
|
||||
//
|
||||
// .....
|
||||
// |
|
||||
// v
|
||||
// A_Clustered ====> C_Unclustered
|
||||
// |
|
||||
// v
|
||||
// B_Clustered
|
||||
//
|
||||
// to:
|
||||
//
|
||||
// .....
|
||||
// | |
|
||||
// | +-------------+
|
||||
// | |
|
||||
// v v
|
||||
// A_Clustered A_Unclustered ====> C_Unclustered
|
||||
// |
|
||||
// v
|
||||
// B_Clustered
|
||||
//
|
||||
// where the ===> arrow has a hostmem source and destination and would entail a
|
||||
// device to host copy if the source and destination were not in the same XLA
|
||||
// cluster.
|
||||
// - Reducing device-to-host copies.
|
||||
// - Reducing the number of XLA recompilations.
|
||||
class PartiallyDeclusterPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override;
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
@ -31,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -82,8 +84,10 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
|
||||
// Assign all nodes to the CPU device.
|
||||
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
|
||||
for (Node* n : (*graph)->nodes()) {
|
||||
if (n->assigned_device_name().empty()) {
|
||||
n->set_assigned_device_name(kCpuDevice);
|
||||
}
|
||||
}
|
||||
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
opt_options.graph = graph;
|
||||
@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
|
||||
return pass.Run(opt_options);
|
||||
}
|
||||
|
||||
const Node* FindNodeByName(const Graph& graph, const string& name) {
|
||||
for (const Node* node : graph.nodes()) {
|
||||
Node* FindNodeByName(const Graph& graph, const string& name) {
|
||||
for (Node* node : graph.nodes()) {
|
||||
if (node->name() == name) {
|
||||
return node;
|
||||
}
|
||||
@ -279,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
|
||||
"ClusteredProducer0/declustered");
|
||||
EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
|
||||
}
|
||||
|
||||
void AddToCluster(absl::Span<Node* const> nodes,
|
||||
absl::string_view cluster_name) {
|
||||
for (Node* n : nodes) {
|
||||
n->AddAttr(kXlaClusterAttr, string(cluster_name));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
|
||||
|
||||
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
|
||||
DT_FLOAT, ops::Placeholder::Attrs{});
|
||||
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
|
||||
|
||||
AddToCluster({shape.node(), reshape.node()}, "cluster_0");
|
||||
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
|
||||
const Node* n = FindNodeByName(*graph, "shape");
|
||||
ASSERT_NE(n, nullptr);
|
||||
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt);
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
|
||||
Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
|
||||
|
||||
Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
|
||||
|
||||
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
|
||||
DT_FLOAT, ops::Placeholder::Attrs{});
|
||||
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
|
||||
|
||||
AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
|
||||
"cluster_0");
|
||||
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
|
||||
const Node* n = FindNodeByName(*graph, "shape");
|
||||
ASSERT_NE(n, nullptr);
|
||||
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
|
||||
|
||||
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
|
||||
DT_FLOAT, ops::Placeholder::Attrs{});
|
||||
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
|
||||
|
||||
AddToCluster({reshape.node()}, "cluster_0");
|
||||
AddToCluster({shape.node()}, "cluster_1");
|
||||
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
|
||||
const Node* n = FindNodeByName(*graph, "shape");
|
||||
ASSERT_NE(n, nullptr);
|
||||
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
|
||||
}
|
||||
|
||||
TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
|
||||
ops::Placeholder::Attrs{});
|
||||
Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
|
||||
|
||||
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
|
||||
DT_FLOAT, ops::Placeholder::Attrs{});
|
||||
Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
|
||||
|
||||
AddToCluster({shape.node(), reshape.node()}, "cluster_0");
|
||||
|
||||
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
TF_ASSERT_OK(s.ToGraph(graph.get()));
|
||||
|
||||
// This is needed to register the XLA_GPU device.
|
||||
std::vector<Device*> devices;
|
||||
TF_ASSERT_OK(DeviceFactory::AddDevices(
|
||||
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
|
||||
|
||||
// Scope::ToGraph loses the assigned device name since it goes through
|
||||
// GraphDef/NodeDef which does not have a field for the assigned device name.
|
||||
Node* n = FindNodeByName(*graph, "shape");
|
||||
ASSERT_NE(n, nullptr);
|
||||
n->set_assigned_device_name(
|
||||
"/job:localhost/replica:0/task:0/device:XLA_GPU:0");
|
||||
|
||||
TF_ASSERT_OK(PartiallyDecluster(&graph));
|
||||
|
||||
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
|
||||
|
||||
for (Device* d : devices) {
|
||||
delete d;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -210,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
|
||||
node_def->mutable_attr()->erase(kXlaClusterAttr);
|
||||
}
|
||||
|
||||
void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
|
||||
|
||||
Status AdjustCycleDetectionGraphForResourceOps(
|
||||
const Graph* graph, const FunctionLibraryDefinition* flib_def,
|
||||
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
|
||||
|
@ -53,6 +53,9 @@ absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
|
||||
// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
|
||||
void RemoveFromXlaCluster(NodeDef* node_def);
|
||||
|
||||
// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute).
|
||||
void RemoveFromXlaCluster(Node* node);
|
||||
|
||||
// Returns true if `node` has a DT_RESOURCE typed input or output.
|
||||
bool HasResourceInputOrOutput(const Node& node);
|
||||
|
||||
|
@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase):
|
||||
Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
|
||||
"""
|
||||
|
||||
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
|
||||
os.environ["TF_XLA_FLAGS"] = (
|
||||
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
|
||||
config = config_pb2.ConfigProto()
|
||||
config.graph_options.optimizer_options.global_jit_level = (
|
||||
config_pb2.OptimizerOptions.ON_1)
|
||||
@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase):
|
||||
|
||||
labels = GetRunMetadataLabels(run_metadata)
|
||||
self.assertEqual(1, XlaLaunchOpCount(labels))
|
||||
self.assertFalse(InLabels(labels, "ListDiff"))
|
||||
self.assertFalse(InLabels(labels, "MatMult"))
|
||||
|
||||
def testDenseLayerJitScopeDefinedShape(self):
|
||||
"""Tests that the dense layer node is properly compiled in jit scope.
|
||||
@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase):
|
||||
|
||||
labels = GetRunMetadataLabels(run_metadata)
|
||||
self.assertEqual(2, XlaLaunchOpCount(labels))
|
||||
self.assertFalse(InLabels(labels, "ListDiff"))
|
||||
self.assertFalse(InLabels(labels, "MatMult"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase):
|
||||
def testElementWiseClustering(self):
|
||||
arg0 = np.random.rand(2, 2).astype(np.float32)
|
||||
arg1 = np.random.rand(2, 2).astype(np.float32)
|
||||
os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
|
||||
"--tf_xla_cpu_global_jit")
|
||||
os.environ["TF_XLA_FLAGS"] = (
|
||||
"--tf_xla_fusion_only=true "
|
||||
"--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
|
||||
tf_op, tf_count = self.simpleTest(arg0, arg1,
|
||||
config_pb2.OptimizerOptions.OFF)
|
||||
self.assertEqual(0, tf_count)
|
||||
|
@ -26,8 +26,9 @@ namespace tensorflow {
|
||||
// Backwards dataflow analysis that finds arguments to a graph that must be
|
||||
// compile-time constants.
|
||||
Status BackwardsConstAnalysis(const Graph& g,
|
||||
std::vector<bool>* compile_time_const_args,
|
||||
std::vector<bool>* compile_time_const_nodes) {
|
||||
std::vector<bool>* compile_time_const_arg_indices,
|
||||
std::vector<bool>* compile_time_const_nodes,
|
||||
std::function<bool(const Edge&)> edge_filter) {
|
||||
// Operators that don't look at the data of their inputs, just the shapes.
|
||||
const std::unordered_set<string> metadata_ops = {
|
||||
"Rank",
|
||||
@ -45,8 +46,7 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
}
|
||||
|
||||
Status status;
|
||||
auto visit = [&status, &metadata_ops, compile_time_const_nodes,
|
||||
compile_time_const_args](Node* node) {
|
||||
auto visit = [&](Node* node) {
|
||||
if (!status.ok()) return;
|
||||
|
||||
// If this is a metadata-only op, don't propagate the const requirement.
|
||||
@ -59,13 +59,13 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
int index;
|
||||
status = GetNodeAttr(node->attrs(), "index", &index);
|
||||
if (!status.ok()) return;
|
||||
if (compile_time_const_args) {
|
||||
(*compile_time_const_args)[index] = true;
|
||||
if (compile_time_const_arg_indices) {
|
||||
(*compile_time_const_arg_indices)[index] = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (const Edge* pred : node->in_edges()) {
|
||||
if (!pred->IsControlEdge()) {
|
||||
if (!pred->IsControlEdge() && edge_filter(*pred)) {
|
||||
(*compile_time_const_nodes)[pred->src()->id()] = true;
|
||||
}
|
||||
}
|
||||
@ -88,7 +88,8 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
|
||||
for (Edge const* edge : node->in_edges()) {
|
||||
if (edge->dst_input() >= name_range->second.first &&
|
||||
edge->dst_input() < name_range->second.second) {
|
||||
edge->dst_input() < name_range->second.second &&
|
||||
edge_filter(*edge)) {
|
||||
(*compile_time_const_nodes)[edge->src()->id()] = true;
|
||||
}
|
||||
}
|
||||
@ -97,7 +98,8 @@ Status BackwardsConstAnalysis(const Graph& g,
|
||||
|
||||
// Post-order traversal visits nodes in reverse topological order for an
|
||||
// acyclic graph.
|
||||
DFS(g, {}, visit);
|
||||
DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
|
||||
[](const Edge& edge) { return !edge.src()->IsNextIteration(); });
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -32,9 +32,13 @@ namespace tensorflow {
|
||||
//
|
||||
// The ids of the nodes in `graph` that must be constant are returned in
|
||||
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
|
||||
Status BackwardsConstAnalysis(const Graph& graph,
|
||||
//
|
||||
// Only propagate const-ness along edges for which `edge_filter` returns true.
|
||||
Status BackwardsConstAnalysis(const Graph& g,
|
||||
std::vector<bool>* compile_time_const_arg_indices,
|
||||
std::vector<bool>* compile_time_const_nodes);
|
||||
std::vector<bool>* compile_time_const_nodes,
|
||||
std::function<bool(const Edge&)> edge_filter =
|
||||
[](const Edge& e) { return true; });
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user