Remove saved_model test from examples.
PiperOrigin-RevId: 336914830 Change-Id: Ib15d58225c837d9550901eddc623961be028cac7
This commit is contained in:
parent
12c7ef6bec
commit
c8a9751c55
@ -1,84 +0,0 @@
|
|||||||
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
|
|
||||||
|
|
||||||
package(
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "integration_scripts",
|
|
||||||
srcs = [
|
|
||||||
"deploy_mnist_cnn.py",
|
|
||||||
"export_mnist_cnn.py",
|
|
||||||
"export_rnn_cell.py",
|
|
||||||
"export_simple_text_embedding.py",
|
|
||||||
"export_text_rnn_model.py",
|
|
||||||
"integration_scripts.py",
|
|
||||||
"use_mnist_cnn.py",
|
|
||||||
"use_model_in_sequential_keras.py",
|
|
||||||
"use_rnn_cell.py",
|
|
||||||
"use_text_embedding_in_dataset.py",
|
|
||||||
"use_text_rnn_model.py",
|
|
||||||
],
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
deps = [
|
|
||||||
":distribution_strategy_utils",
|
|
||||||
":mnist_util",
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "mnist_util",
|
|
||||||
srcs = ["mnist_util.py"],
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "distribution_strategy_utils",
|
|
||||||
srcs = ["distribution_strategy_utils.py"],
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/python/distribute:strategy_combinations",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
distribute_py_test(
|
|
||||||
name = "saved_model_test",
|
|
||||||
srcs = [
|
|
||||||
"saved_model_test.py",
|
|
||||||
],
|
|
||||||
shard_count = 4,
|
|
||||||
tags = [
|
|
||||||
"no_pip", # b/131697937 and b/132196869
|
|
||||||
"noasan", # forge input size exceeded
|
|
||||||
"nomsan", # forge input size exceeded
|
|
||||||
"notsan", # forge input size exceeded
|
|
||||||
],
|
|
||||||
tpu_tags = [
|
|
||||||
"no_oss", # Test infra collision (b/157754990)
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":distribution_strategy_utils",
|
|
||||||
":integration_scripts",
|
|
||||||
"//tensorflow:tensorflow_py",
|
|
||||||
"//tensorflow/python:framework_combinations",
|
|
||||||
"//tensorflow/python/distribute:combinations",
|
|
||||||
"@absl_py//absl/testing:parameterized",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# b/132234211: Target added to support internal test target that runs the test
|
|
||||||
# in an environment that has the extra dependencies required to test integration
|
|
||||||
# with non core tensorflow packages.
|
|
||||||
py_library(
|
|
||||||
name = "saved_model_test_lib",
|
|
||||||
srcs = [
|
|
||||||
"saved_model_test.py",
|
|
||||||
],
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
deps = [":integration_scripts"],
|
|
||||||
)
|
|
@ -1,107 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Deploys a SavedModel with an MNIST classifier to TFLite."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'saved_model_dir', None,
|
|
||||||
'Directory of the SavedModel to deploy.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'use_fashion_mnist', False,
|
|
||||||
'Use Fashion MNIST (products) instead of the real MNIST (digits).')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'fast_test_mode', False,
|
|
||||||
'Limit amount of test data for running in unit tests.')
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'tflite_output_file', None,
|
|
||||||
'The filename of the .tflite model file to write (optional).')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'reload_as_keras_model', True,
|
|
||||||
'Also test tf.keras.models.load_model() on --saved_model_dir.')
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
# First convert the SavedModel in a pristine environment.
|
|
||||||
converter = tf.lite.TFLiteConverter.from_saved_model(FLAGS.saved_model_dir)
|
|
||||||
lite_model_content = converter.convert()
|
|
||||||
# Here is how you can save it for actual deployment.
|
|
||||||
if FLAGS.tflite_output_file:
|
|
||||||
with open(FLAGS.tflite_output_file, 'wb') as outfile:
|
|
||||||
outfile.write(lite_model_content)
|
|
||||||
# For testing, the TFLite model can be executed like this.
|
|
||||||
interpreter = tf.lite.Interpreter(model_content=lite_model_content)
|
|
||||||
def lite_model(images):
|
|
||||||
interpreter.allocate_tensors()
|
|
||||||
interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
|
|
||||||
interpreter.invoke()
|
|
||||||
return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
|
|
||||||
|
|
||||||
# Load the SavedModel again for use as a test baseline.
|
|
||||||
imported = tf.saved_model.load(FLAGS.saved_model_dir)
|
|
||||||
def tf_model(images):
|
|
||||||
output_dict = imported.signatures['serving_default'](tf.constant(images))
|
|
||||||
logits, = output_dict.values() # Unpack single value.
|
|
||||||
return logits
|
|
||||||
|
|
||||||
# Compare model outputs on the test inputs.
|
|
||||||
(_, _), (x_test, _) = mnist_util.load_reshaped_data(
|
|
||||||
use_fashion_mnist=FLAGS.use_fashion_mnist,
|
|
||||||
fake_tiny_data=FLAGS.fast_test_mode)
|
|
||||||
for i, x in enumerate(x_test):
|
|
||||||
x = x[None, ...] # Make batch of size 1.
|
|
||||||
y_lite = lite_model(x)
|
|
||||||
y_tf = tf_model(x)
|
|
||||||
# This numpy primitive uses plain `raise` and works outside tf.TestCase.
|
|
||||||
# Model outputs are probabilities that sum to 1, so atol makes sense here.
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
y_lite, y_tf, rtol=0, atol=1e-5,
|
|
||||||
err_msg='Mismatch with TF Lite at test example %d' % i)
|
|
||||||
|
|
||||||
# Test that the SavedModel loads correctly with v1 load APIs as well.
|
|
||||||
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as session:
|
|
||||||
tf.compat.v1.saved_model.load(
|
|
||||||
session,
|
|
||||||
[tf.compat.v1.saved_model.SERVING],
|
|
||||||
FLAGS.saved_model_dir)
|
|
||||||
|
|
||||||
# The SavedModel actually was a Keras Model; test that it also loads as that.
|
|
||||||
if FLAGS.reload_as_keras_model:
|
|
||||||
keras_model = tf.keras.models.load_model(FLAGS.saved_model_dir)
|
|
||||||
for i, x in enumerate(x_test):
|
|
||||||
x = x[None, ...] # Make batch of size 1.
|
|
||||||
y_tf = tf_model(x)
|
|
||||||
y_keras = keras_model(x)
|
|
||||||
# This numpy primitive uses plain `raise` and works outside tf.TestCase.
|
|
||||||
# Model outputs are probabilities that sum to 1, so atol makes sense here.
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
y_tf, y_keras, rtol=0, atol=1e-5,
|
|
||||||
err_msg='Mismatch with Keras at test example %d' % i)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
app.run(main)
|
|
@ -1,68 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Utils related to tf.distribute.strategy."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import collections
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from tensorflow.python.distribute import strategy_combinations
|
|
||||||
|
|
||||||
_strategies = [
|
|
||||||
strategy_combinations.one_device_strategy,
|
|
||||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
|
||||||
strategy_combinations.mirrored_strategy_with_one_gpu,
|
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
|
||||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
|
||||||
strategy_combinations.tpu_strategy,
|
|
||||||
]
|
|
||||||
|
|
||||||
# The presence of GPU strategies upsets TPU initialization,
|
|
||||||
# despite their test instances being skipped early. This is a workaround
|
|
||||||
# for b/145386854.
|
|
||||||
if "test_tpu" in sys.argv[0]:
|
|
||||||
_strategies = [s for s in _strategies if "GPU" not in str(s)]
|
|
||||||
|
|
||||||
|
|
||||||
named_strategies = collections.OrderedDict(
|
|
||||||
[(None, None)] +
|
|
||||||
[(str(s), s) for s in _strategies]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MaybeDistributionScope(object):
|
|
||||||
"""Provides a context allowing no distribution strategy."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_name(name):
|
|
||||||
return MaybeDistributionScope(named_strategies[name].strategy if name
|
|
||||||
else None)
|
|
||||||
|
|
||||||
def __init__(self, distribution):
|
|
||||||
self._distribution = distribution
|
|
||||||
self._scope = None
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
if self._distribution:
|
|
||||||
self._scope = self._distribution.scope()
|
|
||||||
self._scope.__enter__()
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, value, traceback):
|
|
||||||
if self._distribution:
|
|
||||||
self._scope.__exit__(exc_type, value, traceback)
|
|
||||||
self._scope = None
|
|
@ -1,204 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Exports a convolutional feature extractor for MNIST in SavedModel format.
|
|
||||||
|
|
||||||
The feature extractor is a convolutional neural network plus a hidden layer
|
|
||||||
that gets trained as part of an MNIST classifier and then written to a
|
|
||||||
SavedModel (without the classification layer). From there, use_mnist_cnn.py
|
|
||||||
picks it up for transfer learning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
|
||||||
from tensorflow.python.util import tf_decorator
|
|
||||||
from tensorflow.python.util import tf_inspect
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'export_dir', None,
|
|
||||||
'Directory of exported SavedModel.')
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'epochs', 10,
|
|
||||||
'Number of epochs to train.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'use_keras_save_api', False,
|
|
||||||
'Uses tf.keras.models.save_model() on the feature extractor '
|
|
||||||
'instead of tf.saved_model.save() on a manually wrapped version. '
|
|
||||||
'With this, the exported model has no hparams.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'fast_test_mode', False,
|
|
||||||
'Shortcut training for running in unit tests.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'export_print_hparams', False,
|
|
||||||
'If true, the exported function will print its effective hparams.')
|
|
||||||
|
|
||||||
|
|
||||||
def make_feature_extractor(l2_strength, dropout_rate):
|
|
||||||
"""Returns a Keras Model to compute a feature vector from MNIST images."""
|
|
||||||
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
|
|
||||||
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
|
|
||||||
net = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', name='conv1',
|
|
||||||
kernel_regularizer=regularizer())(net)
|
|
||||||
net = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', name='conv2',
|
|
||||||
kernel_regularizer=regularizer())(net)
|
|
||||||
net = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name='pool1')(net)
|
|
||||||
net = tf.keras.layers.Dropout(dropout_rate, name='dropout1')(net)
|
|
||||||
net = tf.keras.layers.Flatten(name='flatten')(net)
|
|
||||||
net = tf.keras.layers.Dense(10, activation='relu', name='dense1',
|
|
||||||
kernel_regularizer=regularizer())(net)
|
|
||||||
return tf.keras.Model(inputs=inp, outputs=net)
|
|
||||||
|
|
||||||
|
|
||||||
def set_feature_extractor_hparams(model, dropout_rate):
|
|
||||||
model.get_layer('dropout1').rate = dropout_rate
|
|
||||||
|
|
||||||
|
|
||||||
def make_classifier(feature_extractor, l2_strength, dropout_rate=0.5):
|
|
||||||
"""Returns a Keras Model to classify MNIST using feature_extractor."""
|
|
||||||
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
|
|
||||||
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
|
|
||||||
net = feature_extractor(net)
|
|
||||||
net = tf.keras.layers.Dropout(dropout_rate)(net)
|
|
||||||
net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax',
|
|
||||||
kernel_regularizer=regularizer())(net)
|
|
||||||
return tf.keras.Model(inputs=inp, outputs=net)
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_keras_model_for_export(model, batch_input_shape,
|
|
||||||
set_hparams, default_hparams):
|
|
||||||
"""Wraps `model` for saving and loading as SavedModel."""
|
|
||||||
# The primary input to the module is a Tensor with a batch of images.
|
|
||||||
# Here we determine its spec.
|
|
||||||
inputs_spec = tf.TensorSpec(shape=batch_input_shape, dtype=tf.float32)
|
|
||||||
|
|
||||||
# The module also accepts certain hparams as optional Tensor inputs.
|
|
||||||
# Here, we cut all the relevant slices from `default_hparams`
|
|
||||||
# (and don't worry if anyone accidentally modifies it later).
|
|
||||||
if default_hparams is None: default_hparams = {}
|
|
||||||
hparam_keys = list(default_hparams.keys())
|
|
||||||
hparam_defaults = tuple(default_hparams.values())
|
|
||||||
hparams_spec = {name: tf.TensorSpec.from_tensor(tf.constant(value))
|
|
||||||
for name, value in default_hparams.items()}
|
|
||||||
|
|
||||||
# The goal is to save a function with this argspec...
|
|
||||||
argspec = tf_inspect.FullArgSpec(
|
|
||||||
args=(['inputs', 'training'] + hparam_keys),
|
|
||||||
defaults=((False,) + hparam_defaults),
|
|
||||||
varargs=None, varkw=None,
|
|
||||||
kwonlyargs=[], kwonlydefaults=None,
|
|
||||||
annotations={})
|
|
||||||
# ...and this behavior:
|
|
||||||
def call_fn(inputs, training, *args):
|
|
||||||
if FLAGS.export_print_hparams:
|
|
||||||
args = [tf.keras.backend.print_tensor(args[i], 'training=%s and %s='
|
|
||||||
% (training, hparam_keys[i]))
|
|
||||||
for i in range(len(args))]
|
|
||||||
kwargs = dict(zip(hparam_keys, args))
|
|
||||||
if kwargs: set_hparams(model, **kwargs)
|
|
||||||
return model(inputs, training=training)
|
|
||||||
|
|
||||||
# We cannot spell out `args` in def statement for call_fn, but since
|
|
||||||
# tf.function uses tf_inspect, we can use tf_decorator to wrap it with
|
|
||||||
# the desired argspec.
|
|
||||||
def wrapped(*args, **kwargs): # TODO(arnoegw): Can we use call_fn itself?
|
|
||||||
return call_fn(*args, **kwargs)
|
|
||||||
traced_call_fn = tf.function(
|
|
||||||
tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
|
|
||||||
|
|
||||||
# Now we need to trigger traces for all supported combinations of the
|
|
||||||
# non-Tensor-value inputs.
|
|
||||||
for training in (True, False):
|
|
||||||
traced_call_fn.get_concrete_function(inputs_spec, training, **hparams_spec)
|
|
||||||
|
|
||||||
# Finally, we assemble the object for tf.saved_model.save().
|
|
||||||
obj = tf.train.Checkpoint()
|
|
||||||
obj.__call__ = traced_call_fn
|
|
||||||
obj.trainable_variables = model.trainable_variables
|
|
||||||
obj.variables = model.trainable_variables + model.non_trainable_variables
|
|
||||||
# Make tf.functions for the regularization terms of the loss.
|
|
||||||
obj.regularization_losses = [_get_traced_loss(model, i)
|
|
||||||
for i in range(len(model.losses))]
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
def _get_traced_loss(model, i):
|
|
||||||
"""Returns tf.function for model.losses[i] with a trace for zero args.
|
|
||||||
|
|
||||||
The intended usage is
|
|
||||||
[_get_traced_loss(model, i) for i in range(len(model.losses))]
|
|
||||||
This is better than
|
|
||||||
[tf.function(lambda: model.losses[i], input_signature=[]) for i ...]
|
|
||||||
because it avoids capturing a loop index in a lambda, and removes any
|
|
||||||
chance of deferring the trace.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: a Keras Model.
|
|
||||||
i: an integer between from 0 up to but to len(model.losses).
|
|
||||||
"""
|
|
||||||
f = tf.function(lambda: model.losses[i])
|
|
||||||
_ = f.get_concrete_function()
|
|
||||||
return f
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
# Build a complete classifier model using a feature extractor.
|
|
||||||
default_hparams = dict(dropout_rate=0.25)
|
|
||||||
l2_strength = 0.01 # Not a hparam for inputs -> outputs.
|
|
||||||
feature_extractor = make_feature_extractor(l2_strength=l2_strength,
|
|
||||||
**default_hparams)
|
|
||||||
classifier = make_classifier(feature_extractor, l2_strength=l2_strength)
|
|
||||||
|
|
||||||
# Train the complete model.
|
|
||||||
(x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
|
|
||||||
fake_tiny_data=FLAGS.fast_test_mode)
|
|
||||||
classifier.compile(loss=tf.keras.losses.categorical_crossentropy,
|
|
||||||
optimizer=tf.keras.optimizers.SGD(),
|
|
||||||
metrics=['accuracy'])
|
|
||||||
classifier.fit(x_train, y_train,
|
|
||||||
batch_size=128,
|
|
||||||
epochs=FLAGS.epochs,
|
|
||||||
verbose=1,
|
|
||||||
validation_data=(x_test, y_test))
|
|
||||||
|
|
||||||
# Save the feature extractor to a framework-agnostic SavedModel for reuse.
|
|
||||||
# Note that the feature_extractor object has not been compiled or fitted,
|
|
||||||
# so it does not contain an optimizer and related state.
|
|
||||||
if FLAGS.use_keras_save_api:
|
|
||||||
# Use Keras' built-in way of creating reusable SavedModels.
|
|
||||||
# This has no support for adjustable hparams at this time (July 2019).
|
|
||||||
# (We could also call tf.saved_model.save(feature_extractor, ...),
|
|
||||||
# point is we're passing a Keras model, not a plain Checkpoint.)
|
|
||||||
tf.keras.models.save_model(feature_extractor, FLAGS.export_dir)
|
|
||||||
else:
|
|
||||||
# Assemble a reusable SavedModel manually, with adjustable hparams.
|
|
||||||
exportable = wrap_keras_model_for_export(feature_extractor,
|
|
||||||
(None,) + mnist_util.INPUT_SHAPE,
|
|
||||||
set_feature_extractor_hparams,
|
|
||||||
default_hparams)
|
|
||||||
tf.saved_model.save(exportable, FLAGS.export_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
app.run(main)
|
|
@ -1,63 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Export an RNN cell in SavedModel format."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
root = tf.train.Checkpoint()
|
|
||||||
# Create a cell and attach to our trackable.
|
|
||||||
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
|
|
||||||
|
|
||||||
# Wrap the rnn_cell.__call__ function and assign to next_state.
|
|
||||||
root.next_state = tf.function(root.rnn_cell.__call__)
|
|
||||||
|
|
||||||
# Wrap the rnn_cell.get_initial_function using a decorator and assign to an
|
|
||||||
# attribute with the same name.
|
|
||||||
@tf.function(input_signature=[tf.TensorSpec([None, None], tf.float32)])
|
|
||||||
def get_initial_state(tensor):
|
|
||||||
return root.rnn_cell.get_initial_state(tensor, None, None)
|
|
||||||
|
|
||||||
root.get_initial_state = get_initial_state
|
|
||||||
|
|
||||||
# Construct an initial_state, then call next_state explicitly to trigger a
|
|
||||||
# trace for serialization (we need an explicit call, because next_state has
|
|
||||||
# not been annotated with an input_signature).
|
|
||||||
initial_state = root.get_initial_state(
|
|
||||||
tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
|
|
||||||
root.next_state(
|
|
||||||
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
|
|
||||||
initial_state)
|
|
||||||
|
|
||||||
tf.saved_model.save(root, FLAGS.export_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(main)
|
|
@ -1,105 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Text embedding model stored as a SavedModel."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
|
|
||||||
|
|
||||||
|
|
||||||
def write_vocabulary_file(vocabulary):
|
|
||||||
"""Write temporary vocab file for module construction."""
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
vocabulary_file = os.path.join(tmpdir, "tokens.txt")
|
|
||||||
with tf.io.gfile.GFile(vocabulary_file, "w") as f:
|
|
||||||
for entry in vocabulary:
|
|
||||||
f.write(entry + "\n")
|
|
||||||
return vocabulary_file
|
|
||||||
|
|
||||||
|
|
||||||
class TextEmbeddingModel(tf.train.Checkpoint):
|
|
||||||
"""Text embedding model.
|
|
||||||
|
|
||||||
A text embeddings model that takes a sentences on input and outputs the
|
|
||||||
sentence embedding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocabulary, emb_dim, oov_buckets):
|
|
||||||
super(TextEmbeddingModel, self).__init__()
|
|
||||||
self._oov_buckets = oov_buckets
|
|
||||||
self._total_size = len(vocabulary) + oov_buckets
|
|
||||||
# Assign the table initializer to this instance to ensure the asset
|
|
||||||
# it depends on is saved with the SavedModel.
|
|
||||||
self._table_initializer = tf.lookup.TextFileInitializer(
|
|
||||||
write_vocabulary_file(vocabulary), tf.string,
|
|
||||||
tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64,
|
|
||||||
tf.lookup.TextFileIndex.LINE_NUMBER)
|
|
||||||
self._table = tf.lookup.StaticVocabularyTable(
|
|
||||||
self._table_initializer, num_oov_buckets=self._oov_buckets)
|
|
||||||
self.embeddings = tf.Variable(
|
|
||||||
tf.random.uniform(shape=[self._total_size, emb_dim]))
|
|
||||||
self.variables = [self.embeddings]
|
|
||||||
self.trainable_variables = self.variables
|
|
||||||
|
|
||||||
def _tokenize(self, sentences):
|
|
||||||
# Perform a minimalistic text preprocessing by removing punctuation and
|
|
||||||
# splitting on spaces.
|
|
||||||
normalized_sentences = tf.strings.regex_replace(
|
|
||||||
input=sentences, pattern=r"\pP", rewrite="")
|
|
||||||
normalized_sentences = tf.reshape(normalized_sentences, [-1])
|
|
||||||
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
|
|
||||||
|
|
||||||
# Deal with a corner case: there is one empty sentence.
|
|
||||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
|
||||||
# Deal with a corner case: all sentences are empty.
|
|
||||||
sparse_tokens = tf.sparse.reset_shape(sparse_tokens)
|
|
||||||
sparse_token_ids = self._table.lookup(sparse_tokens.values)
|
|
||||||
|
|
||||||
return (sparse_tokens.indices, sparse_token_ids, sparse_tokens.dense_shape)
|
|
||||||
|
|
||||||
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)])
|
|
||||||
def __call__(self, sentences):
|
|
||||||
token_ids, token_values, token_dense_shape = self._tokenize(sentences)
|
|
||||||
|
|
||||||
return tf.nn.safe_embedding_lookup_sparse(
|
|
||||||
embedding_weights=self.embeddings,
|
|
||||||
sparse_ids=tf.sparse.SparseTensor(token_ids, token_values,
|
|
||||||
token_dense_shape),
|
|
||||||
sparse_weights=None,
|
|
||||||
combiner="sqrtn")
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
vocabulary = ["cat", "is", "on", "the", "mat"]
|
|
||||||
module = TextEmbeddingModel(vocabulary=vocabulary, emb_dim=10, oov_buckets=10)
|
|
||||||
tf.saved_model.save(module, FLAGS.export_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(main)
|
|
@ -1,193 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Text RNN model stored as a SavedModel."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
|
|
||||||
|
|
||||||
|
|
||||||
class TextRnnModel(tf.train.Checkpoint):
|
|
||||||
"""Text RNN model.
|
|
||||||
|
|
||||||
A full generative text RNN model that can train and decode sentences from a
|
|
||||||
starting word.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocab, emb_dim, buckets, state_size):
|
|
||||||
super(TextRnnModel, self).__init__()
|
|
||||||
self._buckets = buckets
|
|
||||||
self._lstm_cell = tf.keras.layers.LSTMCell(units=state_size)
|
|
||||||
self._rnn_layer = tf.keras.layers.RNN(
|
|
||||||
self._lstm_cell, return_sequences=True)
|
|
||||||
self._embeddings = tf.Variable(tf.random.uniform(shape=[buckets, emb_dim]))
|
|
||||||
self._logit_layer = tf.keras.layers.Dense(buckets)
|
|
||||||
self._set_up_vocab(vocab)
|
|
||||||
|
|
||||||
def _tokenize(self, sentences):
|
|
||||||
# Perform a minimalistic text preprocessing by removing punctuation and
|
|
||||||
# splitting on spaces.
|
|
||||||
normalized_sentences = tf.strings.regex_replace(
|
|
||||||
input=sentences, pattern=r"\pP", rewrite="")
|
|
||||||
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
|
|
||||||
|
|
||||||
# Deal with a corner case: there is one empty sentence.
|
|
||||||
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
|
|
||||||
# Deal with a corner case: all sentences are empty.
|
|
||||||
sparse_tokens = tf.sparse.reset_shape(sparse_tokens)
|
|
||||||
|
|
||||||
return (sparse_tokens.indices, sparse_tokens.values,
|
|
||||||
sparse_tokens.dense_shape)
|
|
||||||
|
|
||||||
def _set_up_vocab(self, vocab_tokens):
|
|
||||||
# TODO(vbardiovsky): Currently there is no real vocabulary, because
|
|
||||||
# saved_model serialization does not support trackable resources. Add a real
|
|
||||||
# vocabulary when it does.
|
|
||||||
vocab_list = ["UNK"] * self._buckets
|
|
||||||
for vocab_token in vocab_tokens:
|
|
||||||
index = self._words_to_indices(vocab_token).numpy()
|
|
||||||
vocab_list[index] = vocab_token
|
|
||||||
# This is a variable representing an inverse index.
|
|
||||||
self._vocab_tensor = tf.Variable(vocab_list)
|
|
||||||
|
|
||||||
def _indices_to_words(self, indices):
|
|
||||||
return tf.gather(self._vocab_tensor, indices)
|
|
||||||
|
|
||||||
def _words_to_indices(self, words):
|
|
||||||
return tf.strings.to_hash_bucket(words, self._buckets)
|
|
||||||
|
|
||||||
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)])
|
|
||||||
def train(self, sentences):
|
|
||||||
token_ids, token_values, token_dense_shape = self._tokenize(sentences)
|
|
||||||
tokens_sparse = tf.sparse.SparseTensor(
|
|
||||||
indices=token_ids, values=token_values, dense_shape=token_dense_shape)
|
|
||||||
tokens = tf.sparse.to_dense(tokens_sparse, default_value="")
|
|
||||||
|
|
||||||
sparse_lookup_ids = tf.sparse.SparseTensor(
|
|
||||||
indices=tokens_sparse.indices,
|
|
||||||
values=self._words_to_indices(tokens_sparse.values),
|
|
||||||
dense_shape=tokens_sparse.dense_shape)
|
|
||||||
lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0)
|
|
||||||
|
|
||||||
# Targets are the next word for each word of the sentence.
|
|
||||||
tokens_ids_seq = lookup_ids[:, 0:-1]
|
|
||||||
tokens_ids_target = lookup_ids[:, 1:]
|
|
||||||
|
|
||||||
tokens_prefix = tokens[:, 0:-1]
|
|
||||||
|
|
||||||
# Mask determining which positions we care about for a loss: all positions
|
|
||||||
# that have a valid non-terminal token.
|
|
||||||
mask = tf.logical_and(
|
|
||||||
tf.logical_not(tf.equal(tokens_prefix, "")),
|
|
||||||
tf.logical_not(tf.equal(tokens_prefix, "<E>")))
|
|
||||||
|
|
||||||
input_mask = tf.cast(mask, tf.int32)
|
|
||||||
|
|
||||||
with tf.GradientTape() as t:
|
|
||||||
sentence_embeddings = tf.nn.embedding_lookup(self._embeddings,
|
|
||||||
tokens_ids_seq)
|
|
||||||
|
|
||||||
lstm_initial_state = self._lstm_cell.get_initial_state(
|
|
||||||
sentence_embeddings)
|
|
||||||
|
|
||||||
lstm_output = self._rnn_layer(
|
|
||||||
inputs=sentence_embeddings, initial_state=lstm_initial_state)
|
|
||||||
|
|
||||||
# Stack LSTM outputs into a batch instead of a 2D array.
|
|
||||||
lstm_output = tf.reshape(lstm_output, [-1, self._lstm_cell.output_size])
|
|
||||||
|
|
||||||
logits = self._logit_layer(lstm_output)
|
|
||||||
|
|
||||||
targets = tf.reshape(tokens_ids_target, [-1])
|
|
||||||
weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32)
|
|
||||||
|
|
||||||
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
||||||
labels=targets, logits=logits)
|
|
||||||
|
|
||||||
# Final loss is the mean loss for all token losses.
|
|
||||||
final_loss = tf.math.divide(
|
|
||||||
tf.reduce_sum(tf.multiply(losses, weights)),
|
|
||||||
tf.reduce_sum(weights),
|
|
||||||
name="final_loss")
|
|
||||||
|
|
||||||
watched = t.watched_variables()
|
|
||||||
gradients = t.gradient(final_loss, watched)
|
|
||||||
|
|
||||||
for w, g in zip(watched, gradients):
|
|
||||||
w.assign_sub(g)
|
|
||||||
|
|
||||||
return final_loss
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def decode_greedy(self, sequence_length, first_word):
|
|
||||||
initial_state = self._lstm_cell.get_initial_state(
|
|
||||||
dtype=tf.float32, batch_size=1)
|
|
||||||
|
|
||||||
sequence = [first_word]
|
|
||||||
current_word = first_word
|
|
||||||
current_id = tf.expand_dims(self._words_to_indices(current_word), 0)
|
|
||||||
current_state = initial_state
|
|
||||||
|
|
||||||
for _ in range(sequence_length):
|
|
||||||
token_embeddings = tf.nn.embedding_lookup(self._embeddings, current_id)
|
|
||||||
lstm_outputs, current_state = self._lstm_cell(token_embeddings,
|
|
||||||
current_state)
|
|
||||||
lstm_outputs = tf.reshape(lstm_outputs, [-1, self._lstm_cell.output_size])
|
|
||||||
logits = self._logit_layer(lstm_outputs)
|
|
||||||
softmax = tf.nn.softmax(logits)
|
|
||||||
|
|
||||||
next_ids = tf.math.argmax(softmax, axis=1)
|
|
||||||
next_words = self._indices_to_words(next_ids)[0]
|
|
||||||
|
|
||||||
current_id = next_ids
|
|
||||||
current_word = next_words
|
|
||||||
sequence.append(current_word)
|
|
||||||
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
sentences = ["<S> hello there <E>", "<S> how are you doing today <E>"]
|
|
||||||
vocab = [
|
|
||||||
"<S>", "<E>", "hello", "there", "how", "are", "you", "doing", "today"
|
|
||||||
]
|
|
||||||
|
|
||||||
module = TextRnnModel(vocab=vocab, emb_dim=10, buckets=100, state_size=128)
|
|
||||||
|
|
||||||
for _ in range(100):
|
|
||||||
_ = module.train(tf.constant(sentences))
|
|
||||||
|
|
||||||
# We have to call this function explicitly if we want it exported, because it
|
|
||||||
# has no input_signature in the @tf.function decorator.
|
|
||||||
decoded = module.decode_greedy(
|
|
||||||
sequence_length=10, first_word=tf.constant("<S>"))
|
|
||||||
_ = [d.numpy() for d in decoded]
|
|
||||||
|
|
||||||
tf.saved_model.save(module, FLAGS.export_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(main)
|
|
@ -1,68 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Utility to write SavedModel integration tests.
|
|
||||||
|
|
||||||
SavedModel testing requires isolation between the process that creates and
|
|
||||||
consumes it. This file helps doing that by relaunching the same binary that
|
|
||||||
calls `assertCommandSucceeded` with an environment flag indicating what source
|
|
||||||
file to execute. That binary must start by calling `MaybeRunScriptInstead`.
|
|
||||||
|
|
||||||
This allows to wire this into existing building systems without having to depend
|
|
||||||
on data dependencies. And as so allow to keep a fixed binary size and allows
|
|
||||||
interop with GPU tests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import importlib
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
|
||||||
|
|
||||||
|
|
||||||
class TestCase(tf.test.TestCase):
|
|
||||||
"""Base class to write SavedModel integration tests."""
|
|
||||||
|
|
||||||
def assertCommandSucceeded(self, script_name, **flags):
|
|
||||||
"""Runs an integration test script with given flags."""
|
|
||||||
run_script = sys.argv[0]
|
|
||||||
if run_script.endswith(".py"):
|
|
||||||
command_parts = [sys.executable, run_script]
|
|
||||||
else:
|
|
||||||
command_parts = [run_script]
|
|
||||||
command_parts.append("--alsologtostderr") # For visibility in sponge.
|
|
||||||
for flag_key, flag_value in flags.items():
|
|
||||||
command_parts.append("--%s=%s" % (flag_key, flag_value))
|
|
||||||
|
|
||||||
env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
|
|
||||||
logging.info("Running %s with added environment variables %s" %
|
|
||||||
(command_parts, env))
|
|
||||||
subprocess.check_call(command_parts, env=dict(os.environ, **env))
|
|
||||||
|
|
||||||
|
|
||||||
def MaybeRunScriptInstead():
|
|
||||||
if "SCRIPT_NAME" in os.environ:
|
|
||||||
# Append current path to import path and execute `SCRIPT_NAME` main.
|
|
||||||
sys.path.extend([os.path.dirname(__file__)])
|
|
||||||
module_name = os.environ["SCRIPT_NAME"]
|
|
||||||
retval = app.run(importlib.import_module(module_name).main) # pylint: disable=assignment-from-no-return
|
|
||||||
sys.exit(retval)
|
|
@ -1,51 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Convenience wrapper around Keras' MNIST and Fashion MNIST data."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
INPUT_SHAPE = (28, 28, 1)
|
|
||||||
NUM_CLASSES = 10
|
|
||||||
|
|
||||||
|
|
||||||
def _load_random_data(num_train_and_test):
|
|
||||||
return ((np.random.randint(0, 256, (num, 28, 28), dtype=np.uint8),
|
|
||||||
np.random.randint(0, 10, (num,), dtype=np.int64))
|
|
||||||
for num in num_train_and_test)
|
|
||||||
|
|
||||||
|
|
||||||
def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False):
|
|
||||||
"""Returns MNIST or Fashion MNIST or fake train and test data."""
|
|
||||||
load = ((lambda: _load_random_data([128, 128])) if fake_tiny_data else
|
|
||||||
tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else
|
|
||||||
tf.keras.datasets.mnist.load_data)
|
|
||||||
(x_train, y_train), (x_test, y_test) = load()
|
|
||||||
return ((_prepare_image(x_train), _prepare_label(y_train)),
|
|
||||||
(_prepare_image(x_test), _prepare_label(y_test)))
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_image(x):
|
|
||||||
"""Converts images to [n,h,w,c] format in range [0,1]."""
|
|
||||||
return x[..., None].astype('float32') / 255.
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_label(y):
|
|
||||||
"""Conerts labels to one-hot encoding."""
|
|
||||||
return tf.keras.utils.to_categorical(y, NUM_CLASSES)
|
|
@ -1,133 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""SavedModel integration tests."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
|
||||||
from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts
|
|
||||||
from tensorflow.python.distribute import combinations as distribute_combinations
|
|
||||||
from tensorflow.python.framework import combinations
|
|
||||||
|
|
||||||
|
|
||||||
class SavedModelTest(scripts.TestCase, parameterized.TestCase):
|
|
||||||
|
|
||||||
def __init__(self, method_name="runTest", has_extra_deps=False):
|
|
||||||
super(SavedModelTest, self).__init__(method_name)
|
|
||||||
self.has_extra_deps = has_extra_deps
|
|
||||||
|
|
||||||
def skipIfMissingExtraDeps(self):
|
|
||||||
"""Skip test if it requires extra dependencies.
|
|
||||||
|
|
||||||
b/132234211: The extra dependencies are not available in all environments
|
|
||||||
that run the tests, e.g. "tensorflow_hub" is not available from tests
|
|
||||||
within "tensorflow" alone. Those tests are instead run by another
|
|
||||||
internal test target.
|
|
||||||
"""
|
|
||||||
if not self.has_extra_deps:
|
|
||||||
self.skipTest("Missing extra dependencies")
|
|
||||||
|
|
||||||
def test_text_rnn(self):
|
|
||||||
export_dir = self.get_temp_dir()
|
|
||||||
self.assertCommandSucceeded("export_text_rnn_model", export_dir=export_dir)
|
|
||||||
self.assertCommandSucceeded("use_text_rnn_model", model_dir=export_dir)
|
|
||||||
|
|
||||||
def test_rnn_cell(self):
|
|
||||||
export_dir = self.get_temp_dir()
|
|
||||||
self.assertCommandSucceeded("export_rnn_cell", export_dir=export_dir)
|
|
||||||
self.assertCommandSucceeded("use_rnn_cell", model_dir=export_dir)
|
|
||||||
|
|
||||||
def test_text_embedding_in_sequential_keras(self):
|
|
||||||
self.skipIfMissingExtraDeps()
|
|
||||||
export_dir = self.get_temp_dir()
|
|
||||||
self.assertCommandSucceeded(
|
|
||||||
"export_simple_text_embedding", export_dir=export_dir)
|
|
||||||
self.assertCommandSucceeded(
|
|
||||||
"use_model_in_sequential_keras", model_dir=export_dir)
|
|
||||||
|
|
||||||
def test_text_embedding_in_dataset(self):
|
|
||||||
export_dir = self.get_temp_dir()
|
|
||||||
self.assertCommandSucceeded(
|
|
||||||
"export_simple_text_embedding", export_dir=export_dir)
|
|
||||||
self.assertCommandSucceeded(
|
|
||||||
"use_text_embedding_in_dataset", model_dir=export_dir)
|
|
||||||
|
|
||||||
TEST_MNIST_CNN_GENERATE_KWARGS = dict(
|
|
||||||
combinations=(
|
|
||||||
combinations.combine(
|
|
||||||
# Test all combinations with tf.saved_model.save().
|
|
||||||
# Test all combinations using tf.keras.models.save_model()
|
|
||||||
# for both the reusable and the final full model.
|
|
||||||
use_keras_save_api=True,
|
|
||||||
named_strategy=list(ds_utils.named_strategies.values()),
|
|
||||||
retrain_flag_value=["true", "false"],
|
|
||||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
|
||||||
) + combinations.combine(
|
|
||||||
# Test few critcial combinations with raw tf.saved_model.save(),
|
|
||||||
# including export of a reusable SavedModel that gets assembled
|
|
||||||
# manually, including support for adjustable hparams.
|
|
||||||
use_keras_save_api=False,
|
|
||||||
named_strategy=None,
|
|
||||||
retrain_flag_value=["true", "false"],
|
|
||||||
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
|
|
||||||
)),
|
|
||||||
test_combinations=(distribute_combinations.GPUCombination(),
|
|
||||||
distribute_combinations.TPUCombination()))
|
|
||||||
|
|
||||||
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
|
|
||||||
def test_mnist_cnn(self, use_keras_save_api, named_strategy,
|
|
||||||
retrain_flag_value, regularization_loss_multiplier):
|
|
||||||
|
|
||||||
self.skipIfMissingExtraDeps()
|
|
||||||
|
|
||||||
fast_test_mode = True
|
|
||||||
temp_dir = self.get_temp_dir()
|
|
||||||
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
|
|
||||||
full_model_dir = os.path.join(temp_dir, "full_model")
|
|
||||||
|
|
||||||
self.assertCommandSucceeded(
|
|
||||||
"export_mnist_cnn",
|
|
||||||
fast_test_mode=fast_test_mode,
|
|
||||||
export_dir=feature_extrator_dir,
|
|
||||||
use_keras_save_api=use_keras_save_api)
|
|
||||||
|
|
||||||
use_kwargs = dict(fast_test_mode=fast_test_mode,
|
|
||||||
input_saved_model_dir=feature_extrator_dir,
|
|
||||||
retrain=retrain_flag_value,
|
|
||||||
output_saved_model_dir=full_model_dir,
|
|
||||||
use_keras_save_api=use_keras_save_api)
|
|
||||||
if named_strategy:
|
|
||||||
use_kwargs["strategy"] = str(named_strategy)
|
|
||||||
if regularization_loss_multiplier is not None:
|
|
||||||
use_kwargs[
|
|
||||||
"regularization_loss_multiplier"] = regularization_loss_multiplier
|
|
||||||
self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)
|
|
||||||
|
|
||||||
self.assertCommandSucceeded(
|
|
||||||
"deploy_mnist_cnn",
|
|
||||||
fast_test_mode=fast_test_mode,
|
|
||||||
saved_model_dir=full_model_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
scripts.MaybeRunScriptInstead()
|
|
||||||
tf.test.main()
|
|
@ -1,147 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Imports a convolutional feature extractor for MNIST in SavedModel format.
|
|
||||||
|
|
||||||
This program picks up the SavedModel written by export_mnist_cnn.py and
|
|
||||||
uses the feature extractor contained in it to do classification on either
|
|
||||||
classic MNIST (digits) or Fashion MNIST (thumbnails of apparel). Optionally,
|
|
||||||
it trains the feature extractor further as part of the new classifier.
|
|
||||||
As expected, that makes training slower but does not help much for the
|
|
||||||
original training dataset but helps a lot for transfer to the other dataset.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
import tensorflow_hub as hub
|
|
||||||
|
|
||||||
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
|
|
||||||
from tensorflow.examples.saved_model.integration_tests import mnist_util
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'input_saved_model_dir', None,
|
|
||||||
'Directory of the reusable SavedModel that is imported into this program.')
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'epochs', 5,
|
|
||||||
'Number of epochs to train.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'retrain', False,
|
|
||||||
'If set, the imported SavedModel is trained further.')
|
|
||||||
flags.DEFINE_float(
|
|
||||||
'dropout_rate', None,
|
|
||||||
'If set, dropout rate passed to the SavedModel. '
|
|
||||||
'Requires a SavedModel with support for adjustable hyperparameters.')
|
|
||||||
flags.DEFINE_float(
|
|
||||||
'regularization_loss_multiplier', None,
|
|
||||||
'If set, multiplier for the regularization losses in the SavedModel.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'use_fashion_mnist', False,
|
|
||||||
'Use Fashion MNIST (products) instead of the real MNIST (digits). '
|
|
||||||
'With this, --retrain gains a lot.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'fast_test_mode', False,
|
|
||||||
'Shortcut training for running in unit tests.')
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'output_saved_model_dir', None,
|
|
||||||
'Directory of the SavedModel that was exported for reuse.')
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'use_keras_save_api', False,
|
|
||||||
'Uses tf.keras.models.save_model() instead of tf.saved_model.save().')
|
|
||||||
flags.DEFINE_string('strategy', None,
|
|
||||||
'Name of the distribution strategy to use.')
|
|
||||||
|
|
||||||
|
|
||||||
def make_feature_extractor(saved_model_path, trainable,
|
|
||||||
regularization_loss_multiplier):
|
|
||||||
"""Load a pre-trained feature extractor and wrap it for use in Keras."""
|
|
||||||
if regularization_loss_multiplier is not None:
|
|
||||||
# TODO(b/63257857): Scaling regularization losses requires manual loading
|
|
||||||
# and modification of the SavedModel
|
|
||||||
obj = tf.saved_model.load(saved_model_path)
|
|
||||||
def _scale_one_loss(l): # Separate def avoids lambda capture of loop var.
|
|
||||||
f = tf.function(lambda: tf.multiply(regularization_loss_multiplier, l()))
|
|
||||||
_ = f.get_concrete_function()
|
|
||||||
return f
|
|
||||||
obj.regularization_losses = [_scale_one_loss(l)
|
|
||||||
for l in obj.regularization_losses]
|
|
||||||
# The modified object is then passed to hub.KerasLayer instead of the
|
|
||||||
# string handle. That prevents it from saving a Keras config (b/134528831).
|
|
||||||
handle = obj
|
|
||||||
else:
|
|
||||||
# If possible, we exercise the more common case of passing a string handle
|
|
||||||
# such that hub.KerasLayer can save a Keras config (b/134528831).
|
|
||||||
handle = saved_model_path
|
|
||||||
|
|
||||||
arguments = {}
|
|
||||||
if FLAGS.dropout_rate is not None:
|
|
||||||
arguments['dropout_rate'] = FLAGS.dropout_rate
|
|
||||||
|
|
||||||
return hub.KerasLayer(handle, trainable=trainable, arguments=arguments)
|
|
||||||
|
|
||||||
|
|
||||||
def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5):
|
|
||||||
"""Returns a Keras Model to classify MNIST using feature_extractor."""
|
|
||||||
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
|
|
||||||
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
|
|
||||||
net = feature_extractor(net)
|
|
||||||
if dropout_rate:
|
|
||||||
net = tf.keras.layers.Dropout(dropout_rate)(net)
|
|
||||||
net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax',
|
|
||||||
kernel_regularizer=regularizer())(net)
|
|
||||||
return tf.keras.Model(inputs=inp, outputs=net)
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
with ds_utils.MaybeDistributionScope.from_name(FLAGS.strategy):
|
|
||||||
feature_extractor = make_feature_extractor(
|
|
||||||
FLAGS.input_saved_model_dir,
|
|
||||||
FLAGS.retrain,
|
|
||||||
FLAGS.regularization_loss_multiplier)
|
|
||||||
model = make_classifier(feature_extractor)
|
|
||||||
|
|
||||||
model.compile(loss=tf.keras.losses.categorical_crossentropy,
|
|
||||||
optimizer=tf.keras.optimizers.SGD(),
|
|
||||||
metrics=['accuracy'])
|
|
||||||
|
|
||||||
# Train the classifier (possibly on a different dataset).
|
|
||||||
(x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
|
|
||||||
use_fashion_mnist=FLAGS.use_fashion_mnist,
|
|
||||||
fake_tiny_data=FLAGS.fast_test_mode)
|
|
||||||
print('Training on %s with %d trainable and %d untrainable variables.' %
|
|
||||||
('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST',
|
|
||||||
len(model.trainable_variables), len(model.non_trainable_variables)))
|
|
||||||
model.fit(x_train, y_train,
|
|
||||||
batch_size=128,
|
|
||||||
epochs=FLAGS.epochs,
|
|
||||||
verbose=1,
|
|
||||||
validation_data=(x_test, y_test))
|
|
||||||
|
|
||||||
if FLAGS.output_saved_model_dir:
|
|
||||||
if FLAGS.use_keras_save_api:
|
|
||||||
tf.keras.models.save_model(model, FLAGS.output_saved_model_dir)
|
|
||||||
else:
|
|
||||||
tf.saved_model.save(model, FLAGS.output_saved_model_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
app.run(main)
|
|
@ -1,75 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Load and use text embedding module in sequential Keras."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
import tensorflow_hub as hub
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
|
||||||
|
|
||||||
|
|
||||||
def train(fine_tuning):
|
|
||||||
"""Build a Keras model and train with mock data."""
|
|
||||||
features = np.array(["my first sentence", "my second sentence"])
|
|
||||||
labels = np.array([1, 0])
|
|
||||||
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
|
|
||||||
|
|
||||||
module = tf.saved_model.load(FLAGS.model_dir)
|
|
||||||
|
|
||||||
# Create the sequential keras model.
|
|
||||||
l = tf.keras.layers
|
|
||||||
model = tf.keras.Sequential()
|
|
||||||
model.add(l.Reshape((), batch_input_shape=[None, 1], dtype=tf.string))
|
|
||||||
# TODO(b/124219898): output_shape should be optional.
|
|
||||||
model.add(hub.KerasLayer(module, output_shape=[10], trainable=fine_tuning))
|
|
||||||
model.add(l.Dense(100, activation="relu"))
|
|
||||||
model.add(l.Dense(50, activation="relu"))
|
|
||||||
model.add(l.Dense(1, activation="sigmoid"))
|
|
||||||
|
|
||||||
model.compile(
|
|
||||||
optimizer="adam",
|
|
||||||
loss="binary_crossentropy",
|
|
||||||
metrics=["accuracy"],
|
|
||||||
# TODO(b/124446120): Remove after fixed.
|
|
||||||
run_eagerly=True)
|
|
||||||
|
|
||||||
model.fit_generator(generator=dataset.batch(1), epochs=5)
|
|
||||||
|
|
||||||
# This is testing that a model using a SavedModel can be re-exported again,
|
|
||||||
# e.g. to catch issues such as b/142231881.
|
|
||||||
tf.saved_model.save(model, tempfile.mkdtemp())
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
train(fine_tuning=False)
|
|
||||||
train(fine_tuning=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(main)
|
|
@ -1,50 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Load and use an RNN cell stored as a SavedModel."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
cell = tf.saved_model.load(FLAGS.model_dir)
|
|
||||||
|
|
||||||
initial_state = cell.get_initial_state(
|
|
||||||
tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
|
|
||||||
|
|
||||||
cell.next_state(
|
|
||||||
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
|
|
||||||
initial_state)
|
|
||||||
|
|
||||||
# This is testing that a model using a SavedModel can be re-exported again,
|
|
||||||
# e.g. to catch issues such as b/142231881.
|
|
||||||
tf.saved_model.save(cell, tempfile.mkdtemp())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(main)
|
|
@ -1,73 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Load and use text embedding module in a Dataset map function."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
|
||||||
"""Build a Keras model and train with mock data."""
|
|
||||||
module = tf.saved_model.load(FLAGS.model_dir)
|
|
||||||
def _map_fn(features, labels):
|
|
||||||
features = tf.expand_dims(features, 0)
|
|
||||||
features = module(features)
|
|
||||||
features = tf.squeeze(features, 0)
|
|
||||||
return features, labels
|
|
||||||
|
|
||||||
features = np.array(["my first sentence", "my second sentence"])
|
|
||||||
labels = np.array([1, 0])
|
|
||||||
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn)
|
|
||||||
|
|
||||||
# Create the sequential keras model.
|
|
||||||
l = tf.keras.layers
|
|
||||||
model = tf.keras.Sequential()
|
|
||||||
model.add(l.Dense(10, activation="relu"))
|
|
||||||
model.add(l.Dense(1, activation="sigmoid"))
|
|
||||||
|
|
||||||
model.compile(
|
|
||||||
optimizer="adam",
|
|
||||||
loss="binary_crossentropy",
|
|
||||||
metrics=["accuracy"])
|
|
||||||
|
|
||||||
model.fit_generator(generator=dataset.batch(10), epochs=5)
|
|
||||||
|
|
||||||
# This is testing that a model using a SavedModel can be re-exported again,
|
|
||||||
# e.g. to catch issues such as b/142231881.
|
|
||||||
tf.saved_model.save(model, tempfile.mkdtemp())
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
train()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
tf.enable_v2_behavior()
|
|
||||||
app.run(main)
|
|
@ -1,50 +0,0 @@
|
|||||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Load and use RNN model stored as a SavedModel."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from absl import app
|
|
||||||
from absl import flags
|
|
||||||
import tensorflow.compat.v2 as tf
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv):
|
|
||||||
del argv
|
|
||||||
|
|
||||||
sentences = [
|
|
||||||
"<S> sentence <E>", "<S> second sentence <E>", "<S> third sentence<E>"
|
|
||||||
]
|
|
||||||
|
|
||||||
model = tf.saved_model.load(FLAGS.model_dir)
|
|
||||||
model.train(tf.constant(sentences))
|
|
||||||
decoded = model.decode_greedy(
|
|
||||||
sequence_length=10, first_word=tf.constant("<S>"))
|
|
||||||
_ = [d.numpy() for d in decoded]
|
|
||||||
|
|
||||||
# This is testing that a model using a SavedModel can be re-exported again,
|
|
||||||
# e.g. to catch issues such as b/142231881.
|
|
||||||
tf.saved_model.save(model, tempfile.mkdtemp())
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app.run(main)
|
|
@ -98,7 +98,6 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_srcs",
|
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_srcs",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:gen_mlir_passthrough_op_py",
|
"//tensorflow/compiler/mlir/tensorflow:gen_mlir_passthrough_op_py",
|
||||||
"//tensorflow/core:protos_all_proto_srcs",
|
"//tensorflow/core:protos_all_proto_srcs",
|
||||||
"//tensorflow/examples/saved_model/integration_tests:mnist_util",
|
|
||||||
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
"//tensorflow/lite/python/testdata:interpreter_test_data",
|
||||||
"//tensorflow/lite/python:tflite_convert",
|
"//tensorflow/lite/python:tflite_convert",
|
||||||
"//tensorflow/lite/toco/python:toco_from_protos",
|
"//tensorflow/lite/toco/python:toco_from_protos",
|
||||||
|
Loading…
Reference in New Issue
Block a user