diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index b5299b908c6..fca411337a9 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -65,7 +65,7 @@ cc_library( "freeze_requantization_ranges.cc", "fuse_convolutions.cc", "insert_logging.cc", - "obsfucate_names.cc", + "obfuscate_names.cc", "remove_attribute.cc", "remove_device.cc", "remove_nodes.cc", @@ -73,6 +73,7 @@ cc_library( "rename_op.cc", "set_device.cc", "sort_by_execution_order.cc", + "sparsify_gather.cc", "strip_unused_nodes.cc", ] + if_not_windows([ "quantize_nodes.cc", @@ -111,7 +112,7 @@ tf_cc_test( "freeze_requantization_ranges_test.cc", "fuse_convolutions_test.cc", "insert_logging_test.cc", - "obsfucate_names_test.cc", + "obfuscate_names_test.cc", "quantize_nodes_test.cc", "quantize_weights_test.cc", "remove_attribute_test.cc", @@ -122,6 +123,7 @@ tf_cc_test( "round_weights_test.cc", "set_device_test.cc", "sort_by_execution_order_test.cc", + "sparsify_gather_test.cc", "strip_unused_nodes_test.cc", ], deps = [ diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index 73af8699b24..b036084207f 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -20,7 +20,7 @@ * [fuse_convolutions](#fuse_convolutions) * [insert_logging](#insert_logging) * [merge_duplicate_nodes](#merge_duplicate_nodes) - * [obsfucate_names](#obsfucate_names) + * [obfuscate_names](#obfuscate_names) * [quantize_nodes](#quantize_nodes) * [quantize_weights](#quantize_weights) * [remove_attribute](#remove_attribute) @@ -29,6 +29,7 @@ * [rename_attribute](#rename_attribute) * [rename_op](#rename_op) * [round_weights](#round_weights) + * [sparsify_gather](#sparsify_gather) * [set_device](#set_device) * [sort_by_execution_order](#sort_by_execution_order) * [strip_unused_nodes](#strip_unused_nodes) @@ -250,7 +251,7 @@ results are cached and so you shouldn't see the graph run any more slowly. So far we've been concentrating on weights because those generally take up the most space. If you have a graph with a lot of small nodes in it, the names of those nodes can start to take up a noticeable amount of space too. To shrink -those down, you can run the [obsfucate_names](#obsfucate_names) transform, which +those down, you can run the [obfuscate_names](#obfuscate_names) transform, which replaces all the names (except for inputs and outputs) with short, cryptic but unique ids: @@ -262,7 +263,7 @@ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --inputs='Mul:0' \ --outputs='softmax:0' \ --transforms='\ -obsfucate_names \ +obfuscate_names \ ' ``` @@ -520,7 +521,7 @@ of redundancy (e.g. this transform is always run as part of [quantize_nodes](#quantize_nodes) since the processing there can introduce duplicates of constants that are used in the quantize/dequantize process). -### obsfucate_names +### obfuscate_names Args: None \ Prerequisites: None @@ -656,6 +657,16 @@ between the largest and smallest values present. This is useful when you'll be deploying on mobile, and you want a model that will compress effectively. See [shrinking file size](#shrinking-file-size) for more details. +### sparsify_gather + +Args: None \ +Prerequisites: None + +Transform 'Gather' op to a sparsified version where 'params' input of 'Gather' +is replaced from a dense 'Const' to a 'HashTable'. 'Gather' op itself is +replaced by a hashtable lookup. This is mostly useful for reducing sparse +TF.learn linear model memory footprint. + ### set_device Args: diff --git a/tensorflow/tools/graph_transforms/obsfucate_names.cc b/tensorflow/tools/graph_transforms/obfuscate_names.cc similarity index 95% rename from tensorflow/tools/graph_transforms/obsfucate_names.cc rename to tensorflow/tools/graph_transforms/obfuscate_names.cc index c665ed947af..c470b51b960 100644 --- a/tensorflow/tools/graph_transforms/obsfucate_names.cc +++ b/tensorflow/tools/graph_transforms/obfuscate_names.cc @@ -29,7 +29,7 @@ namespace graph_transforms { // Renames all nodes not uses as graph inputs or outputs to short numerical // forms. -Status ObsfucateNames(const GraphDef& input_graph_def, +Status ObfuscateNames(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { std::unordered_set required_nodes; @@ -73,7 +73,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def, output_graph_def->Clear(); for (const NodeDef& input_node : input_graph_def.node()) { NodeDef* node = output_graph_def->mutable_node()->Add(); - node->CopyFrom(input_node); + *node = input_node; const string& old_name = input_node.name(); node->set_name(new_names[old_name]); node->mutable_input()->Clear(); @@ -94,7 +94,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def, return Status::OK(); } -REGISTER_GRAPH_TRANSFORM("obsfucate_names", ObsfucateNames); +REGISTER_GRAPH_TRANSFORM("obfuscate_names", ObfuscateNames); } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/obsfucate_names_test.cc b/tensorflow/tools/graph_transforms/obfuscate_names_test.cc similarity index 90% rename from tensorflow/tools/graph_transforms/obsfucate_names_test.cc rename to tensorflow/tools/graph_transforms/obfuscate_names_test.cc index 90b34a707ab..14df7ba74e0 100644 --- a/tensorflow/tools/graph_transforms/obsfucate_names_test.cc +++ b/tensorflow/tools/graph_transforms/obfuscate_names_test.cc @@ -29,11 +29,11 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status ObsfucateNames(const GraphDef& input_graph_def, +Status ObfuscateNames(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def); -class ObsfucateNamesTest : public ::testing::Test { +class ObfuscateNamesTest : public ::testing::Test { protected: void TestSimpleTree() { GraphDef graph_def; @@ -74,7 +74,7 @@ class ObsfucateNamesTest : public ::testing::Test { GraphDef result; TF_ASSERT_OK( - ObsfucateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result)); + ObfuscateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result)); std::map node_lookup; MapNamesToNodes(result, &node_lookup); @@ -97,7 +97,7 @@ class ObsfucateNamesTest : public ::testing::Test { } GraphDef result; - TF_ASSERT_OK(ObsfucateNames(graph_def, {{"const_node0"}, {"const_node999"}}, + TF_ASSERT_OK(ObfuscateNames(graph_def, {{"const_node0"}, {"const_node999"}}, &result)); std::map node_lookup; @@ -116,7 +116,7 @@ class ObsfucateNamesTest : public ::testing::Test { } GraphDef result; - TF_ASSERT_OK(ObsfucateNames(graph_def, {{"10"}, {"19"}}, &result)); + TF_ASSERT_OK(ObfuscateNames(graph_def, {{"10"}, {"19"}}, &result)); std::map node_lookup; MapNamesToNodes(result, &node_lookup); @@ -132,11 +132,11 @@ class ObsfucateNamesTest : public ::testing::Test { } }; -TEST_F(ObsfucateNamesTest, TestSimpleTree) { TestSimpleTree(); } +TEST_F(ObfuscateNamesTest, TestSimpleTree) { TestSimpleTree(); } -TEST_F(ObsfucateNamesTest, TestManyNodes) { TestManyNodes(); } +TEST_F(ObfuscateNamesTest, TestManyNodes) { TestManyNodes(); } -TEST_F(ObsfucateNamesTest, TestNameClashes) { TestNameClashes(); } +TEST_F(ObfuscateNamesTest, TestNameClashes) { TestNameClashes(); } } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc new file mode 100644 index 00000000000..bddc2fad730 --- /dev/null +++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc @@ -0,0 +1,276 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +using strings::StrCat; +namespace graph_transforms { +namespace { + +// Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for +// non-zero tensor content. +Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor, + Tensor* values_tensor) { + if (tensor.dims() != 2 || tensor.dim_size(1) != 1) { + return tensorflow::errors::FailedPrecondition( + "Transform only applicable to subgraph with 'Const' with " + "tensor of shpae [N, 1]. But instead get shape ", + tensor.shape().DebugString(), "."); + } + + auto flat = tensor.flat(); + std::vector indices; + std::vector values; + + for (int64 i = 0; i < flat.size(); i++) { + float val = flat(i); + if (std::abs(val) >= 1.0e-5) { + indices.push_back(i); + values.push_back(val); + } + } + + // During model initialization, InitializeTableOp makes use of + // KeyValueTensorIterator, which does not accept empty keys or values. + // Consequently, adding a dummy pair of indices and values as a walkaround. + if (indices.empty() || values.empty()) { + indices.push_back(0); + values.push_back(0); + } + *indices_tensor = Tensor(DataTypeToEnum::value, + {static_cast(indices.size())}); + std::copy_n(indices.begin(), indices.size(), + indices_tensor->flat().data()); + + *values_tensor = + Tensor(DataTypeToEnum::value, {static_cast(values.size())}); + std::copy_n(values.begin(), values.size(), + values_tensor->flat().data()); + + return Status::OK(); +} + +void CreateConstNode(const Tensor& tensor, const string& name, + NodeDef* node_def) { + node_def->set_op("Const"); + node_def->set_name(name); + SetNodeTensorAttr("value", tensor, node_def); +} +} // namespace + +Status SparsifyGather(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + GraphDef current_graph_def = input_graph_def; + bool any_match_found = false; + // The subgraphs may have overlapping components, therefore GraphMatcher + // doesn't return all subgraphs in one round -- this has to be multi-round + // update. + do { + any_match_found = false; + GraphDef replaced_graph_def = current_graph_def; + std::vector init_table_node_names; + + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + current_graph_def, // clang-format off + {"Gather", + { + {"Identity", + { + {"Const"} + } + }, + {"*"}, + } + }, // clang-format on + [&any_match_found, &init_table_node_names]( + const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + any_match_found = true; + + // The captured subgraph should be of the following pattern: + // Const --> Identity --> Gather --> ... + // ^ + // | + // (ids) + // + // After transform, it becomes: + // --> NoOp(group_deps) + // | + // Const --> InitializeTable --> HashTable + // ^ | + // | | + // Const ------------- | + // v + // (ids) ---> LookupTableFind <--- Const(default) + // | + // v + // ... + + // clang-format off + // For each subgraph, do the following + // 1. Sparsify the `Const`, creating two `Const`, for hashtable + // key/val. + // 2. Create a `InitializeTable` op connecting to the above 2 `Const`. + // 3. Create a `HashTable` op connecting to `InitializeTable` op. + // 4. Replace the `Gather` with a `LookupTableFind` op. + // 5. Connect the `LookupTableFind` with + // a. `HashTable` + // b. `Gather`'s ids input + // c. a `default_val` arg, valued at 0 + // clang-format on + const NodeDef& gather_node = match.node; + const NodeDef& const_node = match.inputs[0].inputs[0].node; + + DataType data_type; + GetNodeAttr(const_node, "dtype", &data_type); + if (data_type != DT_FLOAT) { + return tensorflow::errors::FailedPrecondition( + "Transform only applicable to subgraph with 'Const' of dtype " + "'DT_FLOAT'. Found 'Const' with name '", + const_node.name(), "' and dtype '", data_type, "'."); + } + Tensor weight = GetNodeTensorAttr(const_node, "value"); + Tensor indices_tensor; + Tensor values_tensor; + TF_RETURN_IF_ERROR( + SparsifyWeights(weight, &indices_tensor, &values_tensor)); + + // indices and values of sparsified `Const` + DataType key_dtype = DT_INT64; + NodeDef indices_node; + CreateConstNode(indices_tensor, StrCat(const_node.name(), "/indices"), + &indices_node); + SetNodeAttr("dtype", key_dtype, &indices_node); + + NodeDef values_node; + CreateConstNode(values_tensor, StrCat(const_node.name(), "/values"), + &values_node); + SetNodeAttr("dtype", data_type, &values_node); + + // HashTable node + NodeDef hashtable_node; + hashtable_node.set_op("HashTable"); + hashtable_node.set_name(StrCat(const_node.name(), "/HashTable")); + SetNodeAttr("key_dtype", key_dtype, &hashtable_node); + SetNodeAttr("value_dtype", data_type, &hashtable_node); + + // InitializeTable node + NodeDef init_table_node; + init_table_node.set_op("InitializeTable"); + init_table_node.set_name( + StrCat(const_node.name(), "/InitializeTable")); + SetNodeAttr("Tkey", key_dtype, &init_table_node); + SetNodeAttr("Tval", data_type, &init_table_node); + init_table_node_names.push_back(init_table_node.name()); + + // LookupTableFind node + NodeDef lookup_node; + lookup_node.set_op("LookupTableFind"); + lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind")); + SetNodeAttr("Tin", key_dtype, &lookup_node); + SetNodeAttr("Tout", data_type, &lookup_node); + + // Default return value of hashtable lookup + Tensor zero_tensor(data_type, TensorShape({})); + zero_tensor.flat()(0) = 0.0; + NodeDef default_value_node; + CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"), + &default_value_node); + SetNodeAttr("dtype", data_type, &default_value_node); + + // ExpandDims argument + Tensor dim_idx(DT_INT32, TensorShape({})); + dim_idx.flat()(0) = -1; + NodeDef dim_idx_node; + dim_idx_node.set_op("Const"); + dim_idx_node.set_name( + StrCat(gather_node.name(), "/ExpandDims/Const")); + SetNodeAttr("value", dim_idx, &dim_idx_node); + SetNodeAttr("dtype", DT_INT32, &dim_idx_node); + + // ExpandDims node + NodeDef expand_dims_node; + expand_dims_node.set_op("ExpandDims"); + // Reuse gather_node's name so not to change dependent's inputs + expand_dims_node.set_name(gather_node.name()); + SetNodeAttr("T", data_type, &expand_dims_node); + + // Connect nodes + AddNodeInput(hashtable_node.name(), &init_table_node); + AddNodeInput(indices_node.name(), &init_table_node); + AddNodeInput(values_node.name(), &init_table_node); + + AddNodeInput(hashtable_node.name(), &lookup_node); + AddNodeInput(gather_node.input(1), &lookup_node); + AddNodeInput(default_value_node.name(), &lookup_node); + + AddNodeInput(lookup_node.name(), &expand_dims_node); + AddNodeInput(dim_idx_node.name(), &expand_dims_node); + + // Copy 'ids' input of original 'Gather' + new_nodes->push_back(match.inputs[1].node); + new_nodes->push_back(indices_node); + new_nodes->push_back(values_node); + new_nodes->push_back(hashtable_node); + new_nodes->push_back(init_table_node); + new_nodes->push_back(lookup_node); + new_nodes->push_back(default_value_node); + new_nodes->push_back(dim_idx_node); + new_nodes->push_back(expand_dims_node); + + return Status::OK(); + }, + {true}, &replaced_graph_def)); + NodeDef* init_op = nullptr; + for (int i = 0; i < replaced_graph_def.node_size(); i++) { + if (replaced_graph_def.node(i).name() == "group_deps" && + replaced_graph_def.node(i).op() == "NoOp") { + if (init_op != nullptr) { + return tensorflow::errors::FailedPrecondition( + "Multiple nodes with name: 'group_deps' and type: 'NoOp'."); + } + init_op = replaced_graph_def.mutable_node(i); + } + } + if (!init_op) { + return tensorflow::errors::FailedPrecondition( + "No node found with name: 'group_deps' and type: 'NoOp'"); + } + for (const string& name : init_table_node_names) { + // Add control dependence from init_table_node to group_deps_node + AddNodeInput(StrCat("^", name), init_op); + } + current_graph_def = replaced_graph_def; + } while (any_match_found); + *output_graph_def = current_graph_def; + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc new file mode 100644 index 00000000000..8d353d34763 --- /dev/null +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -0,0 +1,352 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status SparsifyGather(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +class SparsifyGatherTest : public ::testing::Test { + protected: + NodeDef* CreateNode(const string& name, const string& op, + const std::vector& inputs, + GraphDef* graph_def) { + NodeDef* node_def = graph_def->add_node(); + node_def->set_name(name); + node_def->set_op(op); + std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) { + node_def->add_input(input->name()); + }); + return node_def; + } + + void TestSinglePartitionConst() { + GraphDef graph_def; + + // Build the graph. + NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def); + NodeDef* const_node = CreateNode("const", "Const", {}, &graph_def); + SetNodeAttr("dtype", DT_FLOAT, const_node); + // Set 'Const' node value. + Tensor weights(DT_FLOAT, TensorShape({4, 1})); + test::FillValues(&weights, {0.2, 0.000001, 1.2, 0.001}); + SetNodeTensorAttr("value", weights, const_node); + + NodeDef* identity_node = + CreateNode("const/read", "Identity", {const_node}, &graph_def); + CreateNode("gather", "Gather", {identity_node, input_node}, &graph_def); + CreateNode("group_deps", "NoOp", {}, &graph_def); + + // Run the op. + GraphDef result; + TransformFuncContext context; + context.input_names = {"ids"}; + context.output_names = {"gather"}; + TF_ASSERT_OK(SparsifyGather(graph_def, context, &result)); + + // Validation begins. + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + + // Check nodes. + EXPECT_EQ(1, node_lookup.count("ids")); + EXPECT_EQ("Const", node_lookup.at("ids")->op()); + + // Nodes in "const" scope. + EXPECT_EQ(1, node_lookup.count("const/indices")); + EXPECT_EQ("Const", node_lookup.at("const/indices")->op()); + Tensor expected_indices_tensor(DT_INT64, TensorShape({3})); + test::FillValues(&expected_indices_tensor, {0, 2, 3}); + test::ExpectTensorEqual( + expected_indices_tensor, + GetNodeTensorAttr(*(node_lookup.at("const/indices")), "value")); + + EXPECT_EQ(1, node_lookup.count("const/values")); + EXPECT_EQ("Const", node_lookup.at("const/values")->op()); + Tensor expected_values_tensor(DT_FLOAT, TensorShape({3})); + test::FillValues(&expected_values_tensor, {0.2, 1.2, 0.001}); + test::ExpectTensorNear( + expected_values_tensor, + GetNodeTensorAttr(*(node_lookup.at("const/values")), "value"), 1e-5); + + EXPECT_EQ(1, node_lookup.count("const/HashTable")); + EXPECT_EQ("HashTable", node_lookup.at("const/HashTable")->op()); + + EXPECT_EQ(1, node_lookup.count("const/InitializeTable")); + EXPECT_EQ("InitializeTable", node_lookup.at("const/InitializeTable")->op()); + + // Nodes in "gather" scope. + EXPECT_EQ(1, node_lookup.count("gather/LookupTableFind")); + EXPECT_EQ("LookupTableFind", + node_lookup.at("gather/LookupTableFind")->op()); + + EXPECT_EQ(1, node_lookup.count("gather/Const")); + EXPECT_EQ("Const", node_lookup.at("gather/Const")->op()); + Tensor expected_gather_default_tensor(DT_FLOAT, TensorShape({})); + test::FillValues(&expected_gather_default_tensor, {0.0}); + test::ExpectTensorNear( + expected_gather_default_tensor, + GetNodeTensorAttr(*(node_lookup.at("gather/Const")), "value"), 1e-5); + + EXPECT_EQ(1, node_lookup.count("gather/ExpandDims/Const")); + EXPECT_EQ("Const", node_lookup.at("gather/ExpandDims/Const")->op()); + Tensor expected_expand_dims_tensor(DT_INT32, TensorShape({})); + test::FillValues(&expected_expand_dims_tensor, {-1}); + test::ExpectTensorEqual( + expected_expand_dims_tensor, + GetNodeTensorAttr(*(node_lookup.at("gather/ExpandDims/Const")), + "value")); + + EXPECT_EQ(1, node_lookup.count("gather")); + EXPECT_EQ("ExpandDims", node_lookup.at("gather")->op()); + + EXPECT_EQ(1, node_lookup.count("group_deps")); + EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op()); + + // Check connections + EXPECT_EQ("const/HashTable", + node_lookup.at("const/InitializeTable")->input(0)); + EXPECT_EQ("const/indices", + node_lookup.at("const/InitializeTable")->input(1)); + EXPECT_EQ("const/values", + node_lookup.at("const/InitializeTable")->input(2)); + + EXPECT_EQ("const/HashTable", + node_lookup.at("gather/LookupTableFind")->input(0)); + EXPECT_EQ("ids", node_lookup.at("gather/LookupTableFind")->input(1)); + EXPECT_EQ("gather/Const", + node_lookup.at("gather/LookupTableFind")->input(2)); + + EXPECT_EQ("gather/LookupTableFind", node_lookup.at("gather")->input(0)); + + // Check control dependency. + EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(), + node_lookup.at("group_deps")->input().end(), + "^const/InitializeTable"), + node_lookup.at("group_deps")->input().end()); + } + + void TestMultiPartitionConst() { + // The 'ids' node is served input for two 'Gather's. + GraphDef graph_def; + + // Build Graph: + // Shared input node + NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def); + // Shared init node + CreateNode("group_deps", "NoOp", {}, &graph_def); + + // Two partitions + NodeDef* const_node1 = CreateNode("const1", "Const", {}, &graph_def); + SetNodeAttr("dtype", DT_FLOAT, const_node1); + // Set 'Const' node value. + Tensor weights(DT_FLOAT, TensorShape({4, 1})); + test::FillValues(&weights, {0.2, 0.000001, 1.2, 0.001}); + SetNodeTensorAttr("value", weights, const_node1); + + NodeDef* const_node2 = CreateNode("const2", "Const", {}, &graph_def); + SetNodeAttr("dtype", DT_FLOAT, const_node2); + SetNodeTensorAttr("value", weights, const_node2); + + NodeDef* identity_node1 = + CreateNode("const1/read", "Identity", {const_node1}, &graph_def); + NodeDef* identity_node2 = + CreateNode("const2/read", "Identity", {const_node2}, &graph_def); + CreateNode("gather1", "Gather", {identity_node1, input_node}, &graph_def); + CreateNode("gather2", "Gather", {identity_node2, input_node}, &graph_def); + + // Run the op. + GraphDef result; + TransformFuncContext context; + context.input_names = {"ids"}; + context.output_names = {"gather1", "gather2"}; + TF_ASSERT_OK(SparsifyGather(graph_def, context, &result)); + + // Validation begins. + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + + // Check nodes. + // Check shared nodes: + EXPECT_EQ(1, node_lookup.count("ids")); + EXPECT_EQ("Const", node_lookup.at("ids")->op()); + + EXPECT_EQ(1, node_lookup.count("group_deps")); + EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op()); + + // Nodes in "const1" scope. + EXPECT_EQ(1, node_lookup.count("const1/indices")); + EXPECT_EQ("Const", node_lookup.at("const1/indices")->op()); + Tensor expected_indices_tensor1(DT_INT64, TensorShape({3})); + test::FillValues(&expected_indices_tensor1, {0, 2, 3}); + test::ExpectTensorEqual( + expected_indices_tensor1, + GetNodeTensorAttr(*(node_lookup.at("const1/indices")), "value")); + + EXPECT_EQ(1, node_lookup.count("const1/values")); + EXPECT_EQ("Const", node_lookup.at("const1/values")->op()); + Tensor expected_values_tensor1(DT_FLOAT, TensorShape({3})); + test::FillValues(&expected_values_tensor1, {0.2, 1.2, 0.001}); + test::ExpectTensorNear( + expected_values_tensor1, + GetNodeTensorAttr(*(node_lookup.at("const1/values")), "value"), 1e-5); + + EXPECT_EQ(1, node_lookup.count("const1/HashTable")); + EXPECT_EQ("HashTable", node_lookup.at("const1/HashTable")->op()); + + EXPECT_EQ(1, node_lookup.count("const1/InitializeTable")); + EXPECT_EQ("InitializeTable", + node_lookup.at("const1/InitializeTable")->op()); + + // Nodes in "gather1" scope. + EXPECT_EQ(1, node_lookup.count("gather1/LookupTableFind")); + EXPECT_EQ("LookupTableFind", + node_lookup.at("gather1/LookupTableFind")->op()); + + EXPECT_EQ(1, node_lookup.count("gather1/Const")); + EXPECT_EQ("Const", node_lookup.at("gather1/Const")->op()); + Tensor expected_gather_default_tensor1(DT_FLOAT, TensorShape({})); + test::FillValues(&expected_gather_default_tensor1, {0.0}); + test::ExpectTensorNear( + expected_gather_default_tensor1, + GetNodeTensorAttr(*(node_lookup.at("gather1/Const")), "value"), 1e-5); + + EXPECT_EQ(1, node_lookup.count("gather1/ExpandDims/Const")); + EXPECT_EQ("Const", node_lookup.at("gather1/ExpandDims/Const")->op()); + Tensor expected_expand_dims_tensor1(DT_INT32, TensorShape({})); + test::FillValues(&expected_expand_dims_tensor1, {-1}); + test::ExpectTensorEqual( + expected_expand_dims_tensor1, + GetNodeTensorAttr(*(node_lookup.at("gather1/ExpandDims/Const")), + "value")); + + EXPECT_EQ(1, node_lookup.count("gather1")); + EXPECT_EQ("ExpandDims", node_lookup.at("gather1")->op()); + + // Nodes in "const2" scope. + EXPECT_EQ(1, node_lookup.count("const2/indices")); + EXPECT_EQ("Const", node_lookup.at("const2/indices")->op()); + Tensor expected_indices_tensor2(DT_INT64, TensorShape({3})); + test::FillValues(&expected_indices_tensor2, {0, 2, 3}); + test::ExpectTensorEqual( + expected_indices_tensor2, + GetNodeTensorAttr(*(node_lookup.at("const2/indices")), "value")); + + EXPECT_EQ(1, node_lookup.count("const2/values")); + EXPECT_EQ("Const", node_lookup.at("const2/values")->op()); + Tensor expected_values_tensor2(DT_FLOAT, TensorShape({3})); + test::FillValues(&expected_values_tensor2, {0.2, 1.2, 0.001}); + test::ExpectTensorNear( + expected_values_tensor2, + GetNodeTensorAttr(*(node_lookup.at("const2/values")), "value"), 1e-5); + + EXPECT_EQ(1, node_lookup.count("const2/HashTable")); + EXPECT_EQ("HashTable", node_lookup.at("const2/HashTable")->op()); + + EXPECT_EQ(1, node_lookup.count("const2/InitializeTable")); + EXPECT_EQ("InitializeTable", + node_lookup.at("const2/InitializeTable")->op()); + + // Nodes in "gather2" scope. + EXPECT_EQ(1, node_lookup.count("gather2/LookupTableFind")); + EXPECT_EQ("LookupTableFind", + node_lookup.at("gather2/LookupTableFind")->op()); + + EXPECT_EQ(1, node_lookup.count("gather2/Const")); + EXPECT_EQ("Const", node_lookup.at("gather2/Const")->op()); + Tensor expected_gather_default_tensor2(DT_FLOAT, TensorShape({})); + test::FillValues(&expected_gather_default_tensor2, {0.0}); + test::ExpectTensorNear( + expected_gather_default_tensor2, + GetNodeTensorAttr(*(node_lookup.at("gather2/Const")), "value"), 1e-5); + + EXPECT_EQ(1, node_lookup.count("gather2/ExpandDims/Const")); + EXPECT_EQ("Const", node_lookup.at("gather2/ExpandDims/Const")->op()); + Tensor expected_expand_dims_tensor2(DT_INT32, TensorShape({})); + test::FillValues(&expected_expand_dims_tensor2, {-1}); + test::ExpectTensorEqual( + expected_expand_dims_tensor2, + GetNodeTensorAttr(*(node_lookup.at("gather2/ExpandDims/Const")), + "value")); + + EXPECT_EQ(1, node_lookup.count("gather2")); + EXPECT_EQ("ExpandDims", node_lookup.at("gather2")->op()); + + // Check connections + EXPECT_EQ("const1/HashTable", + node_lookup.at("const1/InitializeTable")->input(0)); + EXPECT_EQ("const1/indices", + node_lookup.at("const1/InitializeTable")->input(1)); + EXPECT_EQ("const1/values", + node_lookup.at("const1/InitializeTable")->input(2)); + + EXPECT_EQ("const2/HashTable", + node_lookup.at("const2/InitializeTable")->input(0)); + EXPECT_EQ("const2/indices", + node_lookup.at("const2/InitializeTable")->input(1)); + EXPECT_EQ("const2/values", + node_lookup.at("const2/InitializeTable")->input(2)); + + EXPECT_EQ("const1/HashTable", + node_lookup.at("gather1/LookupTableFind")->input(0)); + EXPECT_EQ("ids", node_lookup.at("gather1/LookupTableFind")->input(1)); + EXPECT_EQ("gather1/Const", + node_lookup.at("gather1/LookupTableFind")->input(2)); + EXPECT_EQ("gather1/LookupTableFind", node_lookup.at("gather1")->input(0)); + + EXPECT_EQ("const2/HashTable", + node_lookup.at("gather2/LookupTableFind")->input(0)); + EXPECT_EQ("ids", node_lookup.at("gather2/LookupTableFind")->input(1)); + EXPECT_EQ("gather2/Const", + node_lookup.at("gather2/LookupTableFind")->input(2)); + EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0)); + + // Check control deps. + EXPECT_EQ(2, node_lookup.at("group_deps")->input_size()); + EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(), + node_lookup.at("group_deps")->input().end(), + "^const1/InitializeTable"), + node_lookup.at("group_deps")->input().end()); + + EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(), + node_lookup.at("group_deps")->input().end(), + "^const2/InitializeTable"), + node_lookup.at("group_deps")->input().end()); + } +}; + +TEST_F(SparsifyGatherTest, TestSinglePartitionConst) { + TestSinglePartitionConst(); +} + +TEST_F(SparsifyGatherTest, TestMultiPartitionConst) { + TestMultiPartitionConst(); +} + +} // namespace graph_transforms +} // namespace tensorflow