Add sparsify_gather op to reduce linear model memory footprint.
Also fixed typo: "obsfucate" -> "obfuscate" Change: 149627297
This commit is contained in:
parent
62be492ef4
commit
168e8dacba
@ -65,7 +65,7 @@ cc_library(
|
|||||||
"freeze_requantization_ranges.cc",
|
"freeze_requantization_ranges.cc",
|
||||||
"fuse_convolutions.cc",
|
"fuse_convolutions.cc",
|
||||||
"insert_logging.cc",
|
"insert_logging.cc",
|
||||||
"obsfucate_names.cc",
|
"obfuscate_names.cc",
|
||||||
"remove_attribute.cc",
|
"remove_attribute.cc",
|
||||||
"remove_device.cc",
|
"remove_device.cc",
|
||||||
"remove_nodes.cc",
|
"remove_nodes.cc",
|
||||||
@ -73,6 +73,7 @@ cc_library(
|
|||||||
"rename_op.cc",
|
"rename_op.cc",
|
||||||
"set_device.cc",
|
"set_device.cc",
|
||||||
"sort_by_execution_order.cc",
|
"sort_by_execution_order.cc",
|
||||||
|
"sparsify_gather.cc",
|
||||||
"strip_unused_nodes.cc",
|
"strip_unused_nodes.cc",
|
||||||
] + if_not_windows([
|
] + if_not_windows([
|
||||||
"quantize_nodes.cc",
|
"quantize_nodes.cc",
|
||||||
@ -111,7 +112,7 @@ tf_cc_test(
|
|||||||
"freeze_requantization_ranges_test.cc",
|
"freeze_requantization_ranges_test.cc",
|
||||||
"fuse_convolutions_test.cc",
|
"fuse_convolutions_test.cc",
|
||||||
"insert_logging_test.cc",
|
"insert_logging_test.cc",
|
||||||
"obsfucate_names_test.cc",
|
"obfuscate_names_test.cc",
|
||||||
"quantize_nodes_test.cc",
|
"quantize_nodes_test.cc",
|
||||||
"quantize_weights_test.cc",
|
"quantize_weights_test.cc",
|
||||||
"remove_attribute_test.cc",
|
"remove_attribute_test.cc",
|
||||||
@ -122,6 +123,7 @@ tf_cc_test(
|
|||||||
"round_weights_test.cc",
|
"round_weights_test.cc",
|
||||||
"set_device_test.cc",
|
"set_device_test.cc",
|
||||||
"sort_by_execution_order_test.cc",
|
"sort_by_execution_order_test.cc",
|
||||||
|
"sparsify_gather_test.cc",
|
||||||
"strip_unused_nodes_test.cc",
|
"strip_unused_nodes_test.cc",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
* [fuse_convolutions](#fuse_convolutions)
|
* [fuse_convolutions](#fuse_convolutions)
|
||||||
* [insert_logging](#insert_logging)
|
* [insert_logging](#insert_logging)
|
||||||
* [merge_duplicate_nodes](#merge_duplicate_nodes)
|
* [merge_duplicate_nodes](#merge_duplicate_nodes)
|
||||||
* [obsfucate_names](#obsfucate_names)
|
* [obfuscate_names](#obfuscate_names)
|
||||||
* [quantize_nodes](#quantize_nodes)
|
* [quantize_nodes](#quantize_nodes)
|
||||||
* [quantize_weights](#quantize_weights)
|
* [quantize_weights](#quantize_weights)
|
||||||
* [remove_attribute](#remove_attribute)
|
* [remove_attribute](#remove_attribute)
|
||||||
@ -29,6 +29,7 @@
|
|||||||
* [rename_attribute](#rename_attribute)
|
* [rename_attribute](#rename_attribute)
|
||||||
* [rename_op](#rename_op)
|
* [rename_op](#rename_op)
|
||||||
* [round_weights](#round_weights)
|
* [round_weights](#round_weights)
|
||||||
|
* [sparsify_gather](#sparsify_gather)
|
||||||
* [set_device](#set_device)
|
* [set_device](#set_device)
|
||||||
* [sort_by_execution_order](#sort_by_execution_order)
|
* [sort_by_execution_order](#sort_by_execution_order)
|
||||||
* [strip_unused_nodes](#strip_unused_nodes)
|
* [strip_unused_nodes](#strip_unused_nodes)
|
||||||
@ -250,7 +251,7 @@ results are cached and so you shouldn't see the graph run any more slowly.
|
|||||||
So far we've been concentrating on weights because those generally take up the
|
So far we've been concentrating on weights because those generally take up the
|
||||||
most space. If you have a graph with a lot of small nodes in it, the names of
|
most space. If you have a graph with a lot of small nodes in it, the names of
|
||||||
those nodes can start to take up a noticeable amount of space too. To shrink
|
those nodes can start to take up a noticeable amount of space too. To shrink
|
||||||
those down, you can run the [obsfucate_names](#obsfucate_names) transform, which
|
those down, you can run the [obfuscate_names](#obfuscate_names) transform, which
|
||||||
replaces all the names (except for inputs and outputs) with short, cryptic but
|
replaces all the names (except for inputs and outputs) with short, cryptic but
|
||||||
unique ids:
|
unique ids:
|
||||||
|
|
||||||
@ -262,7 +263,7 @@ bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
|||||||
--inputs='Mul:0' \
|
--inputs='Mul:0' \
|
||||||
--outputs='softmax:0' \
|
--outputs='softmax:0' \
|
||||||
--transforms='\
|
--transforms='\
|
||||||
obsfucate_names \
|
obfuscate_names \
|
||||||
'
|
'
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -520,7 +521,7 @@ of redundancy (e.g. this transform is always run as part of
|
|||||||
[quantize_nodes](#quantize_nodes) since the processing there can introduce
|
[quantize_nodes](#quantize_nodes) since the processing there can introduce
|
||||||
duplicates of constants that are used in the quantize/dequantize process).
|
duplicates of constants that are used in the quantize/dequantize process).
|
||||||
|
|
||||||
### obsfucate_names
|
### obfuscate_names
|
||||||
|
|
||||||
Args: None \
|
Args: None \
|
||||||
Prerequisites: None
|
Prerequisites: None
|
||||||
@ -656,6 +657,16 @@ between the largest and smallest values present. This is useful when you'll be
|
|||||||
deploying on mobile, and you want a model that will compress effectively. See
|
deploying on mobile, and you want a model that will compress effectively. See
|
||||||
[shrinking file size](#shrinking-file-size) for more details.
|
[shrinking file size](#shrinking-file-size) for more details.
|
||||||
|
|
||||||
|
### sparsify_gather
|
||||||
|
|
||||||
|
Args: None \
|
||||||
|
Prerequisites: None
|
||||||
|
|
||||||
|
Transform 'Gather' op to a sparsified version where 'params' input of 'Gather'
|
||||||
|
is replaced from a dense 'Const' to a 'HashTable'. 'Gather' op itself is
|
||||||
|
replaced by a hashtable lookup. This is mostly useful for reducing sparse
|
||||||
|
TF.learn linear model memory footprint.
|
||||||
|
|
||||||
### set_device
|
### set_device
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -29,7 +29,7 @@ namespace graph_transforms {
|
|||||||
|
|
||||||
// Renames all nodes not uses as graph inputs or outputs to short numerical
|
// Renames all nodes not uses as graph inputs or outputs to short numerical
|
||||||
// forms.
|
// forms.
|
||||||
Status ObsfucateNames(const GraphDef& input_graph_def,
|
Status ObfuscateNames(const GraphDef& input_graph_def,
|
||||||
const TransformFuncContext& context,
|
const TransformFuncContext& context,
|
||||||
GraphDef* output_graph_def) {
|
GraphDef* output_graph_def) {
|
||||||
std::unordered_set<string> required_nodes;
|
std::unordered_set<string> required_nodes;
|
||||||
@ -73,7 +73,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
|
|||||||
output_graph_def->Clear();
|
output_graph_def->Clear();
|
||||||
for (const NodeDef& input_node : input_graph_def.node()) {
|
for (const NodeDef& input_node : input_graph_def.node()) {
|
||||||
NodeDef* node = output_graph_def->mutable_node()->Add();
|
NodeDef* node = output_graph_def->mutable_node()->Add();
|
||||||
node->CopyFrom(input_node);
|
*node = input_node;
|
||||||
const string& old_name = input_node.name();
|
const string& old_name = input_node.name();
|
||||||
node->set_name(new_names[old_name]);
|
node->set_name(new_names[old_name]);
|
||||||
node->mutable_input()->Clear();
|
node->mutable_input()->Clear();
|
||||||
@ -94,7 +94,7 @@ Status ObsfucateNames(const GraphDef& input_graph_def,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_GRAPH_TRANSFORM("obsfucate_names", ObsfucateNames);
|
REGISTER_GRAPH_TRANSFORM("obfuscate_names", ObfuscateNames);
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
@ -29,11 +29,11 @@ namespace tensorflow {
|
|||||||
namespace graph_transforms {
|
namespace graph_transforms {
|
||||||
|
|
||||||
// Declare here, so we don't need a public header.
|
// Declare here, so we don't need a public header.
|
||||||
Status ObsfucateNames(const GraphDef& input_graph_def,
|
Status ObfuscateNames(const GraphDef& input_graph_def,
|
||||||
const TransformFuncContext& context,
|
const TransformFuncContext& context,
|
||||||
GraphDef* output_graph_def);
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
class ObsfucateNamesTest : public ::testing::Test {
|
class ObfuscateNamesTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
void TestSimpleTree() {
|
void TestSimpleTree() {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
@ -74,7 +74,7 @@ class ObsfucateNamesTest : public ::testing::Test {
|
|||||||
|
|
||||||
GraphDef result;
|
GraphDef result;
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
ObsfucateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
|
ObfuscateNames(graph_def, {{"const_node1"}, {"add_node1"}}, &result));
|
||||||
|
|
||||||
std::map<string, const NodeDef*> node_lookup;
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
MapNamesToNodes(result, &node_lookup);
|
MapNamesToNodes(result, &node_lookup);
|
||||||
@ -97,7 +97,7 @@ class ObsfucateNamesTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDef result;
|
GraphDef result;
|
||||||
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"const_node0"}, {"const_node999"}},
|
TF_ASSERT_OK(ObfuscateNames(graph_def, {{"const_node0"}, {"const_node999"}},
|
||||||
&result));
|
&result));
|
||||||
|
|
||||||
std::map<string, const NodeDef*> node_lookup;
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
@ -116,7 +116,7 @@ class ObsfucateNamesTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDef result;
|
GraphDef result;
|
||||||
TF_ASSERT_OK(ObsfucateNames(graph_def, {{"10"}, {"19"}}, &result));
|
TF_ASSERT_OK(ObfuscateNames(graph_def, {{"10"}, {"19"}}, &result));
|
||||||
|
|
||||||
std::map<string, const NodeDef*> node_lookup;
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
MapNamesToNodes(result, &node_lookup);
|
MapNamesToNodes(result, &node_lookup);
|
||||||
@ -132,11 +132,11 @@ class ObsfucateNamesTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(ObsfucateNamesTest, TestSimpleTree) { TestSimpleTree(); }
|
TEST_F(ObfuscateNamesTest, TestSimpleTree) { TestSimpleTree(); }
|
||||||
|
|
||||||
TEST_F(ObsfucateNamesTest, TestManyNodes) { TestManyNodes(); }
|
TEST_F(ObfuscateNamesTest, TestManyNodes) { TestManyNodes(); }
|
||||||
|
|
||||||
TEST_F(ObsfucateNamesTest, TestNameClashes) { TestNameClashes(); }
|
TEST_F(ObfuscateNamesTest, TestNameClashes) { TestNameClashes(); }
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
276
tensorflow/tools/graph_transforms/sparsify_gather.cc
Normal file
276
tensorflow/tools/graph_transforms/sparsify_gather.cc
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/subgraph.h"
|
||||||
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
using strings::StrCat;
|
||||||
|
namespace graph_transforms {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Sparsify Tensor of shape [N, 1]. Return the indices and values vectors for
|
||||||
|
// non-zero tensor content.
|
||||||
|
Status SparsifyWeights(const Tensor& tensor, Tensor* indices_tensor,
|
||||||
|
Tensor* values_tensor) {
|
||||||
|
if (tensor.dims() != 2 || tensor.dim_size(1) != 1) {
|
||||||
|
return tensorflow::errors::FailedPrecondition(
|
||||||
|
"Transform only applicable to subgraph with 'Const' with "
|
||||||
|
"tensor of shpae [N, 1]. But instead get shape ",
|
||||||
|
tensor.shape().DebugString(), ".");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto flat = tensor.flat<float>();
|
||||||
|
std::vector<int64> indices;
|
||||||
|
std::vector<float> values;
|
||||||
|
|
||||||
|
for (int64 i = 0; i < flat.size(); i++) {
|
||||||
|
float val = flat(i);
|
||||||
|
if (std::abs(val) >= 1.0e-5) {
|
||||||
|
indices.push_back(i);
|
||||||
|
values.push_back(val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// During model initialization, InitializeTableOp makes use of
|
||||||
|
// KeyValueTensorIterator, which does not accept empty keys or values.
|
||||||
|
// Consequently, adding a dummy pair of indices and values as a walkaround.
|
||||||
|
if (indices.empty() || values.empty()) {
|
||||||
|
indices.push_back(0);
|
||||||
|
values.push_back(0);
|
||||||
|
}
|
||||||
|
*indices_tensor = Tensor(DataTypeToEnum<int64>::value,
|
||||||
|
{static_cast<int64>(indices.size())});
|
||||||
|
std::copy_n(indices.begin(), indices.size(),
|
||||||
|
indices_tensor->flat<int64>().data());
|
||||||
|
|
||||||
|
*values_tensor =
|
||||||
|
Tensor(DataTypeToEnum<float>::value, {static_cast<int64>(values.size())});
|
||||||
|
std::copy_n(values.begin(), values.size(),
|
||||||
|
values_tensor->flat<float>().data());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CreateConstNode(const Tensor& tensor, const string& name,
|
||||||
|
NodeDef* node_def) {
|
||||||
|
node_def->set_op("Const");
|
||||||
|
node_def->set_name(name);
|
||||||
|
SetNodeTensorAttr<float>("value", tensor, node_def);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status SparsifyGather(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
GraphDef current_graph_def = input_graph_def;
|
||||||
|
bool any_match_found = false;
|
||||||
|
// The subgraphs may have overlapping components, therefore GraphMatcher
|
||||||
|
// doesn't return all subgraphs in one round -- this has to be multi-round
|
||||||
|
// update.
|
||||||
|
do {
|
||||||
|
any_match_found = false;
|
||||||
|
GraphDef replaced_graph_def = current_graph_def;
|
||||||
|
std::vector<string> init_table_node_names;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
|
current_graph_def, // clang-format off
|
||||||
|
{"Gather",
|
||||||
|
{
|
||||||
|
{"Identity",
|
||||||
|
{
|
||||||
|
{"Const"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"*"},
|
||||||
|
}
|
||||||
|
}, // clang-format on
|
||||||
|
[&any_match_found, &init_table_node_names](
|
||||||
|
const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
|
const std::set<string>& output_nodes,
|
||||||
|
std::vector<NodeDef>* new_nodes) {
|
||||||
|
any_match_found = true;
|
||||||
|
|
||||||
|
// The captured subgraph should be of the following pattern:
|
||||||
|
// Const --> Identity --> Gather --> ...
|
||||||
|
// ^
|
||||||
|
// |
|
||||||
|
// (ids)
|
||||||
|
//
|
||||||
|
// After transform, it becomes:
|
||||||
|
// --> NoOp(group_deps)
|
||||||
|
// |
|
||||||
|
// Const --> InitializeTable --> HashTable
|
||||||
|
// ^ |
|
||||||
|
// | |
|
||||||
|
// Const ------------- |
|
||||||
|
// v
|
||||||
|
// (ids) ---> LookupTableFind <--- Const(default)
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
// ...
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
// For each subgraph, do the following
|
||||||
|
// 1. Sparsify the `Const`, creating two `Const`, for hashtable
|
||||||
|
// key/val.
|
||||||
|
// 2. Create a `InitializeTable` op connecting to the above 2 `Const`.
|
||||||
|
// 3. Create a `HashTable` op connecting to `InitializeTable` op.
|
||||||
|
// 4. Replace the `Gather` with a `LookupTableFind` op.
|
||||||
|
// 5. Connect the `LookupTableFind` with
|
||||||
|
// a. `HashTable`
|
||||||
|
// b. `Gather`'s ids input
|
||||||
|
// c. a `default_val` arg, valued at 0
|
||||||
|
// clang-format on
|
||||||
|
const NodeDef& gather_node = match.node;
|
||||||
|
const NodeDef& const_node = match.inputs[0].inputs[0].node;
|
||||||
|
|
||||||
|
DataType data_type;
|
||||||
|
GetNodeAttr(const_node, "dtype", &data_type);
|
||||||
|
if (data_type != DT_FLOAT) {
|
||||||
|
return tensorflow::errors::FailedPrecondition(
|
||||||
|
"Transform only applicable to subgraph with 'Const' of dtype "
|
||||||
|
"'DT_FLOAT'. Found 'Const' with name '",
|
||||||
|
const_node.name(), "' and dtype '", data_type, "'.");
|
||||||
|
}
|
||||||
|
Tensor weight = GetNodeTensorAttr(const_node, "value");
|
||||||
|
Tensor indices_tensor;
|
||||||
|
Tensor values_tensor;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
SparsifyWeights(weight, &indices_tensor, &values_tensor));
|
||||||
|
|
||||||
|
// indices and values of sparsified `Const`
|
||||||
|
DataType key_dtype = DT_INT64;
|
||||||
|
NodeDef indices_node;
|
||||||
|
CreateConstNode(indices_tensor, StrCat(const_node.name(), "/indices"),
|
||||||
|
&indices_node);
|
||||||
|
SetNodeAttr("dtype", key_dtype, &indices_node);
|
||||||
|
|
||||||
|
NodeDef values_node;
|
||||||
|
CreateConstNode(values_tensor, StrCat(const_node.name(), "/values"),
|
||||||
|
&values_node);
|
||||||
|
SetNodeAttr("dtype", data_type, &values_node);
|
||||||
|
|
||||||
|
// HashTable node
|
||||||
|
NodeDef hashtable_node;
|
||||||
|
hashtable_node.set_op("HashTable");
|
||||||
|
hashtable_node.set_name(StrCat(const_node.name(), "/HashTable"));
|
||||||
|
SetNodeAttr("key_dtype", key_dtype, &hashtable_node);
|
||||||
|
SetNodeAttr("value_dtype", data_type, &hashtable_node);
|
||||||
|
|
||||||
|
// InitializeTable node
|
||||||
|
NodeDef init_table_node;
|
||||||
|
init_table_node.set_op("InitializeTable");
|
||||||
|
init_table_node.set_name(
|
||||||
|
StrCat(const_node.name(), "/InitializeTable"));
|
||||||
|
SetNodeAttr("Tkey", key_dtype, &init_table_node);
|
||||||
|
SetNodeAttr("Tval", data_type, &init_table_node);
|
||||||
|
init_table_node_names.push_back(init_table_node.name());
|
||||||
|
|
||||||
|
// LookupTableFind node
|
||||||
|
NodeDef lookup_node;
|
||||||
|
lookup_node.set_op("LookupTableFind");
|
||||||
|
lookup_node.set_name(StrCat(gather_node.name(), "/LookupTableFind"));
|
||||||
|
SetNodeAttr("Tin", key_dtype, &lookup_node);
|
||||||
|
SetNodeAttr("Tout", data_type, &lookup_node);
|
||||||
|
|
||||||
|
// Default return value of hashtable lookup
|
||||||
|
Tensor zero_tensor(data_type, TensorShape({}));
|
||||||
|
zero_tensor.flat<float>()(0) = 0.0;
|
||||||
|
NodeDef default_value_node;
|
||||||
|
CreateConstNode(zero_tensor, StrCat(gather_node.name(), "/Const"),
|
||||||
|
&default_value_node);
|
||||||
|
SetNodeAttr("dtype", data_type, &default_value_node);
|
||||||
|
|
||||||
|
// ExpandDims argument
|
||||||
|
Tensor dim_idx(DT_INT32, TensorShape({}));
|
||||||
|
dim_idx.flat<int32>()(0) = -1;
|
||||||
|
NodeDef dim_idx_node;
|
||||||
|
dim_idx_node.set_op("Const");
|
||||||
|
dim_idx_node.set_name(
|
||||||
|
StrCat(gather_node.name(), "/ExpandDims/Const"));
|
||||||
|
SetNodeAttr("value", dim_idx, &dim_idx_node);
|
||||||
|
SetNodeAttr("dtype", DT_INT32, &dim_idx_node);
|
||||||
|
|
||||||
|
// ExpandDims node
|
||||||
|
NodeDef expand_dims_node;
|
||||||
|
expand_dims_node.set_op("ExpandDims");
|
||||||
|
// Reuse gather_node's name so not to change dependent's inputs
|
||||||
|
expand_dims_node.set_name(gather_node.name());
|
||||||
|
SetNodeAttr("T", data_type, &expand_dims_node);
|
||||||
|
|
||||||
|
// Connect nodes
|
||||||
|
AddNodeInput(hashtable_node.name(), &init_table_node);
|
||||||
|
AddNodeInput(indices_node.name(), &init_table_node);
|
||||||
|
AddNodeInput(values_node.name(), &init_table_node);
|
||||||
|
|
||||||
|
AddNodeInput(hashtable_node.name(), &lookup_node);
|
||||||
|
AddNodeInput(gather_node.input(1), &lookup_node);
|
||||||
|
AddNodeInput(default_value_node.name(), &lookup_node);
|
||||||
|
|
||||||
|
AddNodeInput(lookup_node.name(), &expand_dims_node);
|
||||||
|
AddNodeInput(dim_idx_node.name(), &expand_dims_node);
|
||||||
|
|
||||||
|
// Copy 'ids' input of original 'Gather'
|
||||||
|
new_nodes->push_back(match.inputs[1].node);
|
||||||
|
new_nodes->push_back(indices_node);
|
||||||
|
new_nodes->push_back(values_node);
|
||||||
|
new_nodes->push_back(hashtable_node);
|
||||||
|
new_nodes->push_back(init_table_node);
|
||||||
|
new_nodes->push_back(lookup_node);
|
||||||
|
new_nodes->push_back(default_value_node);
|
||||||
|
new_nodes->push_back(dim_idx_node);
|
||||||
|
new_nodes->push_back(expand_dims_node);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
},
|
||||||
|
{true}, &replaced_graph_def));
|
||||||
|
NodeDef* init_op = nullptr;
|
||||||
|
for (int i = 0; i < replaced_graph_def.node_size(); i++) {
|
||||||
|
if (replaced_graph_def.node(i).name() == "group_deps" &&
|
||||||
|
replaced_graph_def.node(i).op() == "NoOp") {
|
||||||
|
if (init_op != nullptr) {
|
||||||
|
return tensorflow::errors::FailedPrecondition(
|
||||||
|
"Multiple nodes with name: 'group_deps' and type: 'NoOp'.");
|
||||||
|
}
|
||||||
|
init_op = replaced_graph_def.mutable_node(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!init_op) {
|
||||||
|
return tensorflow::errors::FailedPrecondition(
|
||||||
|
"No node found with name: 'group_deps' and type: 'NoOp'");
|
||||||
|
}
|
||||||
|
for (const string& name : init_table_node_names) {
|
||||||
|
// Add control dependence from init_table_node to group_deps_node
|
||||||
|
AddNodeInput(StrCat("^", name), init_op);
|
||||||
|
}
|
||||||
|
current_graph_def = replaced_graph_def;
|
||||||
|
} while (any_match_found);
|
||||||
|
*output_graph_def = current_graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("sparsify_gather", SparsifyGather);
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
352
tensorflow/tools/graph_transforms/sparsify_gather_test.cc
Normal file
352
tensorflow/tools/graph_transforms/sparsify_gather_test.cc
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/sendrecv_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status SparsifyGather(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
class SparsifyGatherTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
NodeDef* CreateNode(const string& name, const string& op,
|
||||||
|
const std::vector<NodeDef*>& inputs,
|
||||||
|
GraphDef* graph_def) {
|
||||||
|
NodeDef* node_def = graph_def->add_node();
|
||||||
|
node_def->set_name(name);
|
||||||
|
node_def->set_op(op);
|
||||||
|
std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) {
|
||||||
|
node_def->add_input(input->name());
|
||||||
|
});
|
||||||
|
return node_def;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSinglePartitionConst() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
// Build the graph.
|
||||||
|
NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
|
||||||
|
NodeDef* const_node = CreateNode("const", "Const", {}, &graph_def);
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, const_node);
|
||||||
|
// Set 'Const' node value.
|
||||||
|
Tensor weights(DT_FLOAT, TensorShape({4, 1}));
|
||||||
|
test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
|
||||||
|
SetNodeTensorAttr<float>("value", weights, const_node);
|
||||||
|
|
||||||
|
NodeDef* identity_node =
|
||||||
|
CreateNode("const/read", "Identity", {const_node}, &graph_def);
|
||||||
|
CreateNode("gather", "Gather", {identity_node, input_node}, &graph_def);
|
||||||
|
CreateNode("group_deps", "NoOp", {}, &graph_def);
|
||||||
|
|
||||||
|
// Run the op.
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {"ids"};
|
||||||
|
context.output_names = {"gather"};
|
||||||
|
TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
|
||||||
|
|
||||||
|
// Validation begins.
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
|
||||||
|
// Check nodes.
|
||||||
|
EXPECT_EQ(1, node_lookup.count("ids"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("ids")->op());
|
||||||
|
|
||||||
|
// Nodes in "const" scope.
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const/indices"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const/indices")->op());
|
||||||
|
Tensor expected_indices_tensor(DT_INT64, TensorShape({3}));
|
||||||
|
test::FillValues<int64>(&expected_indices_tensor, {0, 2, 3});
|
||||||
|
test::ExpectTensorEqual<int64>(
|
||||||
|
expected_indices_tensor,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("const/indices")), "value"));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const/values"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const/values")->op());
|
||||||
|
Tensor expected_values_tensor(DT_FLOAT, TensorShape({3}));
|
||||||
|
test::FillValues<float>(&expected_values_tensor, {0.2, 1.2, 0.001});
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
expected_values_tensor,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("const/values")), "value"), 1e-5);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const/HashTable"));
|
||||||
|
EXPECT_EQ("HashTable", node_lookup.at("const/HashTable")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const/InitializeTable"));
|
||||||
|
EXPECT_EQ("InitializeTable", node_lookup.at("const/InitializeTable")->op());
|
||||||
|
|
||||||
|
// Nodes in "gather" scope.
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather/LookupTableFind"));
|
||||||
|
EXPECT_EQ("LookupTableFind",
|
||||||
|
node_lookup.at("gather/LookupTableFind")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather/Const"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("gather/Const")->op());
|
||||||
|
Tensor expected_gather_default_tensor(DT_FLOAT, TensorShape({}));
|
||||||
|
test::FillValues<float>(&expected_gather_default_tensor, {0.0});
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
expected_gather_default_tensor,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("gather/Const")), "value"), 1e-5);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather/ExpandDims/Const"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("gather/ExpandDims/Const")->op());
|
||||||
|
Tensor expected_expand_dims_tensor(DT_INT32, TensorShape({}));
|
||||||
|
test::FillValues<int32>(&expected_expand_dims_tensor, {-1});
|
||||||
|
test::ExpectTensorEqual<int32>(
|
||||||
|
expected_expand_dims_tensor,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("gather/ExpandDims/Const")),
|
||||||
|
"value"));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather"));
|
||||||
|
EXPECT_EQ("ExpandDims", node_lookup.at("gather")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("group_deps"));
|
||||||
|
EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op());
|
||||||
|
|
||||||
|
// Check connections
|
||||||
|
EXPECT_EQ("const/HashTable",
|
||||||
|
node_lookup.at("const/InitializeTable")->input(0));
|
||||||
|
EXPECT_EQ("const/indices",
|
||||||
|
node_lookup.at("const/InitializeTable")->input(1));
|
||||||
|
EXPECT_EQ("const/values",
|
||||||
|
node_lookup.at("const/InitializeTable")->input(2));
|
||||||
|
|
||||||
|
EXPECT_EQ("const/HashTable",
|
||||||
|
node_lookup.at("gather/LookupTableFind")->input(0));
|
||||||
|
EXPECT_EQ("ids", node_lookup.at("gather/LookupTableFind")->input(1));
|
||||||
|
EXPECT_EQ("gather/Const",
|
||||||
|
node_lookup.at("gather/LookupTableFind")->input(2));
|
||||||
|
|
||||||
|
EXPECT_EQ("gather/LookupTableFind", node_lookup.at("gather")->input(0));
|
||||||
|
|
||||||
|
// Check control dependency.
|
||||||
|
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
|
||||||
|
node_lookup.at("group_deps")->input().end(),
|
||||||
|
"^const/InitializeTable"),
|
||||||
|
node_lookup.at("group_deps")->input().end());
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestMultiPartitionConst() {
|
||||||
|
// The 'ids' node is served input for two 'Gather's.
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
// Build Graph:
|
||||||
|
// Shared input node
|
||||||
|
NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
|
||||||
|
// Shared init node
|
||||||
|
CreateNode("group_deps", "NoOp", {}, &graph_def);
|
||||||
|
|
||||||
|
// Two partitions
|
||||||
|
NodeDef* const_node1 = CreateNode("const1", "Const", {}, &graph_def);
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, const_node1);
|
||||||
|
// Set 'Const' node value.
|
||||||
|
Tensor weights(DT_FLOAT, TensorShape({4, 1}));
|
||||||
|
test::FillValues<float>(&weights, {0.2, 0.000001, 1.2, 0.001});
|
||||||
|
SetNodeTensorAttr<float>("value", weights, const_node1);
|
||||||
|
|
||||||
|
NodeDef* const_node2 = CreateNode("const2", "Const", {}, &graph_def);
|
||||||
|
SetNodeAttr("dtype", DT_FLOAT, const_node2);
|
||||||
|
SetNodeTensorAttr<float>("value", weights, const_node2);
|
||||||
|
|
||||||
|
NodeDef* identity_node1 =
|
||||||
|
CreateNode("const1/read", "Identity", {const_node1}, &graph_def);
|
||||||
|
NodeDef* identity_node2 =
|
||||||
|
CreateNode("const2/read", "Identity", {const_node2}, &graph_def);
|
||||||
|
CreateNode("gather1", "Gather", {identity_node1, input_node}, &graph_def);
|
||||||
|
CreateNode("gather2", "Gather", {identity_node2, input_node}, &graph_def);
|
||||||
|
|
||||||
|
// Run the op.
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {"ids"};
|
||||||
|
context.output_names = {"gather1", "gather2"};
|
||||||
|
TF_ASSERT_OK(SparsifyGather(graph_def, context, &result));
|
||||||
|
|
||||||
|
// Validation begins.
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
|
||||||
|
// Check nodes.
|
||||||
|
// Check shared nodes:
|
||||||
|
EXPECT_EQ(1, node_lookup.count("ids"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("ids")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("group_deps"));
|
||||||
|
EXPECT_EQ("NoOp", node_lookup.at("group_deps")->op());
|
||||||
|
|
||||||
|
// Nodes in "const1" scope.
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const1/indices"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const1/indices")->op());
|
||||||
|
Tensor expected_indices_tensor1(DT_INT64, TensorShape({3}));
|
||||||
|
test::FillValues<int64>(&expected_indices_tensor1, {0, 2, 3});
|
||||||
|
test::ExpectTensorEqual<int64>(
|
||||||
|
expected_indices_tensor1,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("const1/indices")), "value"));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const1/values"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const1/values")->op());
|
||||||
|
Tensor expected_values_tensor1(DT_FLOAT, TensorShape({3}));
|
||||||
|
test::FillValues<float>(&expected_values_tensor1, {0.2, 1.2, 0.001});
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
expected_values_tensor1,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("const1/values")), "value"), 1e-5);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const1/HashTable"));
|
||||||
|
EXPECT_EQ("HashTable", node_lookup.at("const1/HashTable")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const1/InitializeTable"));
|
||||||
|
EXPECT_EQ("InitializeTable",
|
||||||
|
node_lookup.at("const1/InitializeTable")->op());
|
||||||
|
|
||||||
|
// Nodes in "gather1" scope.
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather1/LookupTableFind"));
|
||||||
|
EXPECT_EQ("LookupTableFind",
|
||||||
|
node_lookup.at("gather1/LookupTableFind")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather1/Const"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("gather1/Const")->op());
|
||||||
|
Tensor expected_gather_default_tensor1(DT_FLOAT, TensorShape({}));
|
||||||
|
test::FillValues<float>(&expected_gather_default_tensor1, {0.0});
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
expected_gather_default_tensor1,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("gather1/Const")), "value"), 1e-5);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather1/ExpandDims/Const"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("gather1/ExpandDims/Const")->op());
|
||||||
|
Tensor expected_expand_dims_tensor1(DT_INT32, TensorShape({}));
|
||||||
|
test::FillValues<int32>(&expected_expand_dims_tensor1, {-1});
|
||||||
|
test::ExpectTensorEqual<int32>(
|
||||||
|
expected_expand_dims_tensor1,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("gather1/ExpandDims/Const")),
|
||||||
|
"value"));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather1"));
|
||||||
|
EXPECT_EQ("ExpandDims", node_lookup.at("gather1")->op());
|
||||||
|
|
||||||
|
// Nodes in "const2" scope.
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const2/indices"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const2/indices")->op());
|
||||||
|
Tensor expected_indices_tensor2(DT_INT64, TensorShape({3}));
|
||||||
|
test::FillValues<int64>(&expected_indices_tensor2, {0, 2, 3});
|
||||||
|
test::ExpectTensorEqual<int64>(
|
||||||
|
expected_indices_tensor2,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("const2/indices")), "value"));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const2/values"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("const2/values")->op());
|
||||||
|
Tensor expected_values_tensor2(DT_FLOAT, TensorShape({3}));
|
||||||
|
test::FillValues<float>(&expected_values_tensor2, {0.2, 1.2, 0.001});
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
expected_values_tensor2,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("const2/values")), "value"), 1e-5);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const2/HashTable"));
|
||||||
|
EXPECT_EQ("HashTable", node_lookup.at("const2/HashTable")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("const2/InitializeTable"));
|
||||||
|
EXPECT_EQ("InitializeTable",
|
||||||
|
node_lookup.at("const2/InitializeTable")->op());
|
||||||
|
|
||||||
|
// Nodes in "gather2" scope.
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather2/LookupTableFind"));
|
||||||
|
EXPECT_EQ("LookupTableFind",
|
||||||
|
node_lookup.at("gather2/LookupTableFind")->op());
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather2/Const"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("gather2/Const")->op());
|
||||||
|
Tensor expected_gather_default_tensor2(DT_FLOAT, TensorShape({}));
|
||||||
|
test::FillValues<float>(&expected_gather_default_tensor2, {0.0});
|
||||||
|
test::ExpectTensorNear<float>(
|
||||||
|
expected_gather_default_tensor2,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("gather2/Const")), "value"), 1e-5);
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather2/ExpandDims/Const"));
|
||||||
|
EXPECT_EQ("Const", node_lookup.at("gather2/ExpandDims/Const")->op());
|
||||||
|
Tensor expected_expand_dims_tensor2(DT_INT32, TensorShape({}));
|
||||||
|
test::FillValues<int32>(&expected_expand_dims_tensor2, {-1});
|
||||||
|
test::ExpectTensorEqual<int32>(
|
||||||
|
expected_expand_dims_tensor2,
|
||||||
|
GetNodeTensorAttr(*(node_lookup.at("gather2/ExpandDims/Const")),
|
||||||
|
"value"));
|
||||||
|
|
||||||
|
EXPECT_EQ(1, node_lookup.count("gather2"));
|
||||||
|
EXPECT_EQ("ExpandDims", node_lookup.at("gather2")->op());
|
||||||
|
|
||||||
|
// Check connections
|
||||||
|
EXPECT_EQ("const1/HashTable",
|
||||||
|
node_lookup.at("const1/InitializeTable")->input(0));
|
||||||
|
EXPECT_EQ("const1/indices",
|
||||||
|
node_lookup.at("const1/InitializeTable")->input(1));
|
||||||
|
EXPECT_EQ("const1/values",
|
||||||
|
node_lookup.at("const1/InitializeTable")->input(2));
|
||||||
|
|
||||||
|
EXPECT_EQ("const2/HashTable",
|
||||||
|
node_lookup.at("const2/InitializeTable")->input(0));
|
||||||
|
EXPECT_EQ("const2/indices",
|
||||||
|
node_lookup.at("const2/InitializeTable")->input(1));
|
||||||
|
EXPECT_EQ("const2/values",
|
||||||
|
node_lookup.at("const2/InitializeTable")->input(2));
|
||||||
|
|
||||||
|
EXPECT_EQ("const1/HashTable",
|
||||||
|
node_lookup.at("gather1/LookupTableFind")->input(0));
|
||||||
|
EXPECT_EQ("ids", node_lookup.at("gather1/LookupTableFind")->input(1));
|
||||||
|
EXPECT_EQ("gather1/Const",
|
||||||
|
node_lookup.at("gather1/LookupTableFind")->input(2));
|
||||||
|
EXPECT_EQ("gather1/LookupTableFind", node_lookup.at("gather1")->input(0));
|
||||||
|
|
||||||
|
EXPECT_EQ("const2/HashTable",
|
||||||
|
node_lookup.at("gather2/LookupTableFind")->input(0));
|
||||||
|
EXPECT_EQ("ids", node_lookup.at("gather2/LookupTableFind")->input(1));
|
||||||
|
EXPECT_EQ("gather2/Const",
|
||||||
|
node_lookup.at("gather2/LookupTableFind")->input(2));
|
||||||
|
EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0));
|
||||||
|
|
||||||
|
// Check control deps.
|
||||||
|
EXPECT_EQ(2, node_lookup.at("group_deps")->input_size());
|
||||||
|
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
|
||||||
|
node_lookup.at("group_deps")->input().end(),
|
||||||
|
"^const1/InitializeTable"),
|
||||||
|
node_lookup.at("group_deps")->input().end());
|
||||||
|
|
||||||
|
EXPECT_NE(std::find(node_lookup.at("group_deps")->input().begin(),
|
||||||
|
node_lookup.at("group_deps")->input().end(),
|
||||||
|
"^const2/InitializeTable"),
|
||||||
|
node_lookup.at("group_deps")->input().end());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(SparsifyGatherTest, TestSinglePartitionConst) {
|
||||||
|
TestSinglePartitionConst();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SparsifyGatherTest, TestMultiPartitionConst) {
|
||||||
|
TestMultiPartitionConst();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user