Fix import of meta graphs with partitioned variables into a scope.

Saver inspects SliceInfo to decide the variable name when creating a
checkpoint. Before this fix even if a partitioned variable ("weights")
was imported into a scope "a" it would still be checkpointed as ("weights")
instead of ("a/weights") since import_scoped_meta_graph was not adjusting
the SliceInfo.

WARNING: if you use import_meta_graph on graphs with partitioned_variables WITH an import_scope argument AND then create a Saver to write/read checkpoints this change
may break your checkpoint loading.
PiperOrigin-RevId: 173105796
This commit is contained in:
A. Unique TensorFlower 2017-10-23 06:03:28 -07:00 committed by TensorFlower Gardener
parent eea089bdb6
commit dc13a8e2f7
2 changed files with 41 additions and 1 deletions

View File

@ -36,8 +36,10 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@ -657,5 +659,42 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
initializer = variables.local_variables_initializer()
class ExportImportAcrossScopesTest(test.TestCase):
def testPartionedVariables(self):
def make_graph_with_partitioned_variables():
variable_scope.get_variable(
name="weights",
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0),
initializer=random_ops.truncated_normal([100, 10]))
self._testExportImportAcrossScopes(make_graph_with_partitioned_variables)
def _testExportImportAcrossScopes(self, graph_fn):
"""Tests export and importing a graph across scopes.
Args:
graph_fn: A closure that creates a graph on the current scope.
"""
with ops.Graph().as_default() as original_graph:
with variable_scope.variable_scope("dropA/dropB/keepA"):
graph_fn()
exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
graph=original_graph,
export_scope="dropA/dropB")[0]
with ops.Graph().as_default() as imported_graph:
meta_graph.import_scoped_meta_graph(
exported_meta_graph_def,
import_scope="importA")
with ops.Graph().as_default() as expected_graph:
with variable_scope.variable_scope("importA/keepA"):
graph_fn()
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
self.assertProtoEquals(expected, result)
if __name__ == "__main__":
test.main()

View File

@ -394,7 +394,8 @@ class Variable(object):
import_scope=import_scope))
if variable_def.HasField("save_slice_info_def"):
self._save_slice_info = Variable.SaveSliceInfo(
save_slice_info_def=variable_def.save_slice_info_def)
save_slice_info_def=variable_def.save_slice_info_def,
import_scope=import_scope)
else:
self._save_slice_info = None
self._caching_device = None