Expose version information in SavedModels loaded with tf.saved_model.load
Adds tensorflow_version and tensorflow_git_version properties. PiperOrigin-RevId: 247044940
This commit is contained in:
parent
75bd1d5c92
commit
98765e41f9
@ -380,6 +380,9 @@ def load(export_dir, tags=None):
|
|||||||
saved_model_proto,
|
saved_model_proto,
|
||||||
export_dir)
|
export_dir)
|
||||||
root = loader.get(0)
|
root = loader.get(0)
|
||||||
|
root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
|
||||||
|
root.tensorflow_git_version = (
|
||||||
|
meta_graph_def.meta_info_def.tensorflow_git_version)
|
||||||
else:
|
else:
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
root = load_v1_in_v2.load(export_dir, tags)
|
root = load_v1_in_v2.load(export_dir, tags)
|
||||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.framework import versions
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.keras.engine import input_layer
|
from tensorflow.python.keras.engine import input_layer
|
||||||
from tensorflow.python.keras.engine import sequential
|
from tensorflow.python.keras.engine import sequential
|
||||||
@ -1547,6 +1548,12 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
|||||||
original,
|
original,
|
||||||
root.model.traced_call(array_ops.zeros([1, 1])).numpy())
|
root.model.traced_call(array_ops.zeros([1, 1])).numpy())
|
||||||
|
|
||||||
|
def test_version_info(self, cycles):
|
||||||
|
root = util.Checkpoint()
|
||||||
|
root = self.cycle(root, cycles)
|
||||||
|
self.assertEqual(versions.__version__, root.tensorflow_version)
|
||||||
|
self.assertEqual(versions.__git_version__, root.tensorflow_git_version)
|
||||||
|
|
||||||
def test_functional_model_with_conv(self, cycles):
|
def test_functional_model_with_conv(self, cycles):
|
||||||
x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32)
|
x = input_layer.Input(name="x", shape=(None, None, 3), dtype=dtypes.float32)
|
||||||
conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x)
|
conved = convolutional.Conv2D(filters=3, kernel_size=3, dilation_rate=2)(x)
|
||||||
|
@ -171,6 +171,10 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||||||
root.signatures = signature_serialization.create_signature_map(
|
root.signatures = signature_serialization.create_signature_map(
|
||||||
signature_functions)
|
signature_functions)
|
||||||
root.variables = list(wrapped.graph.variables)
|
root.variables = list(wrapped.graph.variables)
|
||||||
|
root.tensorflow_version = (
|
||||||
|
meta_graph_def.meta_info_def.tensorflow_version)
|
||||||
|
root.tensorflow_git_version = (
|
||||||
|
meta_graph_def.meta_info_def.tensorflow_git_version)
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.eager import test
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import versions
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -334,6 +335,13 @@ class LoadTest(test.TestCase):
|
|||||||
imported = load.load(path)
|
imported = load.load(path)
|
||||||
self.assertEqual([2], imported.signatures["key"]()["value"].shape)
|
self.assertEqual([2], imported.signatures["key"]()["value"].shape)
|
||||||
|
|
||||||
|
def test_version_info(self):
|
||||||
|
path = self._signature_with_no_inputs()
|
||||||
|
imported = load.load(path)
|
||||||
|
self.assertEqual(versions.__version__, imported.tensorflow_version)
|
||||||
|
self.assertEqual(versions.__git_version__,
|
||||||
|
imported.tensorflow_git_version)
|
||||||
|
|
||||||
def _unfed_placeholder_signature(self):
|
def _unfed_placeholder_signature(self):
|
||||||
export_graph = ops.Graph()
|
export_graph = ops.Graph()
|
||||||
with export_graph.as_default():
|
with export_graph.as_default():
|
||||||
|
Loading…
Reference in New Issue
Block a user