1. saved_model_cli now properly identifies read-only and non-readonly vars in a graph and only freezes the readonly variables (and marks the others as not readonly). 2. fixed bugs in convert_variables_to_constants where blacklist/whitelist was not properly respected for chains of operations. while the VarHandleOp node was properly blacklisted, downstream ops that used the resources would be converted as if the variable had been frozen. so for example the following graph would break: VarHandleOp -> Identity -> [first arg in] ResourceAssign The attrs of the Identity op would be changed to DT_FLOAT though it should stay as DT_RESOURCE. 3. Added support for freezing of *Nd ops (ResourceGatherNd, ResourceScatterNd). PiperOrigin-RevId: 293239272 Change-Id: I06de3d139c5585a93ba585f076edf92137e4c48a
825 lines
32 KiB
Python
825 lines
32 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for SavedModelCLI tool.
|
|
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import contextlib
|
|
import os
|
|
import pickle
|
|
import shutil
|
|
import sys
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
from six import StringIO
|
|
|
|
from tensorflow.core.framework import types_pb2
|
|
from tensorflow.core.protobuf import meta_graph_pb2
|
|
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.lib.io import file_io
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.saved_model import save
|
|
from tensorflow.python.tools import saved_model_cli
|
|
from tensorflow.python.training.tracking import tracking
|
|
|
|
SAVED_MODEL_PATH = ('cc/saved_model/testdata/half_plus_two/00000123')
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def captured_output():
|
|
new_out, new_err = StringIO(), StringIO()
|
|
old_out, old_err = sys.stdout, sys.stderr
|
|
try:
|
|
sys.stdout, sys.stderr = new_out, new_err
|
|
yield sys.stdout, sys.stderr
|
|
finally:
|
|
sys.stdout, sys.stderr = old_out, old_err
|
|
|
|
|
|
class SavedModelCLITestCase(test.TestCase, parameterized.TestCase):
|
|
|
|
def testShowCommandAll(self):
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
self.parser = saved_model_cli.create_parser()
|
|
args = self.parser.parse_args(['show', '--dir', base_path, '--all'])
|
|
with captured_output() as (out, err):
|
|
saved_model_cli.show(args)
|
|
output = out.getvalue().strip()
|
|
# pylint: disable=line-too-long
|
|
exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
|
|
|
|
signature_def['classify_x2_to_y3']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['inputs'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: x2:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['scores'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: y3:0
|
|
Method name is: tensorflow/serving/classify
|
|
|
|
signature_def['classify_x_to_y']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['inputs'] tensor_info:
|
|
dtype: DT_STRING
|
|
shape: unknown_rank
|
|
name: tf_example:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['scores'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: y:0
|
|
Method name is: tensorflow/serving/classify
|
|
|
|
signature_def['regress_x2_to_y3']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['inputs'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: x2:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['outputs'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: y3:0
|
|
Method name is: tensorflow/serving/regress
|
|
|
|
signature_def['regress_x_to_y']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['inputs'] tensor_info:
|
|
dtype: DT_STRING
|
|
shape: unknown_rank
|
|
name: tf_example:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['outputs'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: y:0
|
|
Method name is: tensorflow/serving/regress
|
|
|
|
signature_def['regress_x_to_y2']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['inputs'] tensor_info:
|
|
dtype: DT_STRING
|
|
shape: unknown_rank
|
|
name: tf_example:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['outputs'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: y2:0
|
|
Method name is: tensorflow/serving/regress
|
|
|
|
signature_def['serving_default']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['x'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: x:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['y'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (-1, 1)
|
|
name: y:0
|
|
Method name is: tensorflow/serving/predict"""
|
|
# pylint: enable=line-too-long
|
|
self.maxDiff = None # Produce a useful error msg if the comparison fails
|
|
self.assertMultiLineEqual(output, exp_out)
|
|
self.assertEqual(err.getvalue().strip(), '')
|
|
|
|
def testShowAllWithFunctions(self):
|
|
|
|
class DummyModel(tracking.AutoTrackable):
|
|
"""Model with callable polymorphic functions specified."""
|
|
|
|
@def_function.function
|
|
def func1(self, a, b, c):
|
|
if c:
|
|
return a + b
|
|
else:
|
|
return a * b
|
|
|
|
@def_function.function(input_signature=[
|
|
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
|
|
])
|
|
def func2(self, x):
|
|
return x + 2
|
|
|
|
@def_function.function
|
|
def __call__(self, y, c=7):
|
|
return y + 2 * c
|
|
|
|
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
|
dummy_model = DummyModel()
|
|
# Call with specific values to create new polymorphic function traces.
|
|
dummy_model.func1(
|
|
constant_op.constant(5), constant_op.constant(9), True)
|
|
dummy_model(constant_op.constant(5))
|
|
save.save(dummy_model, saved_model_dir)
|
|
self.parser = saved_model_cli.create_parser()
|
|
args = self.parser.parse_args(['show', '--dir', saved_model_dir, '--all'])
|
|
with captured_output() as (out, err):
|
|
saved_model_cli.show(args)
|
|
output = out.getvalue().strip()
|
|
exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
|
|
|
|
signature_def['__saved_model_init_op']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['__saved_model_init_op'] tensor_info:
|
|
dtype: DT_INVALID
|
|
shape: unknown_rank
|
|
name: NoOp
|
|
Method name is:
|
|
|
|
signature_def['serving_default']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['x'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (2, 2)
|
|
name: serving_default_x:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['output_0'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: (2, 2)
|
|
name: PartitionedCall:0
|
|
Method name is: tensorflow/serving/predict
|
|
|
|
Defined Functions:
|
|
Function Name: '__call__'
|
|
Option #1
|
|
Callable with:
|
|
Argument #1
|
|
y: TensorSpec(shape=(), dtype=tf.int32, name='y')
|
|
Argument #2
|
|
DType: int
|
|
Value: 7
|
|
|
|
Function Name: 'func1'
|
|
Option #1
|
|
Callable with:
|
|
Argument #1
|
|
a: TensorSpec(shape=(), dtype=tf.int32, name='a')
|
|
Argument #2
|
|
b: TensorSpec(shape=(), dtype=tf.int32, name='b')
|
|
Argument #3
|
|
DType: bool
|
|
Value: True
|
|
|
|
Function Name: 'func2'
|
|
Option #1
|
|
Callable with:
|
|
Argument #1
|
|
x: TensorSpec(shape=(2, 2), dtype=tf.float32, name='x')
|
|
""".strip() # pylint: enable=line-too-long
|
|
self.maxDiff = None # Produce a useful error msg if the comparison fails
|
|
self.assertMultiLineEqual(output, exp_out)
|
|
self.assertEqual(err.getvalue().strip(), '')
|
|
|
|
def testShowAllWithPureConcreteFunction(self):
|
|
|
|
class DummyModel(tracking.AutoTrackable):
|
|
"""Model with a callable concrete function."""
|
|
|
|
def __init__(self):
|
|
function = def_function.function(
|
|
self.multiply,
|
|
input_signature=[
|
|
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
|
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32)
|
|
])
|
|
self.pure_concrete_function = function.get_concrete_function()
|
|
super(DummyModel, self).__init__()
|
|
|
|
def multiply(self, a, b):
|
|
return a * b
|
|
|
|
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
|
dummy_model = DummyModel()
|
|
save.save(dummy_model, saved_model_dir)
|
|
self.parser = saved_model_cli.create_parser()
|
|
args = self.parser.parse_args(['show', '--dir', saved_model_dir, '--all'])
|
|
with captured_output() as (out, err):
|
|
saved_model_cli.show(args)
|
|
output = out.getvalue().strip()
|
|
exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
|
|
|
|
signature_def['__saved_model_init_op']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['__saved_model_init_op'] tensor_info:
|
|
dtype: DT_INVALID
|
|
shape: unknown_rank
|
|
name: NoOp
|
|
Method name is:
|
|
|
|
signature_def['serving_default']:
|
|
The given SavedModel SignatureDef contains the following input(s):
|
|
inputs['a'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: ()
|
|
name: serving_default_a:0
|
|
inputs['b'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: ()
|
|
name: serving_default_b:0
|
|
The given SavedModel SignatureDef contains the following output(s):
|
|
outputs['output_0'] tensor_info:
|
|
dtype: DT_FLOAT
|
|
shape: ()
|
|
name: PartitionedCall:0
|
|
Method name is: tensorflow/serving/predict
|
|
|
|
Defined Functions:
|
|
Function Name: 'pure_concrete_function'
|
|
Option #1
|
|
Callable with:
|
|
Argument #1
|
|
a: TensorSpec(shape=(), dtype=tf.float32, name='a')
|
|
Argument #2
|
|
b: TensorSpec(shape=(), dtype=tf.float32, name='b')
|
|
""".strip() # pylint: enable=line-too-long
|
|
self.maxDiff = None # Produce a useful error msg if the comparison fails
|
|
self.assertMultiLineEqual(output, exp_out)
|
|
self.assertEqual(err.getvalue().strip(), '')
|
|
|
|
def testShowCommandTags(self):
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
self.parser = saved_model_cli.create_parser()
|
|
args = self.parser.parse_args(['show', '--dir', base_path])
|
|
with captured_output() as (out, err):
|
|
saved_model_cli.show(args)
|
|
output = out.getvalue().strip()
|
|
exp_out = 'The given SavedModel contains the following tag-sets:\n\'serve\''
|
|
self.assertMultiLineEqual(output, exp_out)
|
|
self.assertEqual(err.getvalue().strip(), '')
|
|
|
|
def testShowCommandSignature(self):
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
self.parser = saved_model_cli.create_parser()
|
|
args = self.parser.parse_args(
|
|
['show', '--dir', base_path, '--tag_set', 'serve'])
|
|
with captured_output() as (out, err):
|
|
saved_model_cli.show(args)
|
|
output = out.getvalue().strip()
|
|
exp_header = ('The given SavedModel MetaGraphDef contains SignatureDefs '
|
|
'with the following keys:')
|
|
exp_start = 'SignatureDef key: '
|
|
exp_keys = [
|
|
'"classify_x2_to_y3"', '"classify_x_to_y"', '"regress_x2_to_y3"',
|
|
'"regress_x_to_y"', '"regress_x_to_y2"', '"serving_default"'
|
|
]
|
|
# Order of signatures does not matter
|
|
self.assertMultiLineEqual(
|
|
output,
|
|
'\n'.join([exp_header] + [exp_start + exp_key for exp_key in exp_keys]))
|
|
self.assertEqual(err.getvalue().strip(), '')
|
|
|
|
def testShowCommandErrorNoTagSet(self):
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
self.parser = saved_model_cli.create_parser()
|
|
args = self.parser.parse_args(
|
|
['show', '--dir', base_path, '--tag_set', 'badtagset'])
|
|
with self.assertRaises(RuntimeError):
|
|
saved_model_cli.show(args)
|
|
|
|
def testShowCommandInputsOutputs(self):
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
self.parser = saved_model_cli.create_parser()
|
|
args = self.parser.parse_args([
|
|
'show', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'serving_default'
|
|
])
|
|
with captured_output() as (out, err):
|
|
saved_model_cli.show(args)
|
|
output = out.getvalue().strip()
|
|
expected_output = (
|
|
'The given SavedModel SignatureDef contains the following input(s):\n'
|
|
' inputs[\'x\'] tensor_info:\n'
|
|
' dtype: DT_FLOAT\n shape: (-1, 1)\n name: x:0\n'
|
|
'The given SavedModel SignatureDef contains the following output(s):\n'
|
|
' outputs[\'y\'] tensor_info:\n'
|
|
' dtype: DT_FLOAT\n shape: (-1, 1)\n name: y:0\n'
|
|
'Method name is: tensorflow/serving/predict')
|
|
self.assertEqual(output, expected_output)
|
|
self.assertEqual(err.getvalue().strip(), '')
|
|
|
|
def testPrintREFTypeTensor(self):
|
|
ref_tensor_info = meta_graph_pb2.TensorInfo()
|
|
ref_tensor_info.dtype = types_pb2.DT_FLOAT_REF
|
|
with captured_output() as (out, err):
|
|
saved_model_cli._print_tensor_info(ref_tensor_info)
|
|
self.assertTrue('DT_FLOAT_REF' in out.getvalue().strip())
|
|
self.assertEqual(err.getvalue().strip(), '')
|
|
|
|
def testInputPreProcessFormats(self):
|
|
input_str = 'input1=/path/file.txt[ab3];input2=file2'
|
|
input_expr_str = 'input3=np.zeros([2,2]);input4=[4,5]'
|
|
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
|
|
input_expr_dict = saved_model_cli.preprocess_input_exprs_arg_string(
|
|
input_expr_str)
|
|
self.assertTrue(input_dict['input1'] == ('/path/file.txt', 'ab3'))
|
|
self.assertTrue(input_dict['input2'] == ('file2', None))
|
|
print(input_expr_dict['input3'])
|
|
self.assertAllClose(input_expr_dict['input3'], np.zeros([2, 2]))
|
|
self.assertAllClose(input_expr_dict['input4'], [4, 5])
|
|
self.assertTrue(len(input_dict) == 2)
|
|
self.assertTrue(len(input_expr_dict) == 2)
|
|
|
|
def testInputPreProcessFileNames(self):
|
|
input_str = (r'inputx=C:\Program Files\data.npz[v:0];'
|
|
r'input:0=c:\PROGRA~1\data.npy')
|
|
input_dict = saved_model_cli.preprocess_inputs_arg_string(input_str)
|
|
self.assertTrue(input_dict['inputx'] == (r'C:\Program Files\data.npz',
|
|
'v:0'))
|
|
self.assertTrue(input_dict['input:0'] == (r'c:\PROGRA~1\data.npy', None))
|
|
|
|
def testInputPreProcessErrorBadFormat(self):
|
|
input_str = 'inputx=file[[v1]v2'
|
|
with self.assertRaises(RuntimeError):
|
|
saved_model_cli.preprocess_inputs_arg_string(input_str)
|
|
input_str = 'inputx:file'
|
|
with self.assertRaises(RuntimeError):
|
|
saved_model_cli.preprocess_inputs_arg_string(input_str)
|
|
input_str = 'inputx:np.zeros((5))'
|
|
with self.assertRaises(RuntimeError):
|
|
saved_model_cli.preprocess_input_exprs_arg_string(input_str)
|
|
|
|
def testInputParserNPY(self):
|
|
x0 = np.array([[1], [2]])
|
|
x1 = np.array(range(6)).reshape(2, 3)
|
|
input0_path = os.path.join(test.get_temp_dir(), 'input0.npy')
|
|
input1_path = os.path.join(test.get_temp_dir(), 'input1.npy')
|
|
np.save(input0_path, x0)
|
|
np.save(input1_path, x1)
|
|
input_str = 'x0=' + input0_path + '[x0];x1=' + input1_path
|
|
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
|
input_str, '', '')
|
|
self.assertTrue(np.all(feed_dict['x0'] == x0))
|
|
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
|
|
|
def testInputParserNPZ(self):
|
|
x0 = np.array([[1], [2]])
|
|
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
|
np.savez(input_path, a=x0)
|
|
input_str = 'x=' + input_path + '[a];y=' + input_path
|
|
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
|
input_str, '', '')
|
|
self.assertTrue(np.all(feed_dict['x'] == x0))
|
|
self.assertTrue(np.all(feed_dict['y'] == x0))
|
|
|
|
def testInputParserPickle(self):
|
|
pkl0 = {'a': 5, 'b': np.array(range(4))}
|
|
pkl1 = np.array([1])
|
|
pkl2 = np.array([[1], [3]])
|
|
input_path0 = os.path.join(test.get_temp_dir(), 'pickle0.pkl')
|
|
input_path1 = os.path.join(test.get_temp_dir(), 'pickle1.pkl')
|
|
input_path2 = os.path.join(test.get_temp_dir(), 'pickle2.pkl')
|
|
with open(input_path0, 'wb') as f:
|
|
pickle.dump(pkl0, f)
|
|
with open(input_path1, 'wb') as f:
|
|
pickle.dump(pkl1, f)
|
|
with open(input_path2, 'wb') as f:
|
|
pickle.dump(pkl2, f)
|
|
input_str = 'x=' + input_path0 + '[b];y=' + input_path1 + '[c];'
|
|
input_str += 'z=' + input_path2
|
|
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
|
input_str, '', '')
|
|
self.assertTrue(np.all(feed_dict['x'] == pkl0['b']))
|
|
self.assertTrue(np.all(feed_dict['y'] == pkl1))
|
|
self.assertTrue(np.all(feed_dict['z'] == pkl2))
|
|
|
|
def testInputParserPythonExpression(self):
|
|
x1 = np.ones([2, 10])
|
|
x2 = np.array([[1], [2], [3]])
|
|
x3 = np.mgrid[0:5, 0:5]
|
|
x4 = [[3], [4]]
|
|
input_expr_str = ('x1=np.ones([2,10]);x2=np.array([[1],[2],[3]]);'
|
|
'x3=np.mgrid[0:5,0:5];x4=[[3],[4]]')
|
|
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
|
'', input_expr_str, '')
|
|
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
|
self.assertTrue(np.all(feed_dict['x2'] == x2))
|
|
self.assertTrue(np.all(feed_dict['x3'] == x3))
|
|
self.assertTrue(np.all(feed_dict['x4'] == x4))
|
|
|
|
def testInputParserBoth(self):
|
|
x0 = np.array([[1], [2]])
|
|
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
|
np.savez(input_path, a=x0)
|
|
x1 = np.ones([2, 10])
|
|
input_str = 'x0=' + input_path + '[a]'
|
|
input_expr_str = 'x1=np.ones([2,10])'
|
|
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
|
input_str, input_expr_str, '')
|
|
self.assertTrue(np.all(feed_dict['x0'] == x0))
|
|
self.assertTrue(np.all(feed_dict['x1'] == x1))
|
|
|
|
def testInputParserBothDuplicate(self):
|
|
x0 = np.array([[1], [2]])
|
|
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
|
np.savez(input_path, a=x0)
|
|
x1 = np.ones([2, 10])
|
|
input_str = 'x0=' + input_path + '[a]'
|
|
input_expr_str = 'x0=np.ones([2,10])'
|
|
feed_dict = saved_model_cli.load_inputs_from_input_arg_string(
|
|
input_str, input_expr_str, '')
|
|
self.assertTrue(np.all(feed_dict['x0'] == x1))
|
|
|
|
def testInputParserErrorNoName(self):
|
|
x0 = np.array([[1], [2]])
|
|
x1 = np.array(range(5))
|
|
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
|
np.savez(input_path, a=x0, b=x1)
|
|
input_str = 'x=' + input_path
|
|
with self.assertRaises(RuntimeError):
|
|
saved_model_cli.load_inputs_from_input_arg_string(input_str, '', '')
|
|
|
|
def testInputParserErrorWrongName(self):
|
|
x0 = np.array([[1], [2]])
|
|
x1 = np.array(range(5))
|
|
input_path = os.path.join(test.get_temp_dir(), 'input.npz')
|
|
np.savez(input_path, a=x0, b=x1)
|
|
input_str = 'x=' + input_path + '[c]'
|
|
with self.assertRaises(RuntimeError):
|
|
saved_model_cli.load_inputs_from_input_arg_string(input_str, '', '')
|
|
|
|
def testRunCommandInputExamples(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'regress_x_to_y', '--input_examples',
|
|
'inputs=[{"x":[8.0],"x2":[5.0]}, {"x":[4.0],"x2":[3.0]}]', '--outdir',
|
|
output_dir
|
|
])
|
|
saved_model_cli.run(args)
|
|
y_actual = np.load(os.path.join(output_dir, 'outputs.npy'))
|
|
y_expected = np.array([[6.0], [4.0]])
|
|
self.assertAllEqual(y_expected, y_actual)
|
|
|
|
def testRunCommandExistingOutdir(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
x = np.array([[1], [2]])
|
|
x_notused = np.zeros((6, 3))
|
|
input_path = os.path.join(test.get_temp_dir(), 'testRunCommand_inputs.npz')
|
|
np.savez(input_path, x0=x, x1=x_notused)
|
|
output_file = os.path.join(test.get_temp_dir(), 'outputs.npy')
|
|
if os.path.exists(output_file):
|
|
os.remove(output_file)
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'regress_x2_to_y3', '--inputs', 'inputs=' + input_path + '[x0]',
|
|
'--outdir',
|
|
test.get_temp_dir()
|
|
])
|
|
saved_model_cli.run(args)
|
|
y_actual = np.load(output_file)
|
|
y_expected = np.array([[3.5], [4.0]])
|
|
self.assertAllClose(y_expected, y_actual)
|
|
|
|
def testRunCommandNewOutdir(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
x = np.array([[1], [2]])
|
|
x_notused = np.zeros((6, 3))
|
|
input_path = os.path.join(test.get_temp_dir(),
|
|
'testRunCommandNewOutdir_inputs.npz')
|
|
output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
|
|
if os.path.isdir(output_dir):
|
|
shutil.rmtree(output_dir)
|
|
np.savez(input_path, x0=x, x1=x_notused)
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
|
|
output_dir
|
|
])
|
|
saved_model_cli.run(args)
|
|
y_actual = np.load(os.path.join(output_dir, 'y.npy'))
|
|
y_expected = np.array([[2.5], [3.0]])
|
|
self.assertAllClose(y_expected, y_actual)
|
|
|
|
def testRunCommandOutOverwrite(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
x = np.array([[1], [2]])
|
|
x_notused = np.zeros((6, 3))
|
|
input_path = os.path.join(test.get_temp_dir(),
|
|
'testRunCommandOutOverwrite_inputs.npz')
|
|
np.savez(input_path, x0=x, x1=x_notused)
|
|
output_file = os.path.join(test.get_temp_dir(), 'y.npy')
|
|
open(output_file, 'a').close()
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
|
|
test.get_temp_dir(), '--overwrite'
|
|
])
|
|
saved_model_cli.run(args)
|
|
y_actual = np.load(output_file)
|
|
y_expected = np.array([[2.5], [3.0]])
|
|
self.assertAllClose(y_expected, y_actual)
|
|
|
|
def testRunCommandInvalidInputKeyError(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'regress_x2_to_y3', '--input_exprs', 'x2=np.ones((3,1))'
|
|
])
|
|
with self.assertRaises(ValueError):
|
|
saved_model_cli.run(args)
|
|
|
|
def testRunCommandInputExamplesNotListError(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'regress_x_to_y', '--input_examples', 'inputs={"x":8.0,"x2":5.0}',
|
|
'--outdir', output_dir
|
|
])
|
|
with self.assertRaisesRegexp(ValueError, 'must be a list'):
|
|
saved_model_cli.run(args)
|
|
|
|
def testRunCommandInputExamplesFeatureValueNotListError(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'regress_x_to_y', '--input_examples', 'inputs=[{"x":8.0,"x2":5.0}]',
|
|
'--outdir', output_dir
|
|
])
|
|
with self.assertRaisesRegexp(ValueError, 'feature value must be a list'):
|
|
saved_model_cli.run(args)
|
|
|
|
def testRunCommandInputExamplesFeatureBadType(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'regress_x_to_y', '--input_examples', 'inputs=[{"x":[[1],[2]]}]',
|
|
'--outdir', output_dir
|
|
])
|
|
with self.assertRaisesRegexp(ValueError, 'is not supported'):
|
|
saved_model_cli.run(args)
|
|
|
|
def testRunCommandOutputFileExistError(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
x = np.array([[1], [2]])
|
|
x_notused = np.zeros((6, 3))
|
|
input_path = os.path.join(test.get_temp_dir(),
|
|
'testRunCommandOutOverwrite_inputs.npz')
|
|
np.savez(input_path, x0=x, x1=x_notused)
|
|
output_file = os.path.join(test.get_temp_dir(), 'y.npy')
|
|
open(output_file, 'a').close()
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
|
|
test.get_temp_dir()
|
|
])
|
|
with self.assertRaises(RuntimeError):
|
|
saved_model_cli.run(args)
|
|
|
|
def testRunCommandInputNotGivenError(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'serving_default'
|
|
])
|
|
with self.assertRaises(AttributeError):
|
|
saved_model_cli.run(args)
|
|
|
|
def testRunCommandWithDebuggerEnabled(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
x = np.array([[1], [2]])
|
|
x_notused = np.zeros((6, 3))
|
|
input_path = os.path.join(test.get_temp_dir(),
|
|
'testRunCommandNewOutdir_inputs.npz')
|
|
output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
|
|
if os.path.isdir(output_dir):
|
|
shutil.rmtree(output_dir)
|
|
np.savez(input_path, x0=x, x1=x_notused)
|
|
args = self.parser.parse_args([
|
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
|
'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
|
|
output_dir, '--tf_debug'
|
|
])
|
|
|
|
def fake_wrapper_session(sess):
|
|
return sess
|
|
|
|
with test.mock.patch.object(local_cli_wrapper,
|
|
'LocalCLIDebugWrapperSession',
|
|
side_effect=fake_wrapper_session,
|
|
autospec=True) as fake:
|
|
saved_model_cli.run(args)
|
|
fake.assert_called_with(test.mock.ANY)
|
|
|
|
y_actual = np.load(os.path.join(output_dir, 'y.npy'))
|
|
y_expected = np.array([[2.5], [3.0]])
|
|
self.assertAllClose(y_expected, y_actual)
|
|
|
|
def testScanCommand(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
args = self.parser.parse_args(['scan', '--dir', base_path])
|
|
with captured_output() as (out, _):
|
|
saved_model_cli.scan(args)
|
|
output = out.getvalue().strip()
|
|
self.assertTrue('does not contain blacklisted ops' in output)
|
|
|
|
def testScanCommandFoundBlacklistedOp(self):
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
args = self.parser.parse_args(
|
|
['scan', '--dir', base_path, '--tag_set', 'serve'])
|
|
op_blacklist = saved_model_cli._OP_BLACKLIST
|
|
saved_model_cli._OP_BLACKLIST = set(['VariableV2'])
|
|
with captured_output() as (out, _):
|
|
saved_model_cli.scan(args)
|
|
saved_model_cli._OP_BLACKLIST = op_blacklist
|
|
output = out.getvalue().strip()
|
|
self.assertTrue('\'VariableV2\'' in output)
|
|
|
|
def testAOTCompileCPUWrongSignatureDefKey(self):
|
|
if not test.is_built_with_xla():
|
|
self.skipTest('Skipping test because XLA is not compiled in.')
|
|
|
|
self.parser = saved_model_cli.create_parser()
|
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
|
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
|
|
args = self.parser.parse_args(
|
|
['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
|
|
'--output_prefix', output_dir,
|
|
'--cpp_class', 'Compiled',
|
|
'--signature_def_key', 'MISSING'])
|
|
with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'):
|
|
saved_model_cli.aot_compile_cpu(args)
|
|
|
|
class AOTCompileDummyModel(tracking.AutoTrackable):
|
|
"""Model compatible with XLA compilation."""
|
|
|
|
def __init__(self):
|
|
self.var = variables.Variable(1.0, name='my_var')
|
|
self.write_var = variables.Variable(1.0, name='write_var')
|
|
|
|
@def_function.function(input_signature=[
|
|
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
|
|
# Test unused inputs.
|
|
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
|
])
|
|
def func2(self, x, y):
|
|
del y
|
|
return {'res': x + self.var}
|
|
|
|
@def_function.function(input_signature=[
|
|
# Test large inputs.
|
|
tensor_spec.TensorSpec(shape=(2048, 16), dtype=dtypes.float32),
|
|
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
|
])
|
|
def func3(self, x, y):
|
|
del y
|
|
return {'res': x + self.var}
|
|
|
|
@def_function.function(input_signature=[
|
|
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
|
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
|
])
|
|
def func_write(self, x, y):
|
|
del y
|
|
self.write_var.assign(x + self.var)
|
|
return {'res': self.write_var}
|
|
|
|
@parameterized.named_parameters(
|
|
('VariablesToFeedNone', '', 'func2'),
|
|
('VariablesToFeedAll', 'all', 'func2'),
|
|
('VariablesToFeedMyVar', 'my_var', 'func2'),
|
|
('VariablesToFeedNoneLargeConstant', '', 'func3'),
|
|
('WriteToWriteVar', 'all', 'func_write'),
|
|
)
|
|
def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed, func):
|
|
if not test.is_built_with_xla():
|
|
self.skipTest('Skipping test because XLA is not compiled in.')
|
|
|
|
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
|
dummy_model = self.AOTCompileDummyModel()
|
|
func = getattr(dummy_model, func)
|
|
with self.cached_session():
|
|
self.evaluate(dummy_model.var.initializer)
|
|
self.evaluate(dummy_model.write_var.initializer)
|
|
save.save(dummy_model, saved_model_dir, signatures={'func': func})
|
|
|
|
self.parser = saved_model_cli.create_parser()
|
|
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
|
|
args = self.parser.parse_args([
|
|
'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
|
'--signature_def_key', 'func',
|
|
'--output_prefix', output_prefix, '--variables_to_feed',
|
|
variables_to_feed, '--cpp_class', 'Generated'
|
|
]) # Use the default seving signature_key.
|
|
with test.mock.patch.object(logging, 'warn') as captured_warn:
|
|
saved_model_cli.aot_compile_cpu(args)
|
|
self.assertRegexpMatches(
|
|
str(captured_warn.call_args),
|
|
'Signature input key \'y\'.*has been pruned while freezing the graph.')
|
|
self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
|
|
self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
|
|
self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix)))
|
|
self.assertTrue(
|
|
file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
|
|
header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix))
|
|
self.assertIn('class Generated', header_contents)
|
|
self.assertIn('arg_feed_x_data', header_contents)
|
|
self.assertIn('result_fetch_res_data', header_contents)
|
|
# arg_y got filtered out as it's not used by the output.
|
|
self.assertNotIn('arg_feed_y_data', header_contents)
|
|
if variables_to_feed:
|
|
# Read-only-variables' setters preserve constness.
|
|
self.assertIn('set_var_param_my_var_data(const float', header_contents)
|
|
self.assertNotIn('set_var_param_my_var_data(float', header_contents)
|
|
if func == dummy_model.func_write:
|
|
# Writeable variables setters do not preserve constness.
|
|
self.assertIn('set_var_param_write_var_data(float', header_contents)
|
|
self.assertNotIn(
|
|
'set_var_param_write_var_data(const float', header_contents)
|
|
|
|
makefile_contents = file_io.read_file_to_string(
|
|
'{}_makefile.inc'.format(output_prefix))
|
|
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test.main()
|