Expose version information in SavedModels loaded with tf.saved_model.load

Adds tensorflow_version and tensorflow_git_version properties.

PiperOrigin-RevId: 247044940
This commit is contained in:
Allen Lavoie 2019-05-07 10:28:52 -07:00 committed by TensorFlower Gardener
parent 75bd1d5c92
commit 98765e41f9
4 changed files with 22 additions and 0 deletions

View File

@ -380,6 +380,9 @@ def load(export_dir, tags=None):
saved_model_proto,
export_dir)
root = loader.get(0)
root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
root.tensorflow_git_version = (
meta_graph_def.meta_info_def.tensorflow_git_version)
else:
with ops.init_scope():
root = load_v1_in_v2.load(export_dir, tags)

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import sequential
@ -1547,6 +1548,12 @@ class LoadTest(test.TestCase, parameterized.TestCase):
original,
root.model.traced_call(array_ops.zeros([1, 1])).numpy())
def test_version_info(self, cycles):
root = util.Checkpoint()
root = self.cycle(root, cycles)
self.assertEqual(versions.__version__, root.tensorflow_version)
self.assertEqual(versions.__git_version__, root.tensorflow_git_version)
def test_functional_model_with_conv(self, cycles):
x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32)
conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x)

View File

@ -171,6 +171,10 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
root.signatures = signature_serialization.create_signature_map(
signature_functions)
root.variables = list(wrapped.graph.variables)
root.tensorflow_version = (
meta_graph_def.meta_info_def.tensorflow_version)
root.tensorflow_git_version = (
meta_graph_def.meta_info_def.tensorflow_git_version)
return root

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -334,6 +335,13 @@ class LoadTest(test.TestCase):
imported = load.load(path)
self.assertEqual([2], imported.signatures["key"]()["value"].shape)
def test_version_info(self):
path = self._signature_with_no_inputs()
imported = load.load(path)
self.assertEqual(versions.__version__, imported.tensorflow_version)
self.assertEqual(versions.__git_version__,
imported.tensorflow_git_version)
def _unfed_placeholder_signature(self):
export_graph = ops.Graph()
with export_graph.as_default():