diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 65abb695991..06cee46bf62 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -36,8 +36,10 @@ from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -657,5 +659,42 @@ class MetaGraphWithVariableScopeTest(test.TestCase): initializer = variables.local_variables_initializer() +class ExportImportAcrossScopesTest(test.TestCase): + + def testPartionedVariables(self): + def make_graph_with_partitioned_variables(): + variable_scope.get_variable( + name="weights", + partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0), + initializer=random_ops.truncated_normal([100, 10])) + self._testExportImportAcrossScopes(make_graph_with_partitioned_variables) + + def _testExportImportAcrossScopes(self, graph_fn): + """Tests export and importing a graph across scopes. + + Args: + graph_fn: A closure that creates a graph on the current scope. + """ + with ops.Graph().as_default() as original_graph: + with variable_scope.variable_scope("dropA/dropB/keepA"): + graph_fn() + exported_meta_graph_def = meta_graph.export_scoped_meta_graph( + graph=original_graph, + export_scope="dropA/dropB")[0] + + with ops.Graph().as_default() as imported_graph: + meta_graph.import_scoped_meta_graph( + exported_meta_graph_def, + import_scope="importA") + + with ops.Graph().as_default() as expected_graph: + with variable_scope.variable_scope("importA/keepA"): + graph_fn() + + result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0] + expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0] + self.assertProtoEquals(expected, result) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 90b4f25d81a..0272f771768 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -394,7 +394,8 @@ class Variable(object): import_scope=import_scope)) if variable_def.HasField("save_slice_info_def"): self._save_slice_info = Variable.SaveSliceInfo( - save_slice_info_def=variable_def.save_slice_info_def) + save_slice_info_def=variable_def.save_slice_info_def, + import_scope=import_scope) else: self._save_slice_info = None self._caching_device = None