Add sparsify_gather op to reduce linear model memory footprint.
Also fixed typo: "obsfucate" -> "obfuscate" Change: 149627297
This commit is contained in:
parent
62be492ef4
commit
168e8dacba
@ -65,7 +65,7 @@ cc_library(
|
||||
"freeze_requantization_ranges.cc",
|
||||
"fuse_convolutions.cc",
|
||||
"insert_logging.cc",
|
||||
"obsfucate_names.cc",
|
||||
"obfuscate_names.cc",
|
||||
"remove_attribute.cc",
|
||||
"remove_device.cc",
|
||||
"remove_nodes.cc",
|
||||
@ -73,6 +73,7 @@ cc_library(
|
||||
"rename_op.cc",
|
||||
"set_device.cc",
|
||||
"sort_by_execution_order.cc",
|
||||
"sparsify_gather.cc",
|
||||
"strip_unused_nodes.cc",
|
||||
] + if_not_windows([
|
||||
"quantize_nodes.cc",
|
||||
@ -111,7 +112,7 @@ tf_cc_test(
|
||||
"freeze_requantization_ranges_test.cc",
|
||||
"fuse_convolutions_test.cc",
|
||||
"insert_logging_test.cc",
|
||||
"obsfucate_names_test.cc",
|
||||
"obfuscate_names_test.cc",
|
||||
"quantize_nodes_test.cc",
|
||||
"quantize_weights_test.cc",
|
||||
"remove_attribute_test.cc",
|
||||
@ -122,6 +123,7 @@ tf_cc_test(
|
||||
"round_weights_test.cc",
|
||||
"set_device_test.cc",
|
||||
"sort_by_execution_order_test.cc",
|
||||
"sparsify_gather_test.cc",
|
||||
"strip_unused_nodes_test.cc",
|
||||
],
|
||||
deps = [
|
||||
|
@ -20,7 +20,7 @@
|
||||
* [fuse_convolutions](#fuse_convolutions)
|
||||
* [insert_logging](#insert_logging)
|
||||
* [merge_duplicate_nodes](#merge_duplicate_nodes)
|
||||
* [obsfucate_names](#obsfucate_names)
|
||||
* [obfuscate_names](#obfuscate_names)
|
||||
* [quantize_nodes](#quantize_nodes)
|
||||
* [quantize_weights](#quantize_weights)
|
||||
* [remove_attribute](#remove_attribute)
|
||||
@ -29,6 +29,7 @@
|
||||
* [rename_attribute](#rename_attribute)
|
||||
* [rename_op](#rename_op)
|
||||
* [round_weights](#round_weights)
|
||||
* [sparsify_gather](#sparsify_gather)
|
||||
* [set_device](#set_device)
|
||||
* [sort_by_execution_order](#sort_by_execution_order)
|
||||
* [strip_unused_nodes](#strip_unused_nodes)
|
||||
@ -250,7 +251,7 @@ results are cached and so you shouldn't see the graph run any more slowly.
|
||||
So far we've been concentrating on weights because those generally take up the
|
||||
most space. If you have a graph with a lot of small nodes in it, the names of
|
||||
those nodes can start to take up a noticeable amount of space too. To shrink
|
||||
those down, you can run the [obsfucate_names](#obsfucate_names) transform, which
|
||||
those down, you can run the [obfuscate_names](#obfuscate_names) transform, which
|
||||
replaces all the names (except for inputs and outputs) with short, cryptic but
|
||||
unique ids:
|
||||
|
||||
@ -262,7 +263,7 @@ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||
--inputs='Mul:0' \
|
||||
--outputs='softmax:0' \
|
||||
--transforms='\
|
||||
obsfucate_names \
|
||||
obfuscate_names \
|
||||
'
|
||||
```
|
||||
|
||||
@ -520,7 +521,7 @@ of redundancy (e.g. this transform is always run as part of
|
||||
[quantize_nodes](#quantize_nodes) since the processing there can introduce
|
||||
duplicates of constants that are used in the quantize/dequantize process).
|
||||
|
||||
### obsfucate_names
|
||||
### obfuscate_names
|
||||
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
@ -656,6 +657,16 @@ between the largest and smallest values present. This is useful when you'll be
|
||||
deploying on mobile, and you want a model that will compress effectively. See
|
||||
[shrinking file size](#shrinking-file-size) for more details.
|
||||
|
||||
### sparsify_gather
|
||||
|
||||
Args: None \
|
||||
Prerequisites: None
|
||||
|
||||
Transform 'Gather' op to a sparsified version where 'params' input of 'Gather'
|
||||
is replaced from a dense 'Const' to a 'HashTable'. 'Gather' op itself is
|
||||
replaced by a hashtable lookup. This is mostly useful for reducing sparse
|
||||
TF.learn linear model memory footprint.
|
||||
|
||||
### set_device
|
||||
|
||||
Args:
|
||||
|
@ -29,7 +29,7 @@ namespace graph_transforms {
|
||||
|
||||
// Renames all nodes not uses as graph inputs or outputs to short numerical
|
||||
// forms.
|
||||
Status ObsfucateNames(const GraphDef& input_graph_def,
|
||||
Status ObfuscateNames(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
std::unordered_set<string> required_nodes;
|
||||
@ -73,7 +73,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
|
||||
output_graph_def->Clear();
|
||||
for (const NodeDef& input_node : input_graph_def.node()) {
|
||||
NodeDef* node = output_graph_def->mutable_node()->Add();
|
||||
node->CopyFrom(input_node);
|
||||
*node = input_node;
|
||||
const string& old_name = input_node.name();
|
||||
node->set_name(new_names[old_name]);
|
||||
node->mutable_input()->Clear();
|
||||
@ -94,7 +94,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("obsfucate_names", ObsfucateNames);
|
||||
REGISTER_GRAPH_TRANSFORM("obfuscate_names", ObfuscateNames);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
@ -29,11 +29,11 @@ namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status ObsfucateNames(const GraphDef& input_graph_def,
|
||||
Status ObfuscateNames(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
class ObsfucateNamesTest : public ::testing::Test {
|
||||
class ObfuscateNamesTest : public ::testing::Test {
|
||||
protected:
|
||||
void TestSimpleTree() {
|
||||
GraphDef graph_def;
|
||||
@ -74,7 +74,7 @@ class ObsfucateNamesTest : public ::testing::Test {
|
||||
|
||||
GraphDef result;
|
||||
TF_ASSERT_OK(
|
||||
ObsfucateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
|
||||
ObfuscateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
|
||||
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
MapNamesToNodes(result, &node_lookup);
|
||||
@ -97,7 +97,7 @@ class ObsfucateNamesTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
GraphDef result;
|
||||
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"const_node0"}, {"const_node999"}},
|
||||
TF_ASSERT_OK(ObfuscateNames(graph_def, {{"const_node0"}, {"const_node999"}},
|
||||
&result));
|
||||
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
@ -116,7 +116,7 @@ class ObsfucateNamesTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
GraphDef result;
|
||||
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"10"}, {"19"}}, &result));
|
||||
TF_ASSERT_OK(ObfuscateNames(graph_def, {{"10"}, {"19"}}, &result));
|
||||
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
MapNamesToNodes(result, &node_lookup);
|
||||
@ -132,11 +132,11 @@ class ObsfucateNamesTest : public ::testing::Test {
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ObsfucateNamesTest, TestSimpleTree) { TestSimpleTree(); }
|
||||
TEST_F(ObfuscateNamesTest, TestSimpleTree) { TestSimpleTree(); }
|
||||
|
||||
TEST_F(ObsfucateNamesTest, TestManyNodes) { TestManyNodes(); }
|
||||
TEST_F(ObfuscateNamesTest, TestManyNodes) { TestManyNodes(); }
|
||||
|
||||
TEST_F(ObsfucateNamesTest, TestNameClashes) { TestNameClashes(); }
|
||||
TEST_F(ObfuscateNamesTest, TestNameClashes) { TestNameClashes(); }
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
276
tensorflow/tools/graph_transforms/sparsify_gather.cc
Normal file
276
tensorflow/tools/graph_transforms/sparsify_gather.cc
Normal file
@ -0,0 +1,276 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/subgraph.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using strings::StrCat;
|
||||
namespace graph_transforms {
|
||||
namespace {
|
||||
|
||||
// Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
|
||||
// non-zero tensor content.
|
||||
Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
|
||||
Tensor* values_tensor) {
|
||||
if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Transform only applicable to subgraph with 'Const' with "
|
||||
"tensor of shpae [N, 1]. But instead get shape ",
|
||||
tensor.shape().DebugString(), ".");
|
||||
}
|
||||
|
||||
auto flat = tensor.flat<float>();
|
||||
std::vector<int64> indices;
|
||||
std::vector<float> values;
|
||||
|
||||
for (int64 i = 0; i < flat.size(); i++) {
|
||||
float val = flat(i);
|
||||
if (std::abs(val) >= 1.0e-5) {
|
||||
indices.push_back(i);
|
||||
values.push_back(val);
|
||||
}
|
||||
}
|
||||
|
||||
// During model initialization, InitializeTableOp makes use of
|
||||
// KeyValueTensorIterator, which does not accept empty keys or values.
|
||||
// Consequently, adding a dummy pair of indices and values as a walkaround.
|
||||
if (indices.empty() || values.empty()) {
|
||||
indices.push_back(0);
|
||||
values.push_back(0);
|
||||
}
|
||||
*indices_tensor = Tensor(DataTypeToEnum<int64>::value,
|
||||
{static_cast<int64>(indices.size())});
|
||||
std::copy_n(indices.begin(), indices.size(),
|
||||
indices_tensor->flat<int64>().data());
|
||||
|
||||
*values_tensor =
|
||||
Tensor(DataTypeToEnum<float>::value, {static_cast<int64>(values.size())});
|
||||
std::copy_n(values.begin(), values.size(),
|
||||
values_tensor->flat<float>().data());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CreateConstNode(const Tensor& tensor, const string& name,
|
||||
NodeDef* node_def) {
|
||||
node_def->set_op("Const");
|
||||
node_def->set_name(name);
|
||||
SetNodeTensorAttr<float>("value", tensor, node_def);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status SparsifyGather(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
GraphDef current_graph_def = input_graph_def;
|
||||
bool any_match_found = false;
|
||||
// The subgraphs may have overlapping components, therefore GraphMatcher
|
||||
// doesn't return all subgraphs in one round -- this has to be multi-round
|
||||
// update.
|
||||
do {
|
||||
any_match_found = false;
|
||||
GraphDef replaced_graph_def = current_graph_def;
|
||||
std::vector<string> init_table_node_names;
|
||||
|
||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||
current_graph_def, // clang-format off
|
||||
{"Gather",
|
||||
{
|
||||
{"Identity",
|
||||
{
|
||||
{"Const"}
|
||||
}
|
||||
},
|
||||
{"*"},
|
||||
}
|
||||
}, // clang-format on
|
||||
[&any_match_found, &init_table_node_names](
|
||||
const NodeMatch& match, const std::set<string>& input_nodes,
|
||||
const std::set<string>& output_nodes,
|
||||
std::vector<NodeDef>* new_nodes) {
|
||||
any_match_found = true;
|
||||
|
||||
// The captured subgraph should be of the following pattern:
|
||||
// Const --> Identity --> Gather --> ...
|
||||
// ^
|
||||
// |
|
||||
// (ids)
|
||||
//
|
||||
// After transform, it becomes:
|
||||
// --> NoOp(group_deps)
|
||||
// |
|
||||
// Const --> InitializeTable --> HashTable
|
||||
// ^ |
|
||||
// | |
|
||||
// Const ------------- |
|
||||
// v
|
||||
// (ids) ---> LookupTableFind <--- Const(default)
|
||||
// |
|
||||
// v
|
||||
// ...
|
||||
|
||||
// clang-format off
|
||||
// For each subgraph, do the following
|
||||
// 1. Sparsify the `Const`, creating two `Const`, for hashtable
|
||||
// key/val.
|
||||
// 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
|
||||
// 3. Create a `HashTable` op connecting to `InitializeTable` op.
|
||||
// 4. Replace the `Gather` with a `LookupTableFind` op.
|
||||
// 5. Connect the `LookupTableFind` with
|
||||
// a. `HashTable`
|
||||
// b. `Gather`'s ids input
|
||||
// c. a `default_val` arg, valued at 0
|
||||
// clang-format on
|
||||
const NodeDef& gather_node = match.node;
|
||||
const NodeDef& const_node = match.inputs[0].inputs[0].node;
|
||||
|
||||
DataType data_type;
|
||||
GetNodeAttr(const_node, "dtype", &data_type);
|
||||
if (data_type != DT_FLOAT) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Transform only applicable to subgraph with 'Const' of dtype "
|
||||
"'DT_FLOAT'. Found 'Const' with name '",
|
||||
const_node.name(), "' and dtype '", data_type, "'.");
|
||||
}
|
||||
Tensor weight = GetNodeTensorAttr(const_node, "value");
|
||||
Tensor indices_tensor;
|
||||
Tensor values_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
SparsifyWeights(weight, &indices_tensor, &values_tensor));
|
||||
|
||||
// indices and values of sparsified `Const`
|
||||
DataType key_dtype = DT_INT64;
|
||||
NodeDef indices_node;
|
||||
CreateConstNode(indices_tensor, StrCat(const_node.name(), "/indices"),
|
||||
&indices_node);
|
||||
SetNodeAttr("dtype", key_dtype, &indices_node);
|
||||
|
||||
NodeDef values_node;
|
||||
CreateConstNode(values_tensor, StrCat(const_node.name(), "/values"),
|
||||
&values_node);
|
||||
SetNodeAttr("dtype", data_type, &values_node);
|
||||
|
||||
// HashTable node
|
||||
NodeDef hashtable_node;
|
||||
hashtable_node.set_op("HashTable");
|
||||
hashtable_node.set_name(StrCat(const_node.name(), "/HashTable"));
|
||||
SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
|
||||
SetNodeAttr("value_dtype", data_type, &hashtable_node);
|
||||
|
||||
// InitializeTable node
|
||||
NodeDef init_table_node;
|
||||
init_table_node.set_op("InitializeTable");
|
||||
init_table_node.set_name(
|
||||
StrCat(const_node.name(), "/InitializeTable"));
|
||||
SetNodeAttr("Tkey", key_dtype, &init_table_node);
|
||||
SetNodeAttr("Tval", data_type, &init_table_node);
|
||||
init_table_node_names.push_back(init_table_node.name());
|
||||
|
||||
// LookupTableFind node
|
||||
NodeDef lookup_node;
|
||||
lookup_node.set_op("LookupTableFind");
|
||||
lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
|
||||
SetNodeAttr("Tin", key_dtype, &lookup_node);
|
||||
SetNodeAttr("Tout", data_type, &lookup_node);
|
||||
|
||||
// Default return value of hashtable lookup
|
||||
Tensor zero_tensor(data_type, TensorShape({}));
|
||||
zero_tensor.flat<float>()(0) = 0.0;
|
||||
NodeDef default_value_node;
|
||||
CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
|
||||
&default_value_node);
|
||||
SetNodeAttr("dtype", data_type, &default_value_node);
|
||||
|
||||
// ExpandDims argument
|
||||
Tensor dim_idx(DT_INT32, TensorShape({}));
|
||||
dim_idx.flat<int32>()(0) = -1;
|
||||
NodeDef dim_idx_node;
|
||||
dim_idx_node.set_op("Const");
|
||||
dim_idx_node.set_name(
|
||||
StrCat(gather_node.name(), "/ExpandDims/Const"));
|
||||
SetNodeAttr("value", dim_idx, &dim_idx_node);
|
||||
SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
|
||||
|
||||
// ExpandDims node
|
||||
NodeDef expand_dims_node;
|
||||
expand_dims_node.set_op("ExpandDims");
|
||||
// Reuse gather_node's name so not to change dependent's inputs
|
||||
expand_dims_node.set_name(gather_node.name());
|
||||
SetNodeAttr("T", data_type, &expand_dims_node);
|
||||
|
||||
// Connect nodes
|
||||
AddNodeInput(hashtable_node.name(), &init_table_node);
|
||||
AddNodeInput(indices_node.name(), &init_table_node);
|
||||
AddNodeInput(values_node.name(), &init_table_node);
|
||||
|
||||
AddNodeInput(hashtable_node.name(), &lookup_node);
|
||||
AddNodeInput(gather_node.input(1), &lookup_node);
|
||||
AddNodeInput(default_value_node.name(), &lookup_node);
|
||||
|
||||
AddNodeInput(lookup_node.name(), &expand_dims_node);
|
||||
AddNodeInput(dim_idx_node.name(), &expand_dims_node);
|
||||
|
||||
// Copy 'ids' input of original 'Gather'
|
||||
new_nodes->push_back(match.inputs[1].node);
|
||||
new_nodes->push_back(indices_node);
|
||||
new_nodes->push_back(values_node);
|
||||
new_nodes->push_back(hashtable_node);
|
||||
new_nodes->push_back(init_table_node);
|
||||
new_nodes->push_back(lookup_node);
|
||||
new_nodes->push_back(default_value_node);
|
||||
new_nodes->push_back(dim_idx_node);
|
||||
new_nodes->push_back(expand_dims_node);
|
||||
|
||||
return Status::OK();
|
||||
},
|
||||
{true}, &replaced_graph_def));
|
||||
NodeDef* init_op = nullptr;
|
||||
for (int i = 0; i < replaced_graph_def.node_size(); i++) {
|
||||
if (replaced_graph_def.node(i).name() == "group_deps" &&
|
||||
replaced_graph_def.node(i).op() == "NoOp") {
|
||||
if (init_op != nullptr) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"Multiple nodes with name: 'group_deps' and type: 'NoOp'.");
|
||||
}
|
||||
init_op = replaced_graph_def.mutable_node(i);
|
||||
}
|
||||
}
|
||||
if (!init_op) {
|
||||
return tensorflow::errors::FailedPrecondition(
|
||||
"No node found with name: 'group_deps' and type: 'NoOp'");
|
||||
}
|
||||
for (const string& name : init_table_node_names) {
|
||||
// Add control dependence from init_table_node to group_deps_node
|
||||
AddNodeInput(StrCat("^", name), init_op);
|
||||
}
|
||||
current_graph_def = replaced_graph_def;
|
||||
} while (any_match_found);
|
||||
*output_graph_def = current_graph_def;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
352
tensorflow/tools/graph_transforms/sparsify_gather_test.cc
Normal file
352
tensorflow/tools/graph_transforms/sparsify_gather_test.cc
Normal file
@ -0,0 +1,352 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status SparsifyGather(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
class SparsifyGatherTest : public ::testing::Test {
|
||||
protected:
|
||||
NodeDef* CreateNode(const string& name, const string& op,
|
||||
const std::vector<NodeDef*>& inputs,
|
||||
GraphDef* graph_def) {
|
||||
NodeDef* node_def = graph_def->add_node();
|
||||
node_def->set_name(name);
|
||||
node_def->set_op(op);
|
||||
std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
|
||||
node_def->add_input(input->name());
|
||||
});
|
||||
return node_def;
|
||||
}
|
||||
|
||||
void TestSinglePartitionConst() {
|
||||
GraphDef graph_def;
|
||||
|
||||
// Build the graph.
|
||||
NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
|
||||
NodeDef* const_node = CreateNode("const", "Const", {}, &graph_def);
|
||||
SetNodeAttr("dtype", DT_FLOAT, const_node);
|
||||
// Set 'Const' node value.
|
||||
Tensor weights(DT_FLOAT, TensorShape({4, 1}));
|
||||
test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
|
||||
SetNodeTensorAttr<float>("value", weights, const_node);
|
||||
|
||||
NodeDef* identity_node =
|
||||
CreateNode("const/read", "Identity", {const_node}, &graph_def);
|
||||
CreateNode("gather", "Gather", {identity_node, input_node}, &graph_def);
|
||||
CreateNode("group_deps", "NoOp", {}, &graph_def);
|
||||
|
||||
// Run the op.
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {"ids"};
|
||||
context.output_names = {"gather"};
|
||||
TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
|
||||
|
||||
// Validation begins.
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
MapNamesToNodes(result, &node_lookup);
|
||||
|
||||
// Check nodes.
|
||||
EXPECT_EQ(1, node_lookup.count("ids"));
|
||||
EXPECT_EQ("Const", node_lookup.at("ids")->op());
|
||||
|
||||
// Nodes in "const" scope.
|
||||
EXPECT_EQ(1, node_lookup.count("const/indices"));
|
||||
EXPECT_EQ("Const", node_lookup.at("const/indices")->op());
|
||||
Tensor expected_indices_tensor(DT_INT64, TensorShape({3}));
|
||||
test::FillValues<int64>(&expected_indices_tensor, {0, 2, 3});
|
||||
test::ExpectTensorEqual<int64>(
|
||||
expected_indices_tensor,
|
||||
GetNodeTensorAttr(*(node_lookup.at("const/indices")), "value"));
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const/values"));
|
||||
EXPECT_EQ("Const", node_lookup.at("const/values")->op());
|
||||
Tensor expected_values_tensor(DT_FLOAT, TensorShape({3}));
|
||||
test::FillValues<float>(&expected_values_tensor, {0.2, 1.2, 0.001});
|
||||
test::ExpectTensorNear<float>(
|
||||
expected_values_tensor,
|
||||
GetNodeTensorAttr(*(node_lookup.at("const/values")), "value"), 1e-5);
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const/HashTable"));
|
||||
EXPECT_EQ("HashTable", node_lookup.at("const/HashTable")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const/InitializeTable"));
|
||||
EXPECT_EQ("InitializeTable", node_lookup.at("const/InitializeTable")->op());
|
||||
|
||||
// Nodes in "gather" scope.
|
||||
EXPECT_EQ(1, node_lookup.count("gather/LookupTableFind"));
|
||||
EXPECT_EQ("LookupTableFind",
|
||||
node_lookup.at("gather/LookupTableFind")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather/Const"));
|
||||
EXPECT_EQ("Const", node_lookup.at("gather/Const")->op());
|
||||
Tensor expected_gather_default_tensor(DT_FLOAT, TensorShape({}));
|
||||
test::FillValues<float>(&expected_gather_default_tensor, {0.0});
|
||||
test::ExpectTensorNear<float>(
|
||||
expected_gather_default_tensor,
|
||||
GetNodeTensorAttr(*(node_lookup.at("gather/Const")), "value"), 1e-5);
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather/ExpandDims/Const"));
|
||||
EXPECT_EQ("Const", node_lookup.at("gather/ExpandDims/Const")->op());
|
||||
Tensor expected_expand_dims_tensor(DT_INT32, TensorShape({}));
|
||||
test::FillValues<int32>(&expected_expand_dims_tensor, {-1});
|
||||
test::ExpectTensorEqual<int32>(
|
||||
expected_expand_dims_tensor,
|
||||
GetNodeTensorAttr(*(node_lookup.at("gather/ExpandDims/Const")),
|
||||
"value"));
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather"));
|
||||
EXPECT_EQ("ExpandDims", node_lookup.at("gather")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("group_deps"));
|
||||
EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op());
|
||||
|
||||
// Check connections
|
||||
EXPECT_EQ("const/HashTable",
|
||||
node_lookup.at("const/InitializeTable")->input(0));
|
||||
EXPECT_EQ("const/indices",
|
||||
node_lookup.at("const/InitializeTable")->input(1));
|
||||
EXPECT_EQ("const/values",
|
||||
node_lookup.at("const/InitializeTable")->input(2));
|
||||
|
||||
EXPECT_EQ("const/HashTable",
|
||||
node_lookup.at("gather/LookupTableFind")->input(0));
|
||||
EXPECT_EQ("ids", node_lookup.at("gather/LookupTableFind")->input(1));
|
||||
EXPECT_EQ("gather/Const",
|
||||
node_lookup.at("gather/LookupTableFind")->input(2));
|
||||
|
||||
EXPECT_EQ("gather/LookupTableFind", node_lookup.at("gather")->input(0));
|
||||
|
||||
// Check control dependency.
|
||||
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
|
||||
node_lookup.at("group_deps")->input().end(),
|
||||
"^const/InitializeTable"),
|
||||
node_lookup.at("group_deps")->input().end());
|
||||
}
|
||||
|
||||
void TestMultiPartitionConst() {
|
||||
// The 'ids' node is served input for two 'Gather's.
|
||||
GraphDef graph_def;
|
||||
|
||||
// Build Graph:
|
||||
// Shared input node
|
||||
NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
|
||||
// Shared init node
|
||||
CreateNode("group_deps", "NoOp", {}, &graph_def);
|
||||
|
||||
// Two partitions
|
||||
NodeDef* const_node1 = CreateNode("const1", "Const", {}, &graph_def);
|
||||
SetNodeAttr("dtype", DT_FLOAT, const_node1);
|
||||
// Set 'Const' node value.
|
||||
Tensor weights(DT_FLOAT, TensorShape({4, 1}));
|
||||
test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
|
||||
SetNodeTensorAttr<float>("value", weights, const_node1);
|
||||
|
||||
NodeDef* const_node2 = CreateNode("const2", "Const", {}, &graph_def);
|
||||
SetNodeAttr("dtype", DT_FLOAT, const_node2);
|
||||
SetNodeTensorAttr<float>("value", weights, const_node2);
|
||||
|
||||
NodeDef* identity_node1 =
|
||||
CreateNode("const1/read", "Identity", {const_node1}, &graph_def);
|
||||
NodeDef* identity_node2 =
|
||||
CreateNode("const2/read", "Identity", {const_node2}, &graph_def);
|
||||
CreateNode("gather1", "Gather", {identity_node1, input_node}, &graph_def);
|
||||
CreateNode("gather2", "Gather", {identity_node2, input_node}, &graph_def);
|
||||
|
||||
// Run the op.
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {"ids"};
|
||||
context.output_names = {"gather1", "gather2"};
|
||||
TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
|
||||
|
||||
// Validation begins.
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
MapNamesToNodes(result, &node_lookup);
|
||||
|
||||
// Check nodes.
|
||||
// Check shared nodes:
|
||||
EXPECT_EQ(1, node_lookup.count("ids"));
|
||||
EXPECT_EQ("Const", node_lookup.at("ids")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("group_deps"));
|
||||
EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op());
|
||||
|
||||
// Nodes in "const1" scope.
|
||||
EXPECT_EQ(1, node_lookup.count("const1/indices"));
|
||||
EXPECT_EQ("Const", node_lookup.at("const1/indices")->op());
|
||||
Tensor expected_indices_tensor1(DT_INT64, TensorShape({3}));
|
||||
test::FillValues<int64>(&expected_indices_tensor1, {0, 2, 3});
|
||||
test::ExpectTensorEqual<int64>(
|
||||
expected_indices_tensor1,
|
||||
GetNodeTensorAttr(*(node_lookup.at("const1/indices")), "value"));
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const1/values"));
|
||||
EXPECT_EQ("Const", node_lookup.at("const1/values")->op());
|
||||
Tensor expected_values_tensor1(DT_FLOAT, TensorShape({3}));
|
||||
test::FillValues<float>(&expected_values_tensor1, {0.2, 1.2, 0.001});
|
||||
test::ExpectTensorNear<float>(
|
||||
expected_values_tensor1,
|
||||
GetNodeTensorAttr(*(node_lookup.at("const1/values")), "value"), 1e-5);
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const1/HashTable"));
|
||||
EXPECT_EQ("HashTable", node_lookup.at("const1/HashTable")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const1/InitializeTable"));
|
||||
EXPECT_EQ("InitializeTable",
|
||||
node_lookup.at("const1/InitializeTable")->op());
|
||||
|
||||
// Nodes in "gather1" scope.
|
||||
EXPECT_EQ(1, node_lookup.count("gather1/LookupTableFind"));
|
||||
EXPECT_EQ("LookupTableFind",
|
||||
node_lookup.at("gather1/LookupTableFind")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather1/Const"));
|
||||
EXPECT_EQ("Const", node_lookup.at("gather1/Const")->op());
|
||||
Tensor expected_gather_default_tensor1(DT_FLOAT, TensorShape({}));
|
||||
test::FillValues<float>(&expected_gather_default_tensor1, {0.0});
|
||||
test::ExpectTensorNear<float>(
|
||||
expected_gather_default_tensor1,
|
||||
GetNodeTensorAttr(*(node_lookup.at("gather1/Const")), "value"), 1e-5);
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather1/ExpandDims/Const"));
|
||||
EXPECT_EQ("Const", node_lookup.at("gather1/ExpandDims/Const")->op());
|
||||
Tensor expected_expand_dims_tensor1(DT_INT32, TensorShape({}));
|
||||
test::FillValues<int32>(&expected_expand_dims_tensor1, {-1});
|
||||
test::ExpectTensorEqual<int32>(
|
||||
expected_expand_dims_tensor1,
|
||||
GetNodeTensorAttr(*(node_lookup.at("gather1/ExpandDims/Const")),
|
||||
"value"));
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather1"));
|
||||
EXPECT_EQ("ExpandDims", node_lookup.at("gather1")->op());
|
||||
|
||||
// Nodes in "const2" scope.
|
||||
EXPECT_EQ(1, node_lookup.count("const2/indices"));
|
||||
EXPECT_EQ("Const", node_lookup.at("const2/indices")->op());
|
||||
Tensor expected_indices_tensor2(DT_INT64, TensorShape({3}));
|
||||
test::FillValues<int64>(&expected_indices_tensor2, {0, 2, 3});
|
||||
test::ExpectTensorEqual<int64>(
|
||||
expected_indices_tensor2,
|
||||
GetNodeTensorAttr(*(node_lookup.at("const2/indices")), "value"));
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const2/values"));
|
||||
EXPECT_EQ("Const", node_lookup.at("const2/values")->op());
|
||||
Tensor expected_values_tensor2(DT_FLOAT, TensorShape({3}));
|
||||
test::FillValues<float>(&expected_values_tensor2, {0.2, 1.2, 0.001});
|
||||
test::ExpectTensorNear<float>(
|
||||
expected_values_tensor2,
|
||||
GetNodeTensorAttr(*(node_lookup.at("const2/values")), "value"), 1e-5);
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const2/HashTable"));
|
||||
EXPECT_EQ("HashTable", node_lookup.at("const2/HashTable")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("const2/InitializeTable"));
|
||||
EXPECT_EQ("InitializeTable",
|
||||
node_lookup.at("const2/InitializeTable")->op());
|
||||
|
||||
// Nodes in "gather2" scope.
|
||||
EXPECT_EQ(1, node_lookup.count("gather2/LookupTableFind"));
|
||||
EXPECT_EQ("LookupTableFind",
|
||||
node_lookup.at("gather2/LookupTableFind")->op());
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather2/Const"));
|
||||
EXPECT_EQ("Const", node_lookup.at("gather2/Const")->op());
|
||||
Tensor expected_gather_default_tensor2(DT_FLOAT, TensorShape({}));
|
||||
test::FillValues<float>(&expected_gather_default_tensor2, {0.0});
|
||||
test::ExpectTensorNear<float>(
|
||||
expected_gather_default_tensor2,
|
||||
GetNodeTensorAttr(*(node_lookup.at("gather2/Const")), "value"), 1e-5);
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather2/ExpandDims/Const"));
|
||||
EXPECT_EQ("Const", node_lookup.at("gather2/ExpandDims/Const")->op());
|
||||
Tensor expected_expand_dims_tensor2(DT_INT32, TensorShape({}));
|
||||
test::FillValues<int32>(&expected_expand_dims_tensor2, {-1});
|
||||
test::ExpectTensorEqual<int32>(
|
||||
expected_expand_dims_tensor2,
|
||||
GetNodeTensorAttr(*(node_lookup.at("gather2/ExpandDims/Const")),
|
||||
"value"));
|
||||
|
||||
EXPECT_EQ(1, node_lookup.count("gather2"));
|
||||
EXPECT_EQ("ExpandDims", node_lookup.at("gather2")->op());
|
||||
|
||||
// Check connections
|
||||
EXPECT_EQ("const1/HashTable",
|
||||
node_lookup.at("const1/InitializeTable")->input(0));
|
||||
EXPECT_EQ("const1/indices",
|
||||
node_lookup.at("const1/InitializeTable")->input(1));
|
||||
EXPECT_EQ("const1/values",
|
||||
node_lookup.at("const1/InitializeTable")->input(2));
|
||||
|
||||
EXPECT_EQ("const2/HashTable",
|
||||
node_lookup.at("const2/InitializeTable")->input(0));
|
||||
EXPECT_EQ("const2/indices",
|
||||
node_lookup.at("const2/InitializeTable")->input(1));
|
||||
EXPECT_EQ("const2/values",
|
||||
node_lookup.at("const2/InitializeTable")->input(2));
|
||||
|
||||
EXPECT_EQ("const1/HashTable",
|
||||
node_lookup.at("gather1/LookupTableFind")->input(0));
|
||||
EXPECT_EQ("ids", node_lookup.at("gather1/LookupTableFind")->input(1));
|
||||
EXPECT_EQ("gather1/Const",
|
||||
node_lookup.at("gather1/LookupTableFind")->input(2));
|
||||
EXPECT_EQ("gather1/LookupTableFind", node_lookup.at("gather1")->input(0));
|
||||
|
||||
EXPECT_EQ("const2/HashTable",
|
||||
node_lookup.at("gather2/LookupTableFind")->input(0));
|
||||
EXPECT_EQ("ids", node_lookup.at("gather2/LookupTableFind")->input(1));
|
||||
EXPECT_EQ("gather2/Const",
|
||||
node_lookup.at("gather2/LookupTableFind")->input(2));
|
||||
EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0));
|
||||
|
||||
// Check control deps.
|
||||
EXPECT_EQ(2, node_lookup.at("group_deps")->input_size());
|
||||
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
|
||||
node_lookup.at("group_deps")->input().end(),
|
||||
"^const1/InitializeTable"),
|
||||
node_lookup.at("group_deps")->input().end());
|
||||
|
||||
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
|
||||
node_lookup.at("group_deps")->input().end(),
|
||||
"^const2/InitializeTable"),
|
||||
node_lookup.at("group_deps")->input().end());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SparsifyGatherTest, TestSinglePartitionConst) {
|
||||
TestSinglePartitionConst();
|
||||
}
|
||||
|
||||
TEST_F(SparsifyGatherTest, TestMultiPartitionConst) {
|
||||
TestMultiPartitionConst();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user