Add sparsify_gather op to reduce linear model memory footprint.

Also fixed typo: "obsfucate" -> "obfuscate"
Change: 149627297
This commit is contained in:
James Qin 2017-03-09 00:31:26 -08:00 committed by TensorFlower Gardener
parent 62be492ef4
commit 168e8dacba
6 changed files with 658 additions and 17 deletions

View File

@ -65,7 +65,7 @@ cc_library(
"freeze_requantization_ranges.cc", "freeze_requantization_ranges.cc",
"fuse_convolutions.cc", "fuse_convolutions.cc",
"insert_logging.cc", "insert_logging.cc",
"obsfucate_names.cc", "obfuscate_names.cc",
"remove_attribute.cc", "remove_attribute.cc",
"remove_device.cc", "remove_device.cc",
"remove_nodes.cc", "remove_nodes.cc",
@ -73,6 +73,7 @@ cc_library(
"rename_op.cc", "rename_op.cc",
"set_device.cc", "set_device.cc",
"sort_by_execution_order.cc", "sort_by_execution_order.cc",
"sparsify_gather.cc",
"strip_unused_nodes.cc", "strip_unused_nodes.cc",
] + if_not_windows([ ] + if_not_windows([
"quantize_nodes.cc", "quantize_nodes.cc",
@ -111,7 +112,7 @@ tf_cc_test(
"freeze_requantization_ranges_test.cc", "freeze_requantization_ranges_test.cc",
"fuse_convolutions_test.cc", "fuse_convolutions_test.cc",
"insert_logging_test.cc", "insert_logging_test.cc",
"obsfucate_names_test.cc", "obfuscate_names_test.cc",
"quantize_nodes_test.cc", "quantize_nodes_test.cc",
"quantize_weights_test.cc", "quantize_weights_test.cc",
"remove_attribute_test.cc", "remove_attribute_test.cc",
@ -122,6 +123,7 @@ tf_cc_test(
"round_weights_test.cc", "round_weights_test.cc",
"set_device_test.cc", "set_device_test.cc",
"sort_by_execution_order_test.cc", "sort_by_execution_order_test.cc",
"sparsify_gather_test.cc",
"strip_unused_nodes_test.cc", "strip_unused_nodes_test.cc",
], ],
deps = [ deps = [

View File

@ -20,7 +20,7 @@
* [fuse_convolutions](#fuse_convolutions) * [fuse_convolutions](#fuse_convolutions)
* [insert_logging](#insert_logging) * [insert_logging](#insert_logging)
* [merge_duplicate_nodes](#merge_duplicate_nodes) * [merge_duplicate_nodes](#merge_duplicate_nodes)
* [obsfucate_names](#obsfucate_names) * [obfuscate_names](#obfuscate_names)
* [quantize_nodes](#quantize_nodes) * [quantize_nodes](#quantize_nodes)
* [quantize_weights](#quantize_weights) * [quantize_weights](#quantize_weights)
* [remove_attribute](#remove_attribute) * [remove_attribute](#remove_attribute)
@ -29,6 +29,7 @@
* [rename_attribute](#rename_attribute) * [rename_attribute](#rename_attribute)
* [rename_op](#rename_op) * [rename_op](#rename_op)
* [round_weights](#round_weights) * [round_weights](#round_weights)
* [sparsify_gather](#sparsify_gather)
* [set_device](#set_device) * [set_device](#set_device)
* [sort_by_execution_order](#sort_by_execution_order) * [sort_by_execution_order](#sort_by_execution_order)
* [strip_unused_nodes](#strip_unused_nodes) * [strip_unused_nodes](#strip_unused_nodes)
@ -250,7 +251,7 @@ results are cached and so you shouldn't see the graph run any more slowly.
So far we've been concentrating on weights because those generally take up the So far we've been concentrating on weights because those generally take up the
most space. If you have a graph with a lot of small nodes in it, the names of most space. If you have a graph with a lot of small nodes in it, the names of
those nodes can start to take up a noticeable amount of space too. To shrink those nodes can start to take up a noticeable amount of space too. To shrink
those down, you can run the [obsfucate_names](#obsfucate_names) transform, which those down, you can run the [obfuscate_names](#obfuscate_names) transform, which
replaces all the names (except for inputs and outputs) with short, cryptic but replaces all the names (except for inputs and outputs) with short, cryptic but
unique ids: unique ids:
@ -262,7 +263,7 @@ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--inputs='Mul:0' \ --inputs='Mul:0' \
--outputs='softmax:0' \ --outputs='softmax:0' \
--transforms='\ --transforms='\
obsfucate_names \ obfuscate_names \
' '
``` ```
@ -520,7 +521,7 @@ of redundancy (e.g. this transform is always run as part of
[quantize_nodes](#quantize_nodes) since the processing there can introduce [quantize_nodes](#quantize_nodes) since the processing there can introduce
duplicates of constants that are used in the quantize/dequantize process). duplicates of constants that are used in the quantize/dequantize process).
### obsfucate_names ### obfuscate_names
Args: None \ Args: None \
Prerequisites: None Prerequisites: None
@ -656,6 +657,16 @@ between the largest and smallest values present. This is useful when you'll be
deploying on mobile, and you want a model that will compress effectively. See deploying on mobile, and you want a model that will compress effectively. See
[shrinking file size](#shrinking-file-size) for more details. [shrinking file size](#shrinking-file-size) for more details.
### sparsify_gather
Args: None \
Prerequisites: None
Transform 'Gather' op to a sparsified version where 'params' input of 'Gather'
is replaced from a dense 'Const' to a 'HashTable'. 'Gather' op itself is
replaced by a hashtable lookup. This is mostly useful for reducing sparse
TF.learn linear model memory footprint.
### set_device ### set_device
Args: Args:

View File

@ -29,7 +29,7 @@ namespace graph_transforms {
// Renames all nodes not uses as graph inputs or outputs to short numerical // Renames all nodes not uses as graph inputs or outputs to short numerical
// forms. // forms.
Status ObsfucateNames(const GraphDef& input_graph_def, Status ObfuscateNames(const GraphDef& input_graph_def,
const TransformFuncContext& context, const TransformFuncContext& context,
GraphDef* output_graph_def) { GraphDef* output_graph_def) {
std::unordered_set<string> required_nodes; std::unordered_set<string> required_nodes;
@ -73,7 +73,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
output_graph_def->Clear(); output_graph_def->Clear();
for (const NodeDef& input_node : input_graph_def.node()) { for (const NodeDef& input_node : input_graph_def.node()) {
NodeDef* node = output_graph_def->mutable_node()->Add(); NodeDef* node = output_graph_def->mutable_node()->Add();
node->CopyFrom(input_node); *node = input_node;
const string& old_name = input_node.name(); const string& old_name = input_node.name();
node->set_name(new_names[old_name]); node->set_name(new_names[old_name]);
node->mutable_input()->Clear(); node->mutable_input()->Clear();
@ -94,7 +94,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
return Status::OK(); return Status::OK();
} }
REGISTER_GRAPH_TRANSFORM("obsfucate_names", ObsfucateNames); REGISTER_GRAPH_TRANSFORM("obfuscate_names", ObfuscateNames);
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow

View File

@ -29,11 +29,11 @@ namespace tensorflow {
namespace graph_transforms { namespace graph_transforms {
// Declare here, so we don't need a public header. // Declare here, so we don't need a public header.
Status ObsfucateNames(const GraphDef& input_graph_def, Status ObfuscateNames(const GraphDef& input_graph_def,
const TransformFuncContext& context, const TransformFuncContext& context,
GraphDef* output_graph_def); GraphDef* output_graph_def);
class ObsfucateNamesTest : public ::testing::Test { class ObfuscateNamesTest : public ::testing::Test {
protected: protected:
void TestSimpleTree() { void TestSimpleTree() {
GraphDef graph_def; GraphDef graph_def;
@ -74,7 +74,7 @@ class ObsfucateNamesTest : public ::testing::Test {
GraphDef result; GraphDef result;
TF_ASSERT_OK( TF_ASSERT_OK(
ObsfucateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result)); ObfuscateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
std::map<string, const NodeDef*> node_lookup; std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup); MapNamesToNodes(result, &node_lookup);
@ -97,7 +97,7 @@ class ObsfucateNamesTest : public ::testing::Test {
} }
GraphDef result; GraphDef result;
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"const_node0"}, {"const_node999"}}, TF_ASSERT_OK(ObfuscateNames(graph_def, {{"const_node0"}, {"const_node999"}},
&result)); &result));
std::map<string, const NodeDef*> node_lookup; std::map<string, const NodeDef*> node_lookup;
@ -116,7 +116,7 @@ class ObsfucateNamesTest : public ::testing::Test {
} }
GraphDef result; GraphDef result;
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"10"}, {"19"}}, &result)); TF_ASSERT_OK(ObfuscateNames(graph_def, {{"10"}, {"19"}}, &result));
std::map<string, const NodeDef*> node_lookup; std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup); MapNamesToNodes(result, &node_lookup);
@ -132,11 +132,11 @@ class ObsfucateNamesTest : public ::testing::Test {
} }
}; };
TEST_F(ObsfucateNamesTest, TestSimpleTree) { TestSimpleTree(); } TEST_F(ObfuscateNamesTest, TestSimpleTree) { TestSimpleTree(); }
TEST_F(ObsfucateNamesTest, TestManyNodes) { TestManyNodes(); } TEST_F(ObfuscateNamesTest, TestManyNodes) { TestManyNodes(); }
TEST_F(ObsfucateNamesTest, TestNameClashes) { TestNameClashes(); } TEST_F(ObfuscateNamesTest, TestNameClashes) { TestNameClashes(); }
} // namespace graph_transforms } // namespace graph_transforms
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,276 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
using strings::StrCat;
namespace graph_transforms {
namespace {
// Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
// non-zero tensor content.
Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
Tensor* values_tensor) {
if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
return tensorflow::errors::FailedPrecondition(
"Transform only applicable to subgraph with 'Const' with "
"tensor of shpae [N, 1]. But instead get shape ",
tensor.shape().DebugString(), ".");
}
auto flat = tensor.flat<float>();
std::vector<int64> indices;
std::vector<float> values;
for (int64 i = 0; i < flat.size(); i++) {
float val = flat(i);
if (std::abs(val) >= 1.0e-5) {
indices.push_back(i);
values.push_back(val);
}
}
// During model initialization, InitializeTableOp makes use of
// KeyValueTensorIterator, which does not accept empty keys or values.
// Consequently, adding a dummy pair of indices and values as a walkaround.
if (indices.empty() || values.empty()) {
indices.push_back(0);
values.push_back(0);
}
*indices_tensor = Tensor(DataTypeToEnum<int64>::value,
{static_cast<int64>(indices.size())});
std::copy_n(indices.begin(), indices.size(),
indices_tensor->flat<int64>().data());
*values_tensor =
Tensor(DataTypeToEnum<float>::value, {static_cast<int64>(values.size())});
std::copy_n(values.begin(), values.size(),
values_tensor->flat<float>().data());
return Status::OK();
}
void CreateConstNode(const Tensor& tensor, const string& name,
NodeDef* node_def) {
node_def->set_op("Const");
node_def->set_name(name);
SetNodeTensorAttr<float>("value", tensor, node_def);
}
} // namespace
Status SparsifyGather(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
GraphDef current_graph_def = input_graph_def;
bool any_match_found = false;
// The subgraphs may have overlapping components, therefore GraphMatcher
// doesn't return all subgraphs in one round -- this has to be multi-round
// update.
do {
any_match_found = false;
GraphDef replaced_graph_def = current_graph_def;
std::vector<string> init_table_node_names;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, // clang-format off
{"Gather",
{
{"Identity",
{
{"Const"}
}
},
{"*"},
}
}, // clang-format on
[&any_match_found, &init_table_node_names](
const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
any_match_found = true;
// The captured subgraph should be of the following pattern:
// Const --> Identity --> Gather --> ...
// ^
// |
// (ids)
//
// After transform, it becomes:
// --> NoOp(group_deps)
// |
// Const --> InitializeTable --> HashTable
// ^ |
// | |
// Const ------------- |
// v
// (ids) ---> LookupTableFind <--- Const(default)
// |
// v
// ...
// clang-format off
// For each subgraph, do the following
// 1. Sparsify the `Const`, creating two `Const`, for hashtable
// key/val.
// 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
// 3. Create a `HashTable` op connecting to `InitializeTable` op.
// 4. Replace the `Gather` with a `LookupTableFind` op.
// 5. Connect the `LookupTableFind` with
// a. `HashTable`
// b. `Gather`'s ids input
// c. a `default_val` arg, valued at 0
// clang-format on
const NodeDef& gather_node = match.node;
const NodeDef& const_node = match.inputs[0].inputs[0].node;
DataType data_type;
GetNodeAttr(const_node, "dtype", &data_type);
if (data_type != DT_FLOAT) {
return tensorflow::errors::FailedPrecondition(
"Transform only applicable to subgraph with 'Const' of dtype "
"'DT_FLOAT'. Found 'Const' with name '",
const_node.name(), "' and dtype '", data_type, "'.");
}
Tensor weight = GetNodeTensorAttr(const_node, "value");
Tensor indices_tensor;
Tensor values_tensor;
TF_RETURN_IF_ERROR(
SparsifyWeights(weight, &indices_tensor, &values_tensor));
// indices and values of sparsified `Const`
DataType key_dtype = DT_INT64;
NodeDef indices_node;
CreateConstNode(indices_tensor, StrCat(const_node.name(), "/indices"),
&indices_node);
SetNodeAttr("dtype", key_dtype, &indices_node);
NodeDef values_node;
CreateConstNode(values_tensor, StrCat(const_node.name(), "/values"),
&values_node);
SetNodeAttr("dtype", data_type, &values_node);
// HashTable node
NodeDef hashtable_node;
hashtable_node.set_op("HashTable");
hashtable_node.set_name(StrCat(const_node.name(), "/HashTable"));
SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
SetNodeAttr("value_dtype", data_type, &hashtable_node);
// InitializeTable node
NodeDef init_table_node;
init_table_node.set_op("InitializeTable");
init_table_node.set_name(
StrCat(const_node.name(), "/InitializeTable"));
SetNodeAttr("Tkey", key_dtype, &init_table_node);
SetNodeAttr("Tval", data_type, &init_table_node);
init_table_node_names.push_back(init_table_node.name());
// LookupTableFind node
NodeDef lookup_node;
lookup_node.set_op("LookupTableFind");
lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
SetNodeAttr("Tin", key_dtype, &lookup_node);
SetNodeAttr("Tout", data_type, &lookup_node);
// Default return value of hashtable lookup
Tensor zero_tensor(data_type, TensorShape({}));
zero_tensor.flat<float>()(0) = 0.0;
NodeDef default_value_node;
CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
&default_value_node);
SetNodeAttr("dtype", data_type, &default_value_node);
// ExpandDims argument
Tensor dim_idx(DT_INT32, TensorShape({}));
dim_idx.flat<int32>()(0) = -1;
NodeDef dim_idx_node;
dim_idx_node.set_op("Const");
dim_idx_node.set_name(
StrCat(gather_node.name(), "/ExpandDims/Const"));
SetNodeAttr("value", dim_idx, &dim_idx_node);
SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
// ExpandDims node
NodeDef expand_dims_node;
expand_dims_node.set_op("ExpandDims");
// Reuse gather_node's name so not to change dependent's inputs
expand_dims_node.set_name(gather_node.name());
SetNodeAttr("T", data_type, &expand_dims_node);
// Connect nodes
AddNodeInput(hashtable_node.name(), &init_table_node);
AddNodeInput(indices_node.name(), &init_table_node);
AddNodeInput(values_node.name(), &init_table_node);
AddNodeInput(hashtable_node.name(), &lookup_node);
AddNodeInput(gather_node.input(1), &lookup_node);
AddNodeInput(default_value_node.name(), &lookup_node);
AddNodeInput(lookup_node.name(), &expand_dims_node);
AddNodeInput(dim_idx_node.name(), &expand_dims_node);
// Copy 'ids' input of original 'Gather'
new_nodes->push_back(match.inputs[1].node);
new_nodes->push_back(indices_node);
new_nodes->push_back(values_node);
new_nodes->push_back(hashtable_node);
new_nodes->push_back(init_table_node);
new_nodes->push_back(lookup_node);
new_nodes->push_back(default_value_node);
new_nodes->push_back(dim_idx_node);
new_nodes->push_back(expand_dims_node);
return Status::OK();
},
{true}, &replaced_graph_def));
NodeDef* init_op = nullptr;
for (int i = 0; i < replaced_graph_def.node_size(); i++) {
if (replaced_graph_def.node(i).name() == "group_deps" &&
replaced_graph_def.node(i).op() == "NoOp") {
if (init_op != nullptr) {
return tensorflow::errors::FailedPrecondition(
"Multiple nodes with name: 'group_deps' and type: 'NoOp'.");
}
init_op = replaced_graph_def.mutable_node(i);
}
}
if (!init_op) {
return tensorflow::errors::FailedPrecondition(
"No node found with name: 'group_deps' and type: 'NoOp'");
}
for (const string& name : init_table_node_names) {
// Add control dependence from init_table_node to group_deps_node
AddNodeInput(StrCat("^", name), init_op);
}
current_graph_def = replaced_graph_def;
} while (any_match_found);
*output_graph_def = current_graph_def;
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
} // namespace graph_transforms
} // namespace tensorflow

View File

@ -0,0 +1,352 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
// Declare here, so we don't need a public header.
Status SparsifyGather(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def);
class SparsifyGatherTest : public ::testing::Test {
protected:
NodeDef* CreateNode(const string& name, const string& op,
const std::vector<NodeDef*>& inputs,
GraphDef* graph_def) {
NodeDef* node_def = graph_def->add_node();
node_def->set_name(name);
node_def->set_op(op);
std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
node_def->add_input(input->name());
});
return node_def;
}
void TestSinglePartitionConst() {
GraphDef graph_def;
// Build the graph.
NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
NodeDef* const_node = CreateNode("const", "Const", {}, &graph_def);
SetNodeAttr("dtype", DT_FLOAT, const_node);
// Set 'Const' node value.
Tensor weights(DT_FLOAT, TensorShape({4, 1}));
test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
SetNodeTensorAttr<float>("value", weights, const_node);
NodeDef* identity_node =
CreateNode("const/read", "Identity", {const_node}, &graph_def);
CreateNode("gather", "Gather", {identity_node, input_node}, &graph_def);
CreateNode("group_deps", "NoOp", {}, &graph_def);
// Run the op.
GraphDef result;
TransformFuncContext context;
context.input_names = {"ids"};
context.output_names = {"gather"};
TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
// Validation begins.
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
// Check nodes.
EXPECT_EQ(1, node_lookup.count("ids"));
EXPECT_EQ("Const", node_lookup.at("ids")->op());
// Nodes in "const" scope.
EXPECT_EQ(1, node_lookup.count("const/indices"));
EXPECT_EQ("Const", node_lookup.at("const/indices")->op());
Tensor expected_indices_tensor(DT_INT64, TensorShape({3}));
test::FillValues<int64>(&expected_indices_tensor, {0, 2, 3});
test::ExpectTensorEqual<int64>(
expected_indices_tensor,
GetNodeTensorAttr(*(node_lookup.at("const/indices")), "value"));
EXPECT_EQ(1, node_lookup.count("const/values"));
EXPECT_EQ("Const", node_lookup.at("const/values")->op());
Tensor expected_values_tensor(DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected_values_tensor, {0.2, 1.2, 0.001});
test::ExpectTensorNear<float>(
expected_values_tensor,
GetNodeTensorAttr(*(node_lookup.at("const/values")), "value"), 1e-5);
EXPECT_EQ(1, node_lookup.count("const/HashTable"));
EXPECT_EQ("HashTable", node_lookup.at("const/HashTable")->op());
EXPECT_EQ(1, node_lookup.count("const/InitializeTable"));
EXPECT_EQ("InitializeTable", node_lookup.at("const/InitializeTable")->op());
// Nodes in "gather" scope.
EXPECT_EQ(1, node_lookup.count("gather/LookupTableFind"));
EXPECT_EQ("LookupTableFind",
node_lookup.at("gather/LookupTableFind")->op());
EXPECT_EQ(1, node_lookup.count("gather/Const"));
EXPECT_EQ("Const", node_lookup.at("gather/Const")->op());
Tensor expected_gather_default_tensor(DT_FLOAT, TensorShape({}));
test::FillValues<float>(&expected_gather_default_tensor, {0.0});
test::ExpectTensorNear<float>(
expected_gather_default_tensor,
GetNodeTensorAttr(*(node_lookup.at("gather/Const")), "value"), 1e-5);
EXPECT_EQ(1, node_lookup.count("gather/ExpandDims/Const"));
EXPECT_EQ("Const", node_lookup.at("gather/ExpandDims/Const")->op());
Tensor expected_expand_dims_tensor(DT_INT32, TensorShape({}));
test::FillValues<int32>(&expected_expand_dims_tensor, {-1});
test::ExpectTensorEqual<int32>(
expected_expand_dims_tensor,
GetNodeTensorAttr(*(node_lookup.at("gather/ExpandDims/Const")),
"value"));
EXPECT_EQ(1, node_lookup.count("gather"));
EXPECT_EQ("ExpandDims", node_lookup.at("gather")->op());
EXPECT_EQ(1, node_lookup.count("group_deps"));
EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op());
// Check connections
EXPECT_EQ("const/HashTable",
node_lookup.at("const/InitializeTable")->input(0));
EXPECT_EQ("const/indices",
node_lookup.at("const/InitializeTable")->input(1));
EXPECT_EQ("const/values",
node_lookup.at("const/InitializeTable")->input(2));
EXPECT_EQ("const/HashTable",
node_lookup.at("gather/LookupTableFind")->input(0));
EXPECT_EQ("ids", node_lookup.at("gather/LookupTableFind")->input(1));
EXPECT_EQ("gather/Const",
node_lookup.at("gather/LookupTableFind")->input(2));
EXPECT_EQ("gather/LookupTableFind", node_lookup.at("gather")->input(0));
// Check control dependency.
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
node_lookup.at("group_deps")->input().end(),
"^const/InitializeTable"),
node_lookup.at("group_deps")->input().end());
}
void TestMultiPartitionConst() {
// The 'ids' node is served input for two 'Gather's.
GraphDef graph_def;
// Build Graph:
// Shared input node
NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
// Shared init node
CreateNode("group_deps", "NoOp", {}, &graph_def);
// Two partitions
NodeDef* const_node1 = CreateNode("const1", "Const", {}, &graph_def);
SetNodeAttr("dtype", DT_FLOAT, const_node1);
// Set 'Const' node value.
Tensor weights(DT_FLOAT, TensorShape({4, 1}));
test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
SetNodeTensorAttr<float>("value", weights, const_node1);
NodeDef* const_node2 = CreateNode("const2", "Const", {}, &graph_def);
SetNodeAttr("dtype", DT_FLOAT, const_node2);
SetNodeTensorAttr<float>("value", weights, const_node2);
NodeDef* identity_node1 =
CreateNode("const1/read", "Identity", {const_node1}, &graph_def);
NodeDef* identity_node2 =
CreateNode("const2/read", "Identity", {const_node2}, &graph_def);
CreateNode("gather1", "Gather", {identity_node1, input_node}, &graph_def);
CreateNode("gather2", "Gather", {identity_node2, input_node}, &graph_def);
// Run the op.
GraphDef result;
TransformFuncContext context;
context.input_names = {"ids"};
context.output_names = {"gather1", "gather2"};
TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
// Validation begins.
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
// Check nodes.
// Check shared nodes:
EXPECT_EQ(1, node_lookup.count("ids"));
EXPECT_EQ("Const", node_lookup.at("ids")->op());
EXPECT_EQ(1, node_lookup.count("group_deps"));
EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op());
// Nodes in "const1" scope.
EXPECT_EQ(1, node_lookup.count("const1/indices"));
EXPECT_EQ("Const", node_lookup.at("const1/indices")->op());
Tensor expected_indices_tensor1(DT_INT64, TensorShape({3}));
test::FillValues<int64>(&expected_indices_tensor1, {0, 2, 3});
test::ExpectTensorEqual<int64>(
expected_indices_tensor1,
GetNodeTensorAttr(*(node_lookup.at("const1/indices")), "value"));
EXPECT_EQ(1, node_lookup.count("const1/values"));
EXPECT_EQ("Const", node_lookup.at("const1/values")->op());
Tensor expected_values_tensor1(DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected_values_tensor1, {0.2, 1.2, 0.001});
test::ExpectTensorNear<float>(
expected_values_tensor1,
GetNodeTensorAttr(*(node_lookup.at("const1/values")), "value"), 1e-5);
EXPECT_EQ(1, node_lookup.count("const1/HashTable"));
EXPECT_EQ("HashTable", node_lookup.at("const1/HashTable")->op());
EXPECT_EQ(1, node_lookup.count("const1/InitializeTable"));
EXPECT_EQ("InitializeTable",
node_lookup.at("const1/InitializeTable")->op());
// Nodes in "gather1" scope.
EXPECT_EQ(1, node_lookup.count("gather1/LookupTableFind"));
EXPECT_EQ("LookupTableFind",
node_lookup.at("gather1/LookupTableFind")->op());
EXPECT_EQ(1, node_lookup.count("gather1/Const"));
EXPECT_EQ("Const", node_lookup.at("gather1/Const")->op());
Tensor expected_gather_default_tensor1(DT_FLOAT, TensorShape({}));
test::FillValues<float>(&expected_gather_default_tensor1, {0.0});
test::ExpectTensorNear<float>(
expected_gather_default_tensor1,
GetNodeTensorAttr(*(node_lookup.at("gather1/Const")), "value"), 1e-5);
EXPECT_EQ(1, node_lookup.count("gather1/ExpandDims/Const"));
EXPECT_EQ("Const", node_lookup.at("gather1/ExpandDims/Const")->op());
Tensor expected_expand_dims_tensor1(DT_INT32, TensorShape({}));
test::FillValues<int32>(&expected_expand_dims_tensor1, {-1});
test::ExpectTensorEqual<int32>(
expected_expand_dims_tensor1,
GetNodeTensorAttr(*(node_lookup.at("gather1/ExpandDims/Const")),
"value"));
EXPECT_EQ(1, node_lookup.count("gather1"));
EXPECT_EQ("ExpandDims", node_lookup.at("gather1")->op());
// Nodes in "const2" scope.
EXPECT_EQ(1, node_lookup.count("const2/indices"));
EXPECT_EQ("Const", node_lookup.at("const2/indices")->op());
Tensor expected_indices_tensor2(DT_INT64, TensorShape({3}));
test::FillValues<int64>(&expected_indices_tensor2, {0, 2, 3});
test::ExpectTensorEqual<int64>(
expected_indices_tensor2,
GetNodeTensorAttr(*(node_lookup.at("const2/indices")), "value"));
EXPECT_EQ(1, node_lookup.count("const2/values"));
EXPECT_EQ("Const", node_lookup.at("const2/values")->op());
Tensor expected_values_tensor2(DT_FLOAT, TensorShape({3}));
test::FillValues<float>(&expected_values_tensor2, {0.2, 1.2, 0.001});
test::ExpectTensorNear<float>(
expected_values_tensor2,
GetNodeTensorAttr(*(node_lookup.at("const2/values")), "value"), 1e-5);
EXPECT_EQ(1, node_lookup.count("const2/HashTable"));
EXPECT_EQ("HashTable", node_lookup.at("const2/HashTable")->op());
EXPECT_EQ(1, node_lookup.count("const2/InitializeTable"));
EXPECT_EQ("InitializeTable",
node_lookup.at("const2/InitializeTable")->op());
// Nodes in "gather2" scope.
EXPECT_EQ(1, node_lookup.count("gather2/LookupTableFind"));
EXPECT_EQ("LookupTableFind",
node_lookup.at("gather2/LookupTableFind")->op());
EXPECT_EQ(1, node_lookup.count("gather2/Const"));
EXPECT_EQ("Const", node_lookup.at("gather2/Const")->op());
Tensor expected_gather_default_tensor2(DT_FLOAT, TensorShape({}));
test::FillValues<float>(&expected_gather_default_tensor2, {0.0});
test::ExpectTensorNear<float>(
expected_gather_default_tensor2,
GetNodeTensorAttr(*(node_lookup.at("gather2/Const")), "value"), 1e-5);
EXPECT_EQ(1, node_lookup.count("gather2/ExpandDims/Const"));
EXPECT_EQ("Const", node_lookup.at("gather2/ExpandDims/Const")->op());
Tensor expected_expand_dims_tensor2(DT_INT32, TensorShape({}));
test::FillValues<int32>(&expected_expand_dims_tensor2, {-1});
test::ExpectTensorEqual<int32>(
expected_expand_dims_tensor2,
GetNodeTensorAttr(*(node_lookup.at("gather2/ExpandDims/Const")),
"value"));
EXPECT_EQ(1, node_lookup.count("gather2"));
EXPECT_EQ("ExpandDims", node_lookup.at("gather2")->op());
// Check connections
EXPECT_EQ("const1/HashTable",
node_lookup.at("const1/InitializeTable")->input(0));
EXPECT_EQ("const1/indices",
node_lookup.at("const1/InitializeTable")->input(1));
EXPECT_EQ("const1/values",
node_lookup.at("const1/InitializeTable")->input(2));
EXPECT_EQ("const2/HashTable",
node_lookup.at("const2/InitializeTable")->input(0));
EXPECT_EQ("const2/indices",
node_lookup.at("const2/InitializeTable")->input(1));
EXPECT_EQ("const2/values",
node_lookup.at("const2/InitializeTable")->input(2));
EXPECT_EQ("const1/HashTable",
node_lookup.at("gather1/LookupTableFind")->input(0));
EXPECT_EQ("ids", node_lookup.at("gather1/LookupTableFind")->input(1));
EXPECT_EQ("gather1/Const",
node_lookup.at("gather1/LookupTableFind")->input(2));
EXPECT_EQ("gather1/LookupTableFind", node_lookup.at("gather1")->input(0));
EXPECT_EQ("const2/HashTable",
node_lookup.at("gather2/LookupTableFind")->input(0));
EXPECT_EQ("ids", node_lookup.at("gather2/LookupTableFind")->input(1));
EXPECT_EQ("gather2/Const",
node_lookup.at("gather2/LookupTableFind")->input(2));
EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0));
// Check control deps.
EXPECT_EQ(2, node_lookup.at("group_deps")->input_size());
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
node_lookup.at("group_deps")->input().end(),
"^const1/InitializeTable"),
node_lookup.at("group_deps")->input().end());
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
node_lookup.at("group_deps")->input().end(),
"^const2/InitializeTable"),
node_lookup.at("group_deps")->input().end());
}
};
TEST_F(SparsifyGatherTest, TestSinglePartitionConst) {
TestSinglePartitionConst();
}
TEST_F(SparsifyGatherTest, TestMultiPartitionConst) {
TestMultiPartitionConst();
}
} // namespace graph_transforms
} // namespace tensorflow