Remove saved_model test from examples.
PiperOrigin-RevId: 336914830 Change-Id: Ib15d58225c837d9550901eddc623961be028cac7
This commit is contained in:
parent
12c7ef6bec
commit
c8a9751c55
@ -1,84 +0,0 @@
|
||||
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "integration_scripts",
|
||||
srcs = [
|
||||
"deploy_mnist_cnn.py",
|
||||
"export_mnist_cnn.py",
|
||||
"export_rnn_cell.py",
|
||||
"export_simple_text_embedding.py",
|
||||
"export_text_rnn_model.py",
|
||||
"integration_scripts.py",
|
||||
"use_mnist_cnn.py",
|
||||
"use_model_in_sequential_keras.py",
|
||||
"use_rnn_cell.py",
|
||||
"use_text_embedding_in_dataset.py",
|
||||
"use_text_rnn_model.py",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":distribution_strategy_utils",
|
||||
":mnist_util",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "mnist_util",
|
||||
srcs = ["mnist_util.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "distribution_strategy_utils",
|
||||
srcs = ["distribution_strategy_utils.py"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "saved_model_test",
|
||||
srcs = [
|
||||
"saved_model_test.py",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"no_pip", # b/131697937 and b/132196869
|
||||
"noasan", # forge input size exceeded
|
||||
"nomsan", # forge input size exceeded
|
||||
"notsan", # forge input size exceeded
|
||||
],
|
||||
tpu_tags = [
|
||||
"no_oss", # Test infra collision (b/157754990)
|
||||
],
|
||||
deps = [
|
||||
":distribution_strategy_utils",
|
||||
":integration_scripts",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_combinations",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
# b/132234211: Target added to support internal test target that runs the test
|
||||
# in an environment that has the extra dependencies required to test integration
|
||||
# with non core tensorflow packages.
|
||||
py_library(
|
||||
name = "saved_model_test_lib",
|
||||
srcs = [
|
||||
"saved_model_test.py",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [":integration_scripts"],
|
||||
)
|
@ -1,107 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Deploys a SavedModel with an MNIST classifier to TFLite."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
'saved_model_dir', None,
|
||||
'Directory of the SavedModel to deploy.')
|
||||
flags.DEFINE_bool(
|
||||
'use_fashion_mnist', False,
|
||||
'Use Fashion MNIST (products) instead of the real MNIST (digits).')
|
||||
flags.DEFINE_bool(
|
||||
'fast_test_mode', False,
|
||||
'Limit amount of test data for running in unit tests.')
|
||||
flags.DEFINE_string(
|
||||
'tflite_output_file', None,
|
||||
'The filename of the .tflite model file to write (optional).')
|
||||
flags.DEFINE_bool(
|
||||
'reload_as_keras_model', True,
|
||||
'Also test tf.keras.models.load_model() on --saved_model_dir.')
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# First convert the SavedModel in a pristine environment.
|
||||
converter = tf.lite.TFLiteConverter.from_saved_model(FLAGS.saved_model_dir)
|
||||
lite_model_content = converter.convert()
|
||||
# Here is how you can save it for actual deployment.
|
||||
if FLAGS.tflite_output_file:
|
||||
with open(FLAGS.tflite_output_file, 'wb') as outfile:
|
||||
outfile.write(lite_model_content)
|
||||
# For testing, the TFLite model can be executed like this.
|
||||
interpreter = tf.lite.Interpreter(model_content=lite_model_content)
|
||||
def lite_model(images):
|
||||
interpreter.allocate_tensors()
|
||||
interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
|
||||
interpreter.invoke()
|
||||
return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
|
||||
|
||||
# Load the SavedModel again for use as a test baseline.
|
||||
imported = tf.saved_model.load(FLAGS.saved_model_dir)
|
||||
def tf_model(images):
|
||||
output_dict = imported.signatures['serving_default'](tf.constant(images))
|
||||
logits, = output_dict.values() # Unpack single value.
|
||||
return logits
|
||||
|
||||
# Compare model outputs on the test inputs.
|
||||
(_, _), (x_test, _) = mnist_util.load_reshaped_data(
|
||||
use_fashion_mnist=FLAGS.use_fashion_mnist,
|
||||
fake_tiny_data=FLAGS.fast_test_mode)
|
||||
for i, x in enumerate(x_test):
|
||||
x = x[None, ...] # Make batch of size 1.
|
||||
y_lite = lite_model(x)
|
||||
y_tf = tf_model(x)
|
||||
# This numpy primitive uses plain `raise` and works outside tf.TestCase.
|
||||
# Model outputs are probabilities that sum to 1, so atol makes sense here.
|
||||
np.testing.assert_allclose(
|
||||
y_lite, y_tf, rtol=0, atol=1e-5,
|
||||
err_msg='Mismatch with TF Lite at test example %d' % i)
|
||||
|
||||
# Test that the SavedModel loads correctly with v1 load APIs as well.
|
||||
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as session:
|
||||
tf.compat.v1.saved_model.load(
|
||||
session,
|
||||
[tf.compat.v1.saved_model.SERVING],
|
||||
FLAGS.saved_model_dir)
|
||||
|
||||
# The SavedModel actually was a Keras Model; test that it also loads as that.
|
||||
if FLAGS.reload_as_keras_model:
|
||||
keras_model = tf.keras.models.load_model(FLAGS.saved_model_dir)
|
||||
for i, x in enumerate(x_test):
|
||||
x = x[None, ...] # Make batch of size 1.
|
||||
y_tf = tf_model(x)
|
||||
y_keras = keras_model(x)
|
||||
# This numpy primitive uses plain `raise` and works outside tf.TestCase.
|
||||
# Model outputs are probabilities that sum to 1, so atol makes sense here.
|
||||
np.testing.assert_allclose(
|
||||
y_tf, y_keras, rtol=0, atol=1e-5,
|
||||
err_msg='Mismatch with Keras at test example %d' % i)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
@ -1,68 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utils related to tf.distribute.strategy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import sys
|
||||
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
|
||||
_strategies = [
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_gpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.tpu_strategy,
|
||||
]
|
||||
|
||||
# The presence of GPU strategies upsets TPU initialization,
|
||||
# despite their test instances being skipped early. This is a workaround
|
||||
# for b/145386854.
|
||||
if "test_tpu" in sys.argv[0]:
|
||||
_strategies = [s for s in _strategies if "GPU" not in str(s)]
|
||||
|
||||
|
||||
named_strategies = collections.OrderedDict(
|
||||
[(None, None)] +
|
||||
[(str(s), s) for s in _strategies]
|
||||
)
|
||||
|
||||
|
||||
class MaybeDistributionScope(object):
|
||||
"""Provides a context allowing no distribution strategy."""
|
||||
|
||||
@staticmethod
|
||||
def from_name(name):
|
||||
return MaybeDistributionScope(named_strategies[name].strategy if name
|
||||
else None)
|
||||
|
||||
def __init__(self, distribution):
|
||||
self._distribution = distribution
|
||||
self._scope = None
|
||||
|
||||
def __enter__(self):
|
||||
if self._distribution:
|
||||
self._scope = self._distribution.scope()
|
||||
self._scope.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, value, traceback):
|
||||
if self._distribution:
|
||||
self._scope.__exit__(exc_type, value, traceback)
|
||||
self._scope = None
|
@ -1,204 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Exports a convolutional feature extractor for MNIST in SavedModel format.
|
||||
|
||||
The feature extractor is a convolutional neural network plus a hidden layer
|
||||
that gets trained as part of an MNIST classifier and then written to a
|
||||
SavedModel (without the classification layer). From there, use_mnist_cnn.py
|
||||
picks it up for transfer learning.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
'export_dir', None,
|
||||
'Directory of exported SavedModel.')
|
||||
flags.DEFINE_integer(
|
||||
'epochs', 10,
|
||||
'Number of epochs to train.')
|
||||
flags.DEFINE_bool(
|
||||
'use_keras_save_api', False,
|
||||
'Uses tf.keras.models.save_model() on the feature extractor '
|
||||
'instead of tf.saved_model.save() on a manually wrapped version. '
|
||||
'With this, the exported model has no hparams.')
|
||||
flags.DEFINE_bool(
|
||||
'fast_test_mode', False,
|
||||
'Shortcut training for running in unit tests.')
|
||||
flags.DEFINE_bool(
|
||||
'export_print_hparams', False,
|
||||
'If true, the exported function will print its effective hparams.')
|
||||
|
||||
|
||||
def make_feature_extractor(l2_strength, dropout_rate):
|
||||
"""Returns a Keras Model to compute a feature vector from MNIST images."""
|
||||
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
|
||||
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
|
||||
net = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', name='conv1',
|
||||
kernel_regularizer=regularizer())(net)
|
||||
net = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', name='conv2',
|
||||
kernel_regularizer=regularizer())(net)
|
||||
net = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name='pool1')(net)
|
||||
net = tf.keras.layers.Dropout(dropout_rate, name='dropout1')(net)
|
||||
net = tf.keras.layers.Flatten(name='flatten')(net)
|
||||
net = tf.keras.layers.Dense(10, activation='relu', name='dense1',
|
||||
kernel_regularizer=regularizer())(net)
|
||||
return tf.keras.Model(inputs=inp, outputs=net)
|
||||
|
||||
|
||||
def set_feature_extractor_hparams(model, dropout_rate):
|
||||
model.get_layer('dropout1').rate = dropout_rate
|
||||
|
||||
|
||||
def make_classifier(feature_extractor, l2_strength, dropout_rate=0.5):
|
||||
"""Returns a Keras Model to classify MNIST using feature_extractor."""
|
||||
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
|
||||
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
|
||||
net = feature_extractor(net)
|
||||
net = tf.keras.layers.Dropout(dropout_rate)(net)
|
||||
net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax',
|
||||
kernel_regularizer=regularizer())(net)
|
||||
return tf.keras.Model(inputs=inp, outputs=net)
|
||||
|
||||
|
||||
def wrap_keras_model_for_export(model, batch_input_shape,
|
||||
set_hparams, default_hparams):
|
||||
"""Wraps `model` for saving and loading as SavedModel."""
|
||||
# The primary input to the module is a Tensor with a batch of images.
|
||||
# Here we determine its spec.
|
||||
inputs_spec = tf.TensorSpec(shape=batch_input_shape, dtype=tf.float32)
|
||||
|
||||
# The module also accepts certain hparams as optional Tensor inputs.
|
||||
# Here, we cut all the relevant slices from `default_hparams`
|
||||
# (and don't worry if anyone accidentally modifies it later).
|
||||
if default_hparams is None: default_hparams = {}
|
||||
hparam_keys = list(default_hparams.keys())
|
||||
hparam_defaults = tuple(default_hparams.values())
|
||||
hparams_spec = {name: tf.TensorSpec.from_tensor(tf.constant(value))
|
||||
for name, value in default_hparams.items()}
|
||||
|
||||
# The goal is to save a function with this argspec...
|
||||
argspec = tf_inspect.FullArgSpec(
|
||||
args=(['inputs', 'training'] + hparam_keys),
|
||||
defaults=((False,) + hparam_defaults),
|
||||
varargs=None, varkw=None,
|
||||
kwonlyargs=[], kwonlydefaults=None,
|
||||
annotations={})
|
||||
# ...and this behavior:
|
||||
def call_fn(inputs, training, *args):
|
||||
if FLAGS.export_print_hparams:
|
||||
args = [tf.keras.backend.print_tensor(args[i], 'training=%s and %s='
|
||||
% (training, hparam_keys[i]))
|
||||
for i in range(len(args))]
|
||||
kwargs = dict(zip(hparam_keys, args))
|
||||
if kwargs: set_hparams(model, **kwargs)
|
||||
return model(inputs, training=training)
|
||||
|
||||
# We cannot spell out `args` in def statement for call_fn, but since
|
||||
# tf.function uses tf_inspect, we can use tf_decorator to wrap it with
|
||||
# the desired argspec.
|
||||
def wrapped(*args, **kwargs): # TODO(arnoegw): Can we use call_fn itself?
|
||||
return call_fn(*args, **kwargs)
|
||||
traced_call_fn = tf.function(
|
||||
tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
|
||||
|
||||
# Now we need to trigger traces for all supported combinations of the
|
||||
# non-Tensor-value inputs.
|
||||
for training in (True, False):
|
||||
traced_call_fn.get_concrete_function(inputs_spec, training, **hparams_spec)
|
||||
|
||||
# Finally, we assemble the object for tf.saved_model.save().
|
||||
obj = tf.train.Checkpoint()
|
||||
obj.__call__ = traced_call_fn
|
||||
obj.trainable_variables = model.trainable_variables
|
||||
obj.variables = model.trainable_variables + model.non_trainable_variables
|
||||
# Make tf.functions for the regularization terms of the loss.
|
||||
obj.regularization_losses = [_get_traced_loss(model, i)
|
||||
for i in range(len(model.losses))]
|
||||
return obj
|
||||
|
||||
|
||||
def _get_traced_loss(model, i):
|
||||
"""Returns tf.function for model.losses[i] with a trace for zero args.
|
||||
|
||||
The intended usage is
|
||||
[_get_traced_loss(model, i) for i in range(len(model.losses))]
|
||||
This is better than
|
||||
[tf.function(lambda: model.losses[i], input_signature=[]) for i ...]
|
||||
because it avoids capturing a loop index in a lambda, and removes any
|
||||
chance of deferring the trace.
|
||||
|
||||
Args:
|
||||
model: a Keras Model.
|
||||
i: an integer between from 0 up to but to len(model.losses).
|
||||
"""
|
||||
f = tf.function(lambda: model.losses[i])
|
||||
_ = f.get_concrete_function()
|
||||
return f
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
# Build a complete classifier model using a feature extractor.
|
||||
default_hparams = dict(dropout_rate=0.25)
|
||||
l2_strength = 0.01 # Not a hparam for inputs -> outputs.
|
||||
feature_extractor = make_feature_extractor(l2_strength=l2_strength,
|
||||
**default_hparams)
|
||||
classifier = make_classifier(feature_extractor, l2_strength=l2_strength)
|
||||
|
||||
# Train the complete model.
|
||||
(x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
|
||||
fake_tiny_data=FLAGS.fast_test_mode)
|
||||
classifier.compile(loss=tf.keras.losses.categorical_crossentropy,
|
||||
optimizer=tf.keras.optimizers.SGD(),
|
||||
metrics=['accuracy'])
|
||||
classifier.fit(x_train, y_train,
|
||||
batch_size=128,
|
||||
epochs=FLAGS.epochs,
|
||||
verbose=1,
|
||||
validation_data=(x_test, y_test))
|
||||
|
||||
# Save the feature extractor to a framework-agnostic SavedModel for reuse.
|
||||
# Note that the feature_extractor object has not been compiled or fitted,
|
||||
# so it does not contain an optimizer and related state.
|
||||
if FLAGS.use_keras_save_api:
|
||||
# Use Keras' built-in way of creating reusable SavedModels.
|
||||
# This has no support for adjustable hparams at this time (July 2019).
|
||||
# (We could also call tf.saved_model.save(feature_extractor, ...),
|
||||
# point is we're passing a Keras model, not a plain Checkpoint.)
|
||||
tf.keras.models.save_model(feature_extractor, FLAGS.export_dir)
|
||||
else:
|
||||
# Assemble a reusable SavedModel manually, with adjustable hparams.
|
||||
exportable = wrap_keras_model_for_export(feature_extractor,
|
||||
(None,) + mnist_util.INPUT_SHAPE,
|
||||
set_feature_extractor_hparams,
|
||||
default_hparams)
|
||||
tf.saved_model.save(exportable, FLAGS.export_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
@ -1,63 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Export an RNN cell in SavedModel format."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
root = tf.train.Checkpoint()
|
||||
# Create a cell and attach to our trackable.
|
||||
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
|
||||
|
||||
# Wrap the rnn_cell.__call__ function and assign to next_state.
|
||||
root.next_state = tf.function(root.rnn_cell.__call__)
|
||||
|
||||
# Wrap the rnn_cell.get_initial_function using a decorator and assign to an
|
||||
# attribute with the same name.
|
||||
@tf.function(input_signature=[tf.TensorSpec([None, None], tf.float32)])
|
||||
def get_initial_state(tensor):
|
||||
return root.rnn_cell.get_initial_state(tensor, None, None)
|
||||
|
||||
root.get_initial_state = get_initial_state
|
||||
|
||||
# Construct an initial_state, then call next_state explicitly to trigger a
|
||||
# trace for serialization (we need an explicit call, because next_state has
|
||||
# not been annotated with an input_signature).
|
||||
initial_state = root.get_initial_state(
|
||||
tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
|
||||
root.next_state(
|
||||
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
|
||||
initial_state)
|
||||
|
||||
tf.saved_model.save(root, FLAGS.export_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -1,105 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Text embedding model stored as a SavedModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
|
||||
|
||||
|
||||
def write_vocabulary_file(vocabulary):
|
||||
"""Write temporary vocab file for module construction."""
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
vocabulary_file = os.path.join(tmpdir, "tokens.txt")
|
||||
with tf.io.gfile.GFile(vocabulary_file, "w") as f:
|
||||
for entry in vocabulary:
|
||||
f.write(entry + "\n")
|
||||
return vocabulary_file
|
||||
|
||||
|
||||
class TextEmbeddingModel(tf.train.Checkpoint):
|
||||
"""Text embedding model.
|
||||
|
||||
A text embeddings model that takes a sentences on input and outputs the
|
||||
sentence embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, vocabulary, emb_dim, oov_buckets):
|
||||
super(TextEmbeddingModel, self).__init__()
|
||||
self._oov_buckets = oov_buckets
|
||||
self._total_size = len(vocabulary) + oov_buckets
|
||||
# Assign the table initializer to this instance to ensure the asset
|
||||
# it depends on is saved with the SavedModel.
|
||||
self._table_initializer = tf.lookup.TextFileInitializer(
|
||||
write_vocabulary_file(vocabulary), tf.string,
|
||||
tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64,
|
||||
tf.lookup.TextFileIndex.LINE_NUMBER)
|
||||
self._table = tf.lookup.StaticVocabularyTable(
|
||||
self._table_initializer, num_oov_buckets=self._oov_buckets)
|
||||
self.embeddings = tf.Variable(
|
||||
tf.random.uniform(shape=[self._total_size, emb_dim]))
|
||||
self.variables = [self.embeddings]
|
||||
self.trainable_variables = self.variables
|
||||
|
||||
def _tokenize(self, sentences):
|
||||
# Perform a minimalistic text preprocessing by removing punctuation and
|
||||
# splitting on spaces.
|
||||
normalized_sentences = tf.strings.regex_replace(
|
||||
input=sentences, pattern=r"\pP", rewrite="")
|
||||
normalized_sentences = tf.reshape(normalized_sentences, [-1])
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
|
||||
|
||||
# Deal with a corner case: there is one empty sentence.
|
||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
||||
# Deal with a corner case: all sentences are empty.
|
||||
sparse_tokens = tf.sparse.reset_shape(sparse_tokens)
|
||||
sparse_token_ids = self._table.lookup(sparse_tokens.values)
|
||||
|
||||
return (sparse_tokens.indices, sparse_token_ids, sparse_tokens.dense_shape)
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)])
|
||||
def __call__(self, sentences):
|
||||
token_ids, token_values, token_dense_shape = self._tokenize(sentences)
|
||||
|
||||
return tf.nn.safe_embedding_lookup_sparse(
|
||||
embedding_weights=self.embeddings,
|
||||
sparse_ids=tf.sparse.SparseTensor(token_ids, token_values,
|
||||
token_dense_shape),
|
||||
sparse_weights=None,
|
||||
combiner="sqrtn")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
vocabulary = ["cat", "is", "on", "the", "mat"]
|
||||
module = TextEmbeddingModel(vocabulary=vocabulary, emb_dim=10, oov_buckets=10)
|
||||
tf.saved_model.save(module, FLAGS.export_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -1,193 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Text RNN model stored as a SavedModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
|
||||
|
||||
|
||||
class TextRnnModel(tf.train.Checkpoint):
|
||||
"""Text RNN model.
|
||||
|
||||
A full generative text RNN model that can train and decode sentences from a
|
||||
starting word.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab, emb_dim, buckets, state_size):
|
||||
super(TextRnnModel, self).__init__()
|
||||
self._buckets = buckets
|
||||
self._lstm_cell = tf.keras.layers.LSTMCell(units=state_size)
|
||||
self._rnn_layer = tf.keras.layers.RNN(
|
||||
self._lstm_cell, return_sequences=True)
|
||||
self._embeddings = tf.Variable(tf.random.uniform(shape=[buckets, emb_dim]))
|
||||
self._logit_layer = tf.keras.layers.Dense(buckets)
|
||||
self._set_up_vocab(vocab)
|
||||
|
||||
def _tokenize(self, sentences):
|
||||
# Perform a minimalistic text preprocessing by removing punctuation and
|
||||
# splitting on spaces.
|
||||
normalized_sentences = tf.strings.regex_replace(
|
||||
input=sentences, pattern=r"\pP", rewrite="")
|
||||
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
|
||||
|
||||
# Deal with a corner case: there is one empty sentence.
|
||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
||||
# Deal with a corner case: all sentences are empty.
|
||||
sparse_tokens = tf.sparse.reset_shape(sparse_tokens)
|
||||
|
||||
return (sparse_tokens.indices, sparse_tokens.values,
|
||||
sparse_tokens.dense_shape)
|
||||
|
||||
def _set_up_vocab(self, vocab_tokens):
|
||||
# TODO(vbardiovsky): Currently there is no real vocabulary, because
|
||||
# saved_model serialization does not support trackable resources. Add a real
|
||||
# vocabulary when it does.
|
||||
vocab_list = ["UNK"] * self._buckets
|
||||
for vocab_token in vocab_tokens:
|
||||
index = self._words_to_indices(vocab_token).numpy()
|
||||
vocab_list[index] = vocab_token
|
||||
# This is a variable representing an inverse index.
|
||||
self._vocab_tensor = tf.Variable(vocab_list)
|
||||
|
||||
def _indices_to_words(self, indices):
|
||||
return tf.gather(self._vocab_tensor, indices)
|
||||
|
||||
def _words_to_indices(self, words):
|
||||
return tf.strings.to_hash_bucket(words, self._buckets)
|
||||
|
||||
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)])
|
||||
def train(self, sentences):
|
||||
token_ids, token_values, token_dense_shape = self._tokenize(sentences)
|
||||
tokens_sparse = tf.sparse.SparseTensor(
|
||||
indices=token_ids, values=token_values, dense_shape=token_dense_shape)
|
||||
tokens = tf.sparse.to_dense(tokens_sparse, default_value="")
|
||||
|
||||
sparse_lookup_ids = tf.sparse.SparseTensor(
|
||||
indices=tokens_sparse.indices,
|
||||
values=self._words_to_indices(tokens_sparse.values),
|
||||
dense_shape=tokens_sparse.dense_shape)
|
||||
lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0)
|
||||
|
||||
# Targets are the next word for each word of the sentence.
|
||||
tokens_ids_seq = lookup_ids[:, 0:-1]
|
||||
tokens_ids_target = lookup_ids[:, 1:]
|
||||
|
||||
tokens_prefix = tokens[:, 0:-1]
|
||||
|
||||
# Mask determining which positions we care about for a loss: all positions
|
||||
# that have a valid non-terminal token.
|
||||
mask = tf.logical_and(
|
||||
tf.logical_not(tf.equal(tokens_prefix, "")),
|
||||
tf.logical_not(tf.equal(tokens_prefix, "<E>")))
|
||||
|
||||
input_mask = tf.cast(mask, tf.int32)
|
||||
|
||||
with tf.GradientTape() as t:
|
||||
sentence_embeddings = tf.nn.embedding_lookup(self._embeddings,
|
||||
tokens_ids_seq)
|
||||
|
||||
lstm_initial_state = self._lstm_cell.get_initial_state(
|
||||
sentence_embeddings)
|
||||
|
||||
lstm_output = self._rnn_layer(
|
||||
inputs=sentence_embeddings, initial_state=lstm_initial_state)
|
||||
|
||||
# Stack LSTM outputs into a batch instead of a 2D array.
|
||||
lstm_output = tf.reshape(lstm_output, [-1, self._lstm_cell.output_size])
|
||||
|
||||
logits = self._logit_layer(lstm_output)
|
||||
|
||||
targets = tf.reshape(tokens_ids_target, [-1])
|
||||
weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32)
|
||||
|
||||
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=targets, logits=logits)
|
||||
|
||||
# Final loss is the mean loss for all token losses.
|
||||
final_loss = tf.math.divide(
|
||||
tf.reduce_sum(tf.multiply(losses, weights)),
|
||||
tf.reduce_sum(weights),
|
||||
name="final_loss")
|
||||
|
||||
watched = t.watched_variables()
|
||||
gradients = t.gradient(final_loss, watched)
|
||||
|
||||
for w, g in zip(watched, gradients):
|
||||
w.assign_sub(g)
|
||||
|
||||
return final_loss
|
||||
|
||||
@tf.function
|
||||
def decode_greedy(self, sequence_length, first_word):
|
||||
initial_state = self._lstm_cell.get_initial_state(
|
||||
dtype=tf.float32, batch_size=1)
|
||||
|
||||
sequence = [first_word]
|
||||
current_word = first_word
|
||||
current_id = tf.expand_dims(self._words_to_indices(current_word), 0)
|
||||
current_state = initial_state
|
||||
|
||||
for _ in range(sequence_length):
|
||||
token_embeddings = tf.nn.embedding_lookup(self._embeddings, current_id)
|
||||
lstm_outputs, current_state = self._lstm_cell(token_embeddings,
|
||||
current_state)
|
||||
lstm_outputs = tf.reshape(lstm_outputs, [-1, self._lstm_cell.output_size])
|
||||
logits = self._logit_layer(lstm_outputs)
|
||||
softmax = tf.nn.softmax(logits)
|
||||
|
||||
next_ids = tf.math.argmax(softmax, axis=1)
|
||||
next_words = self._indices_to_words(next_ids)[0]
|
||||
|
||||
current_id = next_ids
|
||||
current_word = next_words
|
||||
sequence.append(current_word)
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
sentences = ["<S> hello there <E>", "<S> how are you doing today <E>"]
|
||||
vocab = [
|
||||
"<S>", "<E>", "hello", "there", "how", "are", "you", "doing", "today"
|
||||
]
|
||||
|
||||
module = TextRnnModel(vocab=vocab, emb_dim=10, buckets=100, state_size=128)
|
||||
|
||||
for _ in range(100):
|
||||
_ = module.train(tf.constant(sentences))
|
||||
|
||||
# We have to call this function explicitly if we want it exported, because it
|
||||
# has no input_signature in the @tf.function decorator.
|
||||
decoded = module.decode_greedy(
|
||||
sequence_length=10, first_word=tf.constant("<S>"))
|
||||
_ = [d.numpy() for d in decoded]
|
||||
|
||||
tf.saved_model.save(module, FLAGS.export_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -1,68 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility to write SavedModel integration tests.
|
||||
|
||||
SavedModel testing requires isolation between the process that creates and
|
||||
consumes it. This file helps doing that by relaunching the same binary that
|
||||
calls `assertCommandSucceeded` with an environment flag indicating what source
|
||||
file to execute. That binary must start by calling `MaybeRunScriptInstead`.
|
||||
|
||||
This allows to wire this into existing building systems without having to depend
|
||||
on data dependencies. And as so allow to keep a fixed binary size and allows
|
||||
interop with GPU tests.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
class TestCase(tf.test.TestCase):
|
||||
"""Base class to write SavedModel integration tests."""
|
||||
|
||||
def assertCommandSucceeded(self, script_name, **flags):
|
||||
"""Runs an integration test script with given flags."""
|
||||
run_script = sys.argv[0]
|
||||
if run_script.endswith(".py"):
|
||||
command_parts = [sys.executable, run_script]
|
||||
else:
|
||||
command_parts = [run_script]
|
||||
command_parts.append("--alsologtostderr") # For visibility in sponge.
|
||||
for flag_key, flag_value in flags.items():
|
||||
command_parts.append("--%s=%s" % (flag_key, flag_value))
|
||||
|
||||
env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
|
||||
logging.info("Running %s with added environment variables %s" %
|
||||
(command_parts, env))
|
||||
subprocess.check_call(command_parts, env=dict(os.environ, **env))
|
||||
|
||||
|
||||
def MaybeRunScriptInstead():
|
||||
if "SCRIPT_NAME" in os.environ:
|
||||
# Append current path to import path and execute `SCRIPT_NAME` main.
|
||||
sys.path.extend([os.path.dirname(__file__)])
|
||||
module_name = os.environ["SCRIPT_NAME"]
|
||||
retval = app.run(importlib.import_module(module_name).main) # pylint: disable=assignment-from-no-return
|
||||
sys.exit(retval)
|
@ -1,51 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Convenience wrapper around Keras' MNIST and Fashion MNIST data."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
INPUT_SHAPE = (28, 28, 1)
|
||||
NUM_CLASSES = 10
|
||||
|
||||
|
||||
def _load_random_data(num_train_and_test):
|
||||
return ((np.random.randint(0, 256, (num, 28, 28), dtype=np.uint8),
|
||||
np.random.randint(0, 10, (num,), dtype=np.int64))
|
||||
for num in num_train_and_test)
|
||||
|
||||
|
||||
def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False):
|
||||
"""Returns MNIST or Fashion MNIST or fake train and test data."""
|
||||
load = ((lambda: _load_random_data([128, 128])) if fake_tiny_data else
|
||||
tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else
|
||||
tf.keras.datasets.mnist.load_data)
|
||||
(x_train, y_train), (x_test, y_test) = load()
|
||||
return ((_prepare_image(x_train), _prepare_label(y_train)),
|
||||
(_prepare_image(x_test), _prepare_label(y_test)))
|
||||
|
||||
|
||||
def _prepare_image(x):
|
||||
"""Converts images to [n,h,w,c] format in range [0,1]."""
|
||||
return x[..., None].astype('float32') / 255.
|
||||
|
||||
|
||||
def _prepare_label(y):
|
||||
"""Conerts labels to one-hot encoding."""
|
||||
return tf.keras.utils.to_categorical(y, NUM_CLASSES)
|
@ -1,133 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""SavedModel integration tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
||||
from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts
|
||||
from tensorflow.python.distribute import combinations as distribute_combinations
|
||||
from tensorflow.python.framework import combinations
|
||||
|
||||
|
||||
class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
||||
|
||||
def __init__(self, method_name="runTest", has_extra_deps=False):
|
||||
super(SavedModelTest, self).__init__(method_name)
|
||||
self.has_extra_deps = has_extra_deps
|
||||
|
||||
def skipIfMissingExtraDeps(self):
|
||||
"""Skip test if it requires extra dependencies.
|
||||
|
||||
b/132234211: The extra dependencies are not available in all environments
|
||||
that run the tests, e.g. "tensorflow_hub" is not available from tests
|
||||
within "tensorflow" alone. Those tests are instead run by another
|
||||
internal test target.
|
||||
"""
|
||||
if not self.has_extra_deps:
|
||||
self.skipTest("Missing extra dependencies")
|
||||
|
||||
def test_text_rnn(self):
|
||||
export_dir = self.get_temp_dir()
|
||||
self.assertCommandSucceeded("export_text_rnn_model", export_dir=export_dir)
|
||||
self.assertCommandSucceeded("use_text_rnn_model", model_dir=export_dir)
|
||||
|
||||
def test_rnn_cell(self):
|
||||
export_dir = self.get_temp_dir()
|
||||
self.assertCommandSucceeded("export_rnn_cell", export_dir=export_dir)
|
||||
self.assertCommandSucceeded("use_rnn_cell", model_dir=export_dir)
|
||||
|
||||
def test_text_embedding_in_sequential_keras(self):
|
||||
self.skipIfMissingExtraDeps()
|
||||
export_dir = self.get_temp_dir()
|
||||
self.assertCommandSucceeded(
|
||||
"export_simple_text_embedding", export_dir=export_dir)
|
||||
self.assertCommandSucceeded(
|
||||
"use_model_in_sequential_keras", model_dir=export_dir)
|
||||
|
||||
def test_text_embedding_in_dataset(self):
|
||||
export_dir = self.get_temp_dir()
|
||||
self.assertCommandSucceeded(
|
||||
"export_simple_text_embedding", export_dir=export_dir)
|
||||
self.assertCommandSucceeded(
|
||||
"use_text_embedding_in_dataset", model_dir=export_dir)
|
||||
|
||||
TEST_MNIST_CNN_GENERATE_KWARGS = dict(
|
||||
combinations=(
|
||||
combinations.combine(
|
||||
# Test all combinations with tf.saved_model.save().
|
||||
# Test all combinations using tf.keras.models.save_model()
|
||||
# for both the reusable and the final full model.
|
||||
use_keras_save_api=True,
|
||||
named_strategy=list(ds_utils.named_strategies.values()),
|
||||
retrain_flag_value=["true", "false"],
|
||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
||||
) + combinations.combine(
|
||||
# Test few critcial combinations with raw tf.saved_model.save(),
|
||||
# including export of a reusable SavedModel that gets assembled
|
||||
# manually, including support for adjustable hparams.
|
||||
use_keras_save_api=False,
|
||||
named_strategy=None,
|
||||
retrain_flag_value=["true", "false"],
|
||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
||||
)),
|
||||
test_combinations=(distribute_combinations.GPUCombination(),
|
||||
distribute_combinations.TPUCombination()))
|
||||
|
||||
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
|
||||
def test_mnist_cnn(self, use_keras_save_api, named_strategy,
|
||||
retrain_flag_value, regularization_loss_multiplier):
|
||||
|
||||
self.skipIfMissingExtraDeps()
|
||||
|
||||
fast_test_mode = True
|
||||
temp_dir = self.get_temp_dir()
|
||||
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
|
||||
full_model_dir = os.path.join(temp_dir, "full_model")
|
||||
|
||||
self.assertCommandSucceeded(
|
||||
"export_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
export_dir=feature_extrator_dir,
|
||||
use_keras_save_api=use_keras_save_api)
|
||||
|
||||
use_kwargs = dict(fast_test_mode=fast_test_mode,
|
||||
input_saved_model_dir=feature_extrator_dir,
|
||||
retrain=retrain_flag_value,
|
||||
output_saved_model_dir=full_model_dir,
|
||||
use_keras_save_api=use_keras_save_api)
|
||||
if named_strategy:
|
||||
use_kwargs["strategy"] = str(named_strategy)
|
||||
if regularization_loss_multiplier is not None:
|
||||
use_kwargs[
|
||||
"regularization_loss_multiplier"] = regularization_loss_multiplier
|
||||
self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)
|
||||
|
||||
self.assertCommandSucceeded(
|
||||
"deploy_mnist_cnn",
|
||||
fast_test_mode=fast_test_mode,
|
||||
saved_model_dir=full_model_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
scripts.MaybeRunScriptInstead()
|
||||
tf.test.main()
|
@ -1,147 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Imports a convolutional feature extractor for MNIST in SavedModel format.
|
||||
|
||||
This program picks up the SavedModel written by export_mnist_cnn.py and
|
||||
uses the feature extractor contained in it to do classification on either
|
||||
classic MNIST (digits) or Fashion MNIST (thumbnails of apparel). Optionally,
|
||||
it trains the feature extractor further as part of the new classifier.
|
||||
As expected, that makes training slower but does not help much for the
|
||||
original training dataset but helps a lot for transfer to the other dataset.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
'input_saved_model_dir', None,
|
||||
'Directory of the reusable SavedModel that is imported into this program.')
|
||||
flags.DEFINE_integer(
|
||||
'epochs', 5,
|
||||
'Number of epochs to train.')
|
||||
flags.DEFINE_bool(
|
||||
'retrain', False,
|
||||
'If set, the imported SavedModel is trained further.')
|
||||
flags.DEFINE_float(
|
||||
'dropout_rate', None,
|
||||
'If set, dropout rate passed to the SavedModel. '
|
||||
'Requires a SavedModel with support for adjustable hyperparameters.')
|
||||
flags.DEFINE_float(
|
||||
'regularization_loss_multiplier', None,
|
||||
'If set, multiplier for the regularization losses in the SavedModel.')
|
||||
flags.DEFINE_bool(
|
||||
'use_fashion_mnist', False,
|
||||
'Use Fashion MNIST (products) instead of the real MNIST (digits). '
|
||||
'With this, --retrain gains a lot.')
|
||||
flags.DEFINE_bool(
|
||||
'fast_test_mode', False,
|
||||
'Shortcut training for running in unit tests.')
|
||||
flags.DEFINE_string(
|
||||
'output_saved_model_dir', None,
|
||||
'Directory of the SavedModel that was exported for reuse.')
|
||||
flags.DEFINE_bool(
|
||||
'use_keras_save_api', False,
|
||||
'Uses tf.keras.models.save_model() instead of tf.saved_model.save().')
|
||||
flags.DEFINE_string('strategy', None,
|
||||
'Name of the distribution strategy to use.')
|
||||
|
||||
|
||||
def make_feature_extractor(saved_model_path, trainable,
|
||||
regularization_loss_multiplier):
|
||||
"""Load a pre-trained feature extractor and wrap it for use in Keras."""
|
||||
if regularization_loss_multiplier is not None:
|
||||
# TODO(b/63257857): Scaling regularization losses requires manual loading
|
||||
# and modification of the SavedModel
|
||||
obj = tf.saved_model.load(saved_model_path)
|
||||
def _scale_one_loss(l): # Separate def avoids lambda capture of loop var.
|
||||
f = tf.function(lambda: tf.multiply(regularization_loss_multiplier, l()))
|
||||
_ = f.get_concrete_function()
|
||||
return f
|
||||
obj.regularization_losses = [_scale_one_loss(l)
|
||||
for l in obj.regularization_losses]
|
||||
# The modified object is then passed to hub.KerasLayer instead of the
|
||||
# string handle. That prevents it from saving a Keras config (b/134528831).
|
||||
handle = obj
|
||||
else:
|
||||
# If possible, we exercise the more common case of passing a string handle
|
||||
# such that hub.KerasLayer can save a Keras config (b/134528831).
|
||||
handle = saved_model_path
|
||||
|
||||
arguments = {}
|
||||
if FLAGS.dropout_rate is not None:
|
||||
arguments['dropout_rate'] = FLAGS.dropout_rate
|
||||
|
||||
return hub.KerasLayer(handle, trainable=trainable, arguments=arguments)
|
||||
|
||||
|
||||
def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5):
|
||||
"""Returns a Keras Model to classify MNIST using feature_extractor."""
|
||||
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
|
||||
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
|
||||
net = feature_extractor(net)
|
||||
if dropout_rate:
|
||||
net = tf.keras.layers.Dropout(dropout_rate)(net)
|
||||
net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax',
|
||||
kernel_regularizer=regularizer())(net)
|
||||
return tf.keras.Model(inputs=inp, outputs=net)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
with ds_utils.MaybeDistributionScope.from_name(FLAGS.strategy):
|
||||
feature_extractor = make_feature_extractor(
|
||||
FLAGS.input_saved_model_dir,
|
||||
FLAGS.retrain,
|
||||
FLAGS.regularization_loss_multiplier)
|
||||
model = make_classifier(feature_extractor)
|
||||
|
||||
model.compile(loss=tf.keras.losses.categorical_crossentropy,
|
||||
optimizer=tf.keras.optimizers.SGD(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
# Train the classifier (possibly on a different dataset).
|
||||
(x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
|
||||
use_fashion_mnist=FLAGS.use_fashion_mnist,
|
||||
fake_tiny_data=FLAGS.fast_test_mode)
|
||||
print('Training on %s with %d trainable and %d untrainable variables.' %
|
||||
('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST',
|
||||
len(model.trainable_variables), len(model.non_trainable_variables)))
|
||||
model.fit(x_train, y_train,
|
||||
batch_size=128,
|
||||
epochs=FLAGS.epochs,
|
||||
verbose=1,
|
||||
validation_data=(x_test, y_test))
|
||||
|
||||
if FLAGS.output_saved_model_dir:
|
||||
if FLAGS.use_keras_save_api:
|
||||
tf.keras.models.save_model(model, FLAGS.output_saved_model_dir)
|
||||
else:
|
||||
tf.saved_model.save(model, FLAGS.output_saved_model_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
@ -1,75 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Load and use text embedding module in sequential Keras."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_hub as hub
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
||||
|
||||
|
||||
def train(fine_tuning):
|
||||
"""Build a Keras model and train with mock data."""
|
||||
features = np.array(["my first sentence", "my second sentence"])
|
||||
labels = np.array([1, 0])
|
||||
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
|
||||
|
||||
module = tf.saved_model.load(FLAGS.model_dir)
|
||||
|
||||
# Create the sequential keras model.
|
||||
l = tf.keras.layers
|
||||
model = tf.keras.Sequential()
|
||||
model.add(l.Reshape((), batch_input_shape=[None, 1], dtype=tf.string))
|
||||
# TODO(b/124219898): output_shape should be optional.
|
||||
model.add(hub.KerasLayer(module, output_shape=[10], trainable=fine_tuning))
|
||||
model.add(l.Dense(100, activation="relu"))
|
||||
model.add(l.Dense(50, activation="relu"))
|
||||
model.add(l.Dense(1, activation="sigmoid"))
|
||||
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss="binary_crossentropy",
|
||||
metrics=["accuracy"],
|
||||
# TODO(b/124446120): Remove after fixed.
|
||||
run_eagerly=True)
|
||||
|
||||
model.fit_generator(generator=dataset.batch(1), epochs=5)
|
||||
|
||||
# This is testing that a model using a SavedModel can be re-exported again,
|
||||
# e.g. to catch issues such as b/142231881.
|
||||
tf.saved_model.save(model, tempfile.mkdtemp())
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
train(fine_tuning=False)
|
||||
train(fine_tuning=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -1,50 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Load and use an RNN cell stored as a SavedModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
cell = tf.saved_model.load(FLAGS.model_dir)
|
||||
|
||||
initial_state = cell.get_initial_state(
|
||||
tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
|
||||
|
||||
cell.next_state(
|
||||
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
|
||||
initial_state)
|
||||
|
||||
# This is testing that a model using a SavedModel can be re-exported again,
|
||||
# e.g. to catch issues such as b/142231881.
|
||||
tf.saved_model.save(cell, tempfile.mkdtemp())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -1,73 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Load and use text embedding module in a Dataset map function."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
||||
|
||||
|
||||
def train():
|
||||
"""Build a Keras model and train with mock data."""
|
||||
module = tf.saved_model.load(FLAGS.model_dir)
|
||||
def _map_fn(features, labels):
|
||||
features = tf.expand_dims(features, 0)
|
||||
features = module(features)
|
||||
features = tf.squeeze(features, 0)
|
||||
return features, labels
|
||||
|
||||
features = np.array(["my first sentence", "my second sentence"])
|
||||
labels = np.array([1, 0])
|
||||
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn)
|
||||
|
||||
# Create the sequential keras model.
|
||||
l = tf.keras.layers
|
||||
model = tf.keras.Sequential()
|
||||
model.add(l.Dense(10, activation="relu"))
|
||||
model.add(l.Dense(1, activation="sigmoid"))
|
||||
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss="binary_crossentropy",
|
||||
metrics=["accuracy"])
|
||||
|
||||
model.fit_generator(generator=dataset.batch(10), epochs=5)
|
||||
|
||||
# This is testing that a model using a SavedModel can be re-exported again,
|
||||
# e.g. to catch issues such as b/142231881.
|
||||
tf.saved_model.save(model, tempfile.mkdtemp())
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.enable_v2_behavior()
|
||||
app.run(main)
|
@ -1,50 +0,0 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Load and use RNN model stored as a SavedModel."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tempfile
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
sentences = [
|
||||
"<S> sentence <E>", "<S> second sentence <E>", "<S> third sentence<E>"
|
||||
]
|
||||
|
||||
model = tf.saved_model.load(FLAGS.model_dir)
|
||||
model.train(tf.constant(sentences))
|
||||
decoded = model.decode_greedy(
|
||||
sequence_length=10, first_word=tf.constant("<S>"))
|
||||
_ = [d.numpy() for d in decoded]
|
||||
|
||||
# This is testing that a model using a SavedModel can be re-exported again,
|
||||
# e.g. to catch issues such as b/142231881.
|
||||
tf.saved_model.save(model, tempfile.mkdtemp())
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
@ -98,7 +98,6 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_srcs",
|
||||
"//tensorflow/compiler/mlir/tensorflow:gen_mlir_passthrough_op_py",
|
||||
"//tensorflow/core:protos_all_proto_srcs",
|
||||
"//tensorflow/examples/saved_model/integration_tests:mnist_util",
|
||||
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
||||
"//tensorflow/lite/python:tflite_convert",
|
||||
"//tensorflow/lite/toco/python:toco_from_protos",
|
||||
|
Loading…
Reference in New Issue
Block a user