From 98765e41f94a1d2cd088c593f18a3a877f5059ca Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 7 May 2019 10:28:52 -0700 Subject: [PATCH] Expose version information in SavedModels loaded with tf.saved_model.load Adds tensorflow_version and tensorflow_git_version properties. PiperOrigin-RevId: 247044940 --- tensorflow/python/saved_model/load.py | 3 +++ tensorflow/python/saved_model/load_test.py | 7 +++++++ tensorflow/python/saved_model/load_v1_in_v2.py | 4 ++++ tensorflow/python/saved_model/load_v1_in_v2_test.py | 8 ++++++++ 4 files changed, 22 insertions(+) diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 17c1024fe34..14569a240a0 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -380,6 +380,9 @@ def load(export_dir, tags=None): saved_model_proto, export_dir) root = loader.get(0) + root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version + root.tensorflow_git_version = ( + meta_graph_def.meta_info_def.tensorflow_git_version) else: with ops.init_scope(): root = load_v1_in_v2.load(export_dir, tags) diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 098e2d330fd..bca1cdc70b7 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.framework import versions from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import input_layer from tensorflow.python.keras.engine import sequential @@ -1547,6 +1548,12 @@ class LoadTest(test.TestCase, parameterized.TestCase): original, root.model.traced_call(array_ops.zeros([1, 1])).numpy()) + def test_version_info(self, cycles): + root = util.Checkpoint() + root = self.cycle(root, cycles) + self.assertEqual(versions.__version__, root.tensorflow_version) + self.assertEqual(versions.__git_version__, root.tensorflow_git_version) + def test_functional_model_with_conv(self, cycles): x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32) conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x) diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index 4d0ef7ba89f..cb1464be780 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -171,6 +171,10 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): root.signatures = signature_serialization.create_signature_map( signature_functions) root.variables = list(wrapped.graph.variables) + root.tensorflow_version = ( + meta_graph_def.meta_info_def.tensorflow_version) + root.tensorflow_git_version = ( + meta_graph_def.meta_info_def.tensorflow_git_version) return root diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index b6a1c9d0c47..6a27a268a41 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import versions from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -334,6 +335,13 @@ class LoadTest(test.TestCase): imported = load.load(path) self.assertEqual([2], imported.signatures["key"]()["value"].shape) + def test_version_info(self): + path = self._signature_with_no_inputs() + imported = load.load(path) + self.assertEqual(versions.__version__, imported.tensorflow_version) + self.assertEqual(versions.__git_version__, + imported.tensorflow_git_version) + def _unfed_placeholder_signature(self): export_graph = ops.Graph() with export_graph.as_default():