Expose version information in SavedModels loaded with tf.saved_model.load
Adds tensorflow_version and tensorflow_git_version properties. PiperOrigin-RevId: 247044940
This commit is contained in:
parent
75bd1d5c92
commit
98765e41f9
@ -380,6 +380,9 @@ def load(export_dir, tags=None):
|
||||
saved_model_proto,
|
||||
export_dir)
|
||||
root = loader.get(0)
|
||||
root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
|
||||
root.tensorflow_git_version = (
|
||||
meta_graph_def.meta_info_def.tensorflow_git_version)
|
||||
else:
|
||||
with ops.init_scope():
|
||||
root = load_v1_in_v2.load(export_dir, tags)
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import input_layer
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
@ -1547,6 +1548,12 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
original,
|
||||
root.model.traced_call(array_ops.zeros([1, 1])).numpy())
|
||||
|
||||
def test_version_info(self, cycles):
|
||||
root = util.Checkpoint()
|
||||
root = self.cycle(root, cycles)
|
||||
self.assertEqual(versions.__version__, root.tensorflow_version)
|
||||
self.assertEqual(versions.__git_version__, root.tensorflow_git_version)
|
||||
|
||||
def test_functional_model_with_conv(self, cycles):
|
||||
x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32)
|
||||
conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x)
|
||||
|
@ -171,6 +171,10 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
root.signatures = signature_serialization.create_signature_map(
|
||||
signature_functions)
|
||||
root.variables = list(wrapped.graph.variables)
|
||||
root.tensorflow_version = (
|
||||
meta_graph_def.meta_info_def.tensorflow_version)
|
||||
root.tensorflow_git_version = (
|
||||
meta_graph_def.meta_info_def.tensorflow_git_version)
|
||||
return root
|
||||
|
||||
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -334,6 +335,13 @@ class LoadTest(test.TestCase):
|
||||
imported = load.load(path)
|
||||
self.assertEqual([2], imported.signatures["key"]()["value"].shape)
|
||||
|
||||
def test_version_info(self):
|
||||
path = self._signature_with_no_inputs()
|
||||
imported = load.load(path)
|
||||
self.assertEqual(versions.__version__, imported.tensorflow_version)
|
||||
self.assertEqual(versions.__git_version__,
|
||||
imported.tensorflow_git_version)
|
||||
|
||||
def _unfed_placeholder_signature(self):
|
||||
export_graph = ops.Graph()
|
||||
with export_graph.as_default():
|
||||
|
Loading…
Reference in New Issue
Block a user