Fix import of meta graphs with partitioned variables into a scope.
Saver inspects SliceInfo to decide the variable name when creating a checkpoint. Before this fix even if a partitioned variable ("weights") was imported into a scope "a" it would still be checkpointed as ("weights") instead of ("a/weights") since import_scoped_meta_graph was not adjusting the SliceInfo. WARNING: if you use import_meta_graph on graphs with partitioned_variables WITH an import_scope argument AND then create a Saver to write/read checkpoints this change may break your checkpoint loading. PiperOrigin-RevId: 173105796
This commit is contained in:
parent
eea089bdb6
commit
dc13a8e2f7
@ -36,8 +36,10 @@ from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
@ -657,5 +659,42 @@ class MetaGraphWithVariableScopeTest(test.TestCase):
|
||||
initializer = variables.local_variables_initializer()
|
||||
|
||||
|
||||
class ExportImportAcrossScopesTest(test.TestCase):
|
||||
|
||||
def testPartionedVariables(self):
|
||||
def make_graph_with_partitioned_variables():
|
||||
variable_scope.get_variable(
|
||||
name="weights",
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0),
|
||||
initializer=random_ops.truncated_normal([100, 10]))
|
||||
self._testExportImportAcrossScopes(make_graph_with_partitioned_variables)
|
||||
|
||||
def _testExportImportAcrossScopes(self, graph_fn):
|
||||
"""Tests export and importing a graph across scopes.
|
||||
|
||||
Args:
|
||||
graph_fn: A closure that creates a graph on the current scope.
|
||||
"""
|
||||
with ops.Graph().as_default() as original_graph:
|
||||
with variable_scope.variable_scope("dropA/dropB/keepA"):
|
||||
graph_fn()
|
||||
exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
|
||||
graph=original_graph,
|
||||
export_scope="dropA/dropB")[0]
|
||||
|
||||
with ops.Graph().as_default() as imported_graph:
|
||||
meta_graph.import_scoped_meta_graph(
|
||||
exported_meta_graph_def,
|
||||
import_scope="importA")
|
||||
|
||||
with ops.Graph().as_default() as expected_graph:
|
||||
with variable_scope.variable_scope("importA/keepA"):
|
||||
graph_fn()
|
||||
|
||||
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
|
||||
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
|
||||
self.assertProtoEquals(expected, result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -394,7 +394,8 @@ class Variable(object):
|
||||
import_scope=import_scope))
|
||||
if variable_def.HasField("save_slice_info_def"):
|
||||
self._save_slice_info = Variable.SaveSliceInfo(
|
||||
save_slice_info_def=variable_def.save_slice_info_def)
|
||||
save_slice_info_def=variable_def.save_slice_info_def,
|
||||
import_scope=import_scope)
|
||||
else:
|
||||
self._save_slice_info = None
|
||||
self._caching_device = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user