Remove saved_model test from examples.

PiperOrigin-RevId: 336914830
Change-Id: Ib15d58225c837d9550901eddc623961be028cac7
This commit is contained in:
Mark Daoust 2020-10-13 11:11:21 -07:00 committed by TensorFlower Gardener
parent 12c7ef6bec
commit c8a9751c55
16 changed files with 0 additions and 1472 deletions

View File

@ -1,84 +0,0 @@
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
package(
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "integration_scripts",
srcs = [
"deploy_mnist_cnn.py",
"export_mnist_cnn.py",
"export_rnn_cell.py",
"export_simple_text_embedding.py",
"export_text_rnn_model.py",
"integration_scripts.py",
"use_mnist_cnn.py",
"use_model_in_sequential_keras.py",
"use_rnn_cell.py",
"use_text_embedding_in_dataset.py",
"use_text_rnn_model.py",
],
visibility = ["//tensorflow:internal"],
deps = [
":distribution_strategy_utils",
":mnist_util",
"//tensorflow:tensorflow_py",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "mnist_util",
srcs = ["mnist_util.py"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow:tensorflow_py",
],
)
py_library(
name = "distribution_strategy_utils",
srcs = ["distribution_strategy_utils.py"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python/distribute:strategy_combinations",
],
)
distribute_py_test(
name = "saved_model_test",
srcs = [
"saved_model_test.py",
],
shard_count = 4,
tags = [
"no_pip", # b/131697937 and b/132196869
"noasan", # forge input size exceeded
"nomsan", # forge input size exceeded
"notsan", # forge input size exceeded
],
tpu_tags = [
"no_oss", # Test infra collision (b/157754990)
],
deps = [
":distribution_strategy_utils",
":integration_scripts",
"//tensorflow:tensorflow_py",
"//tensorflow/python:framework_combinations",
"//tensorflow/python/distribute:combinations",
"@absl_py//absl/testing:parameterized",
],
)
# b/132234211: Target added to support internal test target that runs the test
# in an environment that has the extra dependencies required to test integration
# with non core tensorflow packages.
py_library(
name = "saved_model_test_lib",
srcs = [
"saved_model_test.py",
],
visibility = ["//tensorflow:internal"],
deps = [":integration_scripts"],
)

View File

@ -1,107 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Deploys a SavedModel with an MNIST classifier to TFLite."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow.examples.saved_model.integration_tests import mnist_util
FLAGS = flags.FLAGS
flags.DEFINE_string(
'saved_model_dir', None,
'Directory of the SavedModel to deploy.')
flags.DEFINE_bool(
'use_fashion_mnist', False,
'Use Fashion MNIST (products) instead of the real MNIST (digits).')
flags.DEFINE_bool(
'fast_test_mode', False,
'Limit amount of test data for running in unit tests.')
flags.DEFINE_string(
'tflite_output_file', None,
'The filename of the .tflite model file to write (optional).')
flags.DEFINE_bool(
'reload_as_keras_model', True,
'Also test tf.keras.models.load_model() on --saved_model_dir.')
def main(argv):
del argv
# First convert the SavedModel in a pristine environment.
converter = tf.lite.TFLiteConverter.from_saved_model(FLAGS.saved_model_dir)
lite_model_content = converter.convert()
# Here is how you can save it for actual deployment.
if FLAGS.tflite_output_file:
with open(FLAGS.tflite_output_file, 'wb') as outfile:
outfile.write(lite_model_content)
# For testing, the TFLite model can be executed like this.
interpreter = tf.lite.Interpreter(model_content=lite_model_content)
def lite_model(images):
interpreter.allocate_tensors()
interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images)
interpreter.invoke()
return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
# Load the SavedModel again for use as a test baseline.
imported = tf.saved_model.load(FLAGS.saved_model_dir)
def tf_model(images):
output_dict = imported.signatures['serving_default'](tf.constant(images))
logits, = output_dict.values() # Unpack single value.
return logits
# Compare model outputs on the test inputs.
(_, _), (x_test, _) = mnist_util.load_reshaped_data(
use_fashion_mnist=FLAGS.use_fashion_mnist,
fake_tiny_data=FLAGS.fast_test_mode)
for i, x in enumerate(x_test):
x = x[None, ...] # Make batch of size 1.
y_lite = lite_model(x)
y_tf = tf_model(x)
# This numpy primitive uses plain `raise` and works outside tf.TestCase.
# Model outputs are probabilities that sum to 1, so atol makes sense here.
np.testing.assert_allclose(
y_lite, y_tf, rtol=0, atol=1e-5,
err_msg='Mismatch with TF Lite at test example %d' % i)
# Test that the SavedModel loads correctly with v1 load APIs as well.
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as session:
tf.compat.v1.saved_model.load(
session,
[tf.compat.v1.saved_model.SERVING],
FLAGS.saved_model_dir)
# The SavedModel actually was a Keras Model; test that it also loads as that.
if FLAGS.reload_as_keras_model:
keras_model = tf.keras.models.load_model(FLAGS.saved_model_dir)
for i, x in enumerate(x_test):
x = x[None, ...] # Make batch of size 1.
y_tf = tf_model(x)
y_keras = keras_model(x)
# This numpy primitive uses plain `raise` and works outside tf.TestCase.
# Model outputs are probabilities that sum to 1, so atol makes sense here.
np.testing.assert_allclose(
y_tf, y_keras, rtol=0, atol=1e-5,
err_msg='Mismatch with Keras at test example %d' % i)
if __name__ == '__main__':
app.run(main)

View File

@ -1,68 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to tf.distribute.strategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import sys
from tensorflow.python.distribute import strategy_combinations
_strategies = [
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_one_gpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.tpu_strategy,
]
# The presence of GPU strategies upsets TPU initialization,
# despite their test instances being skipped early. This is a workaround
# for b/145386854.
if "test_tpu" in sys.argv[0]:
_strategies = [s for s in _strategies if "GPU" not in str(s)]
named_strategies = collections.OrderedDict(
[(None, None)] +
[(str(s), s) for s in _strategies]
)
class MaybeDistributionScope(object):
"""Provides a context allowing no distribution strategy."""
@staticmethod
def from_name(name):
return MaybeDistributionScope(named_strategies[name].strategy if name
else None)
def __init__(self, distribution):
self._distribution = distribution
self._scope = None
def __enter__(self):
if self._distribution:
self._scope = self._distribution.scope()
self._scope.__enter__()
def __exit__(self, exc_type, value, traceback):
if self._distribution:
self._scope.__exit__(exc_type, value, traceback)
self._scope = None

View File

@ -1,204 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exports a convolutional feature extractor for MNIST in SavedModel format.
The feature extractor is a convolutional neural network plus a hidden layer
that gets trained as part of an MNIST classifier and then written to a
SavedModel (without the classification layer). From there, use_mnist_cnn.py
picks it up for transfer learning.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
from tensorflow.examples.saved_model.integration_tests import mnist_util
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
FLAGS = flags.FLAGS
flags.DEFINE_string(
'export_dir', None,
'Directory of exported SavedModel.')
flags.DEFINE_integer(
'epochs', 10,
'Number of epochs to train.')
flags.DEFINE_bool(
'use_keras_save_api', False,
'Uses tf.keras.models.save_model() on the feature extractor '
'instead of tf.saved_model.save() on a manually wrapped version. '
'With this, the exported model has no hparams.')
flags.DEFINE_bool(
'fast_test_mode', False,
'Shortcut training for running in unit tests.')
flags.DEFINE_bool(
'export_print_hparams', False,
'If true, the exported function will print its effective hparams.')
def make_feature_extractor(l2_strength, dropout_rate):
"""Returns a Keras Model to compute a feature vector from MNIST images."""
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
net = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', name='conv1',
kernel_regularizer=regularizer())(net)
net = tf.keras.layers.Conv2D(64, (3, 3), activation='relu', name='conv2',
kernel_regularizer=regularizer())(net)
net = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), name='pool1')(net)
net = tf.keras.layers.Dropout(dropout_rate, name='dropout1')(net)
net = tf.keras.layers.Flatten(name='flatten')(net)
net = tf.keras.layers.Dense(10, activation='relu', name='dense1',
kernel_regularizer=regularizer())(net)
return tf.keras.Model(inputs=inp, outputs=net)
def set_feature_extractor_hparams(model, dropout_rate):
model.get_layer('dropout1').rate = dropout_rate
def make_classifier(feature_extractor, l2_strength, dropout_rate=0.5):
"""Returns a Keras Model to classify MNIST using feature_extractor."""
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
net = feature_extractor(net)
net = tf.keras.layers.Dropout(dropout_rate)(net)
net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax',
kernel_regularizer=regularizer())(net)
return tf.keras.Model(inputs=inp, outputs=net)
def wrap_keras_model_for_export(model, batch_input_shape,
set_hparams, default_hparams):
"""Wraps `model` for saving and loading as SavedModel."""
# The primary input to the module is a Tensor with a batch of images.
# Here we determine its spec.
inputs_spec = tf.TensorSpec(shape=batch_input_shape, dtype=tf.float32)
# The module also accepts certain hparams as optional Tensor inputs.
# Here, we cut all the relevant slices from `default_hparams`
# (and don't worry if anyone accidentally modifies it later).
if default_hparams is None: default_hparams = {}
hparam_keys = list(default_hparams.keys())
hparam_defaults = tuple(default_hparams.values())
hparams_spec = {name: tf.TensorSpec.from_tensor(tf.constant(value))
for name, value in default_hparams.items()}
# The goal is to save a function with this argspec...
argspec = tf_inspect.FullArgSpec(
args=(['inputs', 'training'] + hparam_keys),
defaults=((False,) + hparam_defaults),
varargs=None, varkw=None,
kwonlyargs=[], kwonlydefaults=None,
annotations={})
# ...and this behavior:
def call_fn(inputs, training, *args):
if FLAGS.export_print_hparams:
args = [tf.keras.backend.print_tensor(args[i], 'training=%s and %s='
% (training, hparam_keys[i]))
for i in range(len(args))]
kwargs = dict(zip(hparam_keys, args))
if kwargs: set_hparams(model, **kwargs)
return model(inputs, training=training)
# We cannot spell out `args` in def statement for call_fn, but since
# tf.function uses tf_inspect, we can use tf_decorator to wrap it with
# the desired argspec.
def wrapped(*args, **kwargs): # TODO(arnoegw): Can we use call_fn itself?
return call_fn(*args, **kwargs)
traced_call_fn = tf.function(
tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
# Now we need to trigger traces for all supported combinations of the
# non-Tensor-value inputs.
for training in (True, False):
traced_call_fn.get_concrete_function(inputs_spec, training, **hparams_spec)
# Finally, we assemble the object for tf.saved_model.save().
obj = tf.train.Checkpoint()
obj.__call__ = traced_call_fn
obj.trainable_variables = model.trainable_variables
obj.variables = model.trainable_variables + model.non_trainable_variables
# Make tf.functions for the regularization terms of the loss.
obj.regularization_losses = [_get_traced_loss(model, i)
for i in range(len(model.losses))]
return obj
def _get_traced_loss(model, i):
"""Returns tf.function for model.losses[i] with a trace for zero args.
The intended usage is
[_get_traced_loss(model, i) for i in range(len(model.losses))]
This is better than
[tf.function(lambda: model.losses[i], input_signature=[]) for i ...]
because it avoids capturing a loop index in a lambda, and removes any
chance of deferring the trace.
Args:
model: a Keras Model.
i: an integer between from 0 up to but to len(model.losses).
"""
f = tf.function(lambda: model.losses[i])
_ = f.get_concrete_function()
return f
def main(argv):
del argv
# Build a complete classifier model using a feature extractor.
default_hparams = dict(dropout_rate=0.25)
l2_strength = 0.01 # Not a hparam for inputs -> outputs.
feature_extractor = make_feature_extractor(l2_strength=l2_strength,
**default_hparams)
classifier = make_classifier(feature_extractor, l2_strength=l2_strength)
# Train the complete model.
(x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
fake_tiny_data=FLAGS.fast_test_mode)
classifier.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(),
metrics=['accuracy'])
classifier.fit(x_train, y_train,
batch_size=128,
epochs=FLAGS.epochs,
verbose=1,
validation_data=(x_test, y_test))
# Save the feature extractor to a framework-agnostic SavedModel for reuse.
# Note that the feature_extractor object has not been compiled or fitted,
# so it does not contain an optimizer and related state.
if FLAGS.use_keras_save_api:
# Use Keras' built-in way of creating reusable SavedModels.
# This has no support for adjustable hparams at this time (July 2019).
# (We could also call tf.saved_model.save(feature_extractor, ...),
# point is we're passing a Keras model, not a plain Checkpoint.)
tf.keras.models.save_model(feature_extractor, FLAGS.export_dir)
else:
# Assemble a reusable SavedModel manually, with adjustable hparams.
exportable = wrap_keras_model_for_export(feature_extractor,
(None,) + mnist_util.INPUT_SHAPE,
set_feature_extractor_hparams,
default_hparams)
tf.saved_model.save(exportable, FLAGS.export_dir)
if __name__ == '__main__':
app.run(main)

View File

@ -1,63 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export an RNN cell in SavedModel format."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
def main(argv):
del argv
root = tf.train.Checkpoint()
# Create a cell and attach to our trackable.
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
# Wrap the rnn_cell.__call__ function and assign to next_state.
root.next_state = tf.function(root.rnn_cell.__call__)
# Wrap the rnn_cell.get_initial_function using a decorator and assign to an
# attribute with the same name.
@tf.function(input_signature=[tf.TensorSpec([None, None], tf.float32)])
def get_initial_state(tensor):
return root.rnn_cell.get_initial_state(tensor, None, None)
root.get_initial_state = get_initial_state
# Construct an initial_state, then call next_state explicitly to trigger a
# trace for serialization (we need an explicit call, because next_state has
# not been annotated with an input_signature).
initial_state = root.get_initial_state(
tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
root.next_state(
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
initial_state)
tf.saved_model.save(root, FLAGS.export_dir)
if __name__ == "__main__":
app.run(main)

View File

@ -1,105 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Text embedding model stored as a SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
def write_vocabulary_file(vocabulary):
"""Write temporary vocab file for module construction."""
tmpdir = tempfile.mkdtemp()
vocabulary_file = os.path.join(tmpdir, "tokens.txt")
with tf.io.gfile.GFile(vocabulary_file, "w") as f:
for entry in vocabulary:
f.write(entry + "\n")
return vocabulary_file
class TextEmbeddingModel(tf.train.Checkpoint):
"""Text embedding model.
A text embeddings model that takes a sentences on input and outputs the
sentence embedding.
"""
def __init__(self, vocabulary, emb_dim, oov_buckets):
super(TextEmbeddingModel, self).__init__()
self._oov_buckets = oov_buckets
self._total_size = len(vocabulary) + oov_buckets
# Assign the table initializer to this instance to ensure the asset
# it depends on is saved with the SavedModel.
self._table_initializer = tf.lookup.TextFileInitializer(
write_vocabulary_file(vocabulary), tf.string,
tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64,
tf.lookup.TextFileIndex.LINE_NUMBER)
self._table = tf.lookup.StaticVocabularyTable(
self._table_initializer, num_oov_buckets=self._oov_buckets)
self.embeddings = tf.Variable(
tf.random.uniform(shape=[self._total_size, emb_dim]))
self.variables = [self.embeddings]
self.trainable_variables = self.variables
def _tokenize(self, sentences):
# Perform a minimalistic text preprocessing by removing punctuation and
# splitting on spaces.
normalized_sentences = tf.strings.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
normalized_sentences = tf.reshape(normalized_sentences, [-1])
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
# Deal with a corner case: there is one empty sentence.
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
# Deal with a corner case: all sentences are empty.
sparse_tokens = tf.sparse.reset_shape(sparse_tokens)
sparse_token_ids = self._table.lookup(sparse_tokens.values)
return (sparse_tokens.indices, sparse_token_ids, sparse_tokens.dense_shape)
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)])
def __call__(self, sentences):
token_ids, token_values, token_dense_shape = self._tokenize(sentences)
return tf.nn.safe_embedding_lookup_sparse(
embedding_weights=self.embeddings,
sparse_ids=tf.sparse.SparseTensor(token_ids, token_values,
token_dense_shape),
sparse_weights=None,
combiner="sqrtn")
def main(argv):
del argv
vocabulary = ["cat", "is", "on", "the", "mat"]
module = TextEmbeddingModel(vocabulary=vocabulary, emb_dim=10, oov_buckets=10)
tf.saved_model.save(module, FLAGS.export_dir)
if __name__ == "__main__":
app.run(main)

View File

@ -1,193 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Text RNN model stored as a SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("export_dir", None, "Directory to export SavedModel.")
class TextRnnModel(tf.train.Checkpoint):
"""Text RNN model.
A full generative text RNN model that can train and decode sentences from a
starting word.
"""
def __init__(self, vocab, emb_dim, buckets, state_size):
super(TextRnnModel, self).__init__()
self._buckets = buckets
self._lstm_cell = tf.keras.layers.LSTMCell(units=state_size)
self._rnn_layer = tf.keras.layers.RNN(
self._lstm_cell, return_sequences=True)
self._embeddings = tf.Variable(tf.random.uniform(shape=[buckets, emb_dim]))
self._logit_layer = tf.keras.layers.Dense(buckets)
self._set_up_vocab(vocab)
def _tokenize(self, sentences):
# Perform a minimalistic text preprocessing by removing punctuation and
# splitting on spaces.
normalized_sentences = tf.strings.regex_replace(
input=sentences, pattern=r"\pP", rewrite="")
sparse_tokens = tf.strings.split(normalized_sentences, " ").to_sparse()
# Deal with a corner case: there is one empty sentence.
sparse_tokens, _ = tf.sparse.fill_empty_rows(sparse_tokens, tf.constant(""))
# Deal with a corner case: all sentences are empty.
sparse_tokens = tf.sparse.reset_shape(sparse_tokens)
return (sparse_tokens.indices, sparse_tokens.values,
sparse_tokens.dense_shape)
def _set_up_vocab(self, vocab_tokens):
# TODO(vbardiovsky): Currently there is no real vocabulary, because
# saved_model serialization does not support trackable resources. Add a real
# vocabulary when it does.
vocab_list = ["UNK"] * self._buckets
for vocab_token in vocab_tokens:
index = self._words_to_indices(vocab_token).numpy()
vocab_list[index] = vocab_token
# This is a variable representing an inverse index.
self._vocab_tensor = tf.Variable(vocab_list)
def _indices_to_words(self, indices):
return tf.gather(self._vocab_tensor, indices)
def _words_to_indices(self, words):
return tf.strings.to_hash_bucket(words, self._buckets)
@tf.function(input_signature=[tf.TensorSpec([None], tf.dtypes.string)])
def train(self, sentences):
token_ids, token_values, token_dense_shape = self._tokenize(sentences)
tokens_sparse = tf.sparse.SparseTensor(
indices=token_ids, values=token_values, dense_shape=token_dense_shape)
tokens = tf.sparse.to_dense(tokens_sparse, default_value="")
sparse_lookup_ids = tf.sparse.SparseTensor(
indices=tokens_sparse.indices,
values=self._words_to_indices(tokens_sparse.values),
dense_shape=tokens_sparse.dense_shape)
lookup_ids = tf.sparse.to_dense(sparse_lookup_ids, default_value=0)
# Targets are the next word for each word of the sentence.
tokens_ids_seq = lookup_ids[:, 0:-1]
tokens_ids_target = lookup_ids[:, 1:]
tokens_prefix = tokens[:, 0:-1]
# Mask determining which positions we care about for a loss: all positions
# that have a valid non-terminal token.
mask = tf.logical_and(
tf.logical_not(tf.equal(tokens_prefix, "")),
tf.logical_not(tf.equal(tokens_prefix, "<E>")))
input_mask = tf.cast(mask, tf.int32)
with tf.GradientTape() as t:
sentence_embeddings = tf.nn.embedding_lookup(self._embeddings,
tokens_ids_seq)
lstm_initial_state = self._lstm_cell.get_initial_state(
sentence_embeddings)
lstm_output = self._rnn_layer(
inputs=sentence_embeddings, initial_state=lstm_initial_state)
# Stack LSTM outputs into a batch instead of a 2D array.
lstm_output = tf.reshape(lstm_output, [-1, self._lstm_cell.output_size])
logits = self._logit_layer(lstm_output)
targets = tf.reshape(tokens_ids_target, [-1])
weights = tf.cast(tf.reshape(input_mask, [-1]), tf.float32)
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets, logits=logits)
# Final loss is the mean loss for all token losses.
final_loss = tf.math.divide(
tf.reduce_sum(tf.multiply(losses, weights)),
tf.reduce_sum(weights),
name="final_loss")
watched = t.watched_variables()
gradients = t.gradient(final_loss, watched)
for w, g in zip(watched, gradients):
w.assign_sub(g)
return final_loss
@tf.function
def decode_greedy(self, sequence_length, first_word):
initial_state = self._lstm_cell.get_initial_state(
dtype=tf.float32, batch_size=1)
sequence = [first_word]
current_word = first_word
current_id = tf.expand_dims(self._words_to_indices(current_word), 0)
current_state = initial_state
for _ in range(sequence_length):
token_embeddings = tf.nn.embedding_lookup(self._embeddings, current_id)
lstm_outputs, current_state = self._lstm_cell(token_embeddings,
current_state)
lstm_outputs = tf.reshape(lstm_outputs, [-1, self._lstm_cell.output_size])
logits = self._logit_layer(lstm_outputs)
softmax = tf.nn.softmax(logits)
next_ids = tf.math.argmax(softmax, axis=1)
next_words = self._indices_to_words(next_ids)[0]
current_id = next_ids
current_word = next_words
sequence.append(current_word)
return sequence
def main(argv):
del argv
sentences = ["<S> hello there <E>", "<S> how are you doing today <E>"]
vocab = [
"<S>", "<E>", "hello", "there", "how", "are", "you", "doing", "today"
]
module = TextRnnModel(vocab=vocab, emb_dim=10, buckets=100, state_size=128)
for _ in range(100):
_ = module.train(tf.constant(sentences))
# We have to call this function explicitly if we want it exported, because it
# has no input_signature in the @tf.function decorator.
decoded = module.decode_greedy(
sequence_length=10, first_word=tf.constant("<S>"))
_ = [d.numpy() for d in decoded]
tf.saved_model.save(module, FLAGS.export_dir)
if __name__ == "__main__":
app.run(main)

View File

@ -1,68 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility to write SavedModel integration tests.
SavedModel testing requires isolation between the process that creates and
consumes it. This file helps doing that by relaunching the same binary that
calls `assertCommandSucceeded` with an environment flag indicating what source
file to execute. That binary must start by calling `MaybeRunScriptInstead`.
This allows to wire this into existing building systems without having to depend
on data dependencies. And as so allow to keep a fixed binary size and allows
interop with GPU tests.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import os
import subprocess
import sys
from absl import app
import tensorflow.compat.v2 as tf
from tensorflow.python.platform import tf_logging as logging
class TestCase(tf.test.TestCase):
"""Base class to write SavedModel integration tests."""
def assertCommandSucceeded(self, script_name, **flags):
"""Runs an integration test script with given flags."""
run_script = sys.argv[0]
if run_script.endswith(".py"):
command_parts = [sys.executable, run_script]
else:
command_parts = [run_script]
command_parts.append("--alsologtostderr") # For visibility in sponge.
for flag_key, flag_value in flags.items():
command_parts.append("--%s=%s" % (flag_key, flag_value))
env = dict(TF2_BEHAVIOR="enabled", SCRIPT_NAME=script_name)
logging.info("Running %s with added environment variables %s" %
(command_parts, env))
subprocess.check_call(command_parts, env=dict(os.environ, **env))
def MaybeRunScriptInstead():
if "SCRIPT_NAME" in os.environ:
# Append current path to import path and execute `SCRIPT_NAME` main.
sys.path.extend([os.path.dirname(__file__)])
module_name = os.environ["SCRIPT_NAME"]
retval = app.run(importlib.import_module(module_name).main) # pylint: disable=assignment-from-no-return
sys.exit(retval)

View File

@ -1,51 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Convenience wrapper around Keras' MNIST and Fashion MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow.compat.v2 as tf
INPUT_SHAPE = (28, 28, 1)
NUM_CLASSES = 10
def _load_random_data(num_train_and_test):
return ((np.random.randint(0, 256, (num, 28, 28), dtype=np.uint8),
np.random.randint(0, 10, (num,), dtype=np.int64))
for num in num_train_and_test)
def load_reshaped_data(use_fashion_mnist=False, fake_tiny_data=False):
"""Returns MNIST or Fashion MNIST or fake train and test data."""
load = ((lambda: _load_random_data([128, 128])) if fake_tiny_data else
tf.keras.datasets.fashion_mnist.load_data if use_fashion_mnist else
tf.keras.datasets.mnist.load_data)
(x_train, y_train), (x_test, y_test) = load()
return ((_prepare_image(x_train), _prepare_label(y_train)),
(_prepare_image(x_test), _prepare_label(y_test)))
def _prepare_image(x):
"""Converts images to [n,h,w,c] format in range [0,1]."""
return x[..., None].astype('float32') / 255.
def _prepare_label(y):
"""Conerts labels to one-hot encoding."""
return tf.keras.utils.to_categorical(y, NUM_CLASSES)

View File

@ -1,133 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SavedModel integration tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import tensorflow.compat.v2 as tf
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
from tensorflow.examples.saved_model.integration_tests import integration_scripts as scripts
from tensorflow.python.distribute import combinations as distribute_combinations
from tensorflow.python.framework import combinations
class SavedModelTest(scripts.TestCase, parameterized.TestCase):
def __init__(self, method_name="runTest", has_extra_deps=False):
super(SavedModelTest, self).__init__(method_name)
self.has_extra_deps = has_extra_deps
def skipIfMissingExtraDeps(self):
"""Skip test if it requires extra dependencies.
b/132234211: The extra dependencies are not available in all environments
that run the tests, e.g. "tensorflow_hub" is not available from tests
within "tensorflow" alone. Those tests are instead run by another
internal test target.
"""
if not self.has_extra_deps:
self.skipTest("Missing extra dependencies")
def test_text_rnn(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded("export_text_rnn_model", export_dir=export_dir)
self.assertCommandSucceeded("use_text_rnn_model", model_dir=export_dir)
def test_rnn_cell(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded("export_rnn_cell", export_dir=export_dir)
self.assertCommandSucceeded("use_rnn_cell", model_dir=export_dir)
def test_text_embedding_in_sequential_keras(self):
self.skipIfMissingExtraDeps()
export_dir = self.get_temp_dir()
self.assertCommandSucceeded(
"export_simple_text_embedding", export_dir=export_dir)
self.assertCommandSucceeded(
"use_model_in_sequential_keras", model_dir=export_dir)
def test_text_embedding_in_dataset(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded(
"export_simple_text_embedding", export_dir=export_dir)
self.assertCommandSucceeded(
"use_text_embedding_in_dataset", model_dir=export_dir)
TEST_MNIST_CNN_GENERATE_KWARGS = dict(
combinations=(
combinations.combine(
# Test all combinations with tf.saved_model.save().
# Test all combinations using tf.keras.models.save_model()
# for both the reusable and the final full model.
use_keras_save_api=True,
named_strategy=list(ds_utils.named_strategies.values()),
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
) + combinations.combine(
# Test few critcial combinations with raw tf.saved_model.save(),
# including export of a reusable SavedModel that gets assembled
# manually, including support for adjustable hparams.
use_keras_save_api=False,
named_strategy=None,
retrain_flag_value=["true", "false"],
regularization_loss_multiplier=[None, 2], # Test for b/134528831.
)),
test_combinations=(distribute_combinations.GPUCombination(),
distribute_combinations.TPUCombination()))
@combinations.generate(**TEST_MNIST_CNN_GENERATE_KWARGS)
def test_mnist_cnn(self, use_keras_save_api, named_strategy,
retrain_flag_value, regularization_loss_multiplier):
self.skipIfMissingExtraDeps()
fast_test_mode = True
temp_dir = self.get_temp_dir()
feature_extrator_dir = os.path.join(temp_dir, "mnist_feature_extractor")
full_model_dir = os.path.join(temp_dir, "full_model")
self.assertCommandSucceeded(
"export_mnist_cnn",
fast_test_mode=fast_test_mode,
export_dir=feature_extrator_dir,
use_keras_save_api=use_keras_save_api)
use_kwargs = dict(fast_test_mode=fast_test_mode,
input_saved_model_dir=feature_extrator_dir,
retrain=retrain_flag_value,
output_saved_model_dir=full_model_dir,
use_keras_save_api=use_keras_save_api)
if named_strategy:
use_kwargs["strategy"] = str(named_strategy)
if regularization_loss_multiplier is not None:
use_kwargs[
"regularization_loss_multiplier"] = regularization_loss_multiplier
self.assertCommandSucceeded("use_mnist_cnn", **use_kwargs)
self.assertCommandSucceeded(
"deploy_mnist_cnn",
fast_test_mode=fast_test_mode,
saved_model_dir=full_model_dir)
if __name__ == "__main__":
scripts.MaybeRunScriptInstead()
tf.test.main()

View File

@ -1,147 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Imports a convolutional feature extractor for MNIST in SavedModel format.
This program picks up the SavedModel written by export_mnist_cnn.py and
uses the feature extractor contained in it to do classification on either
classic MNIST (digits) or Fashion MNIST (thumbnails of apparel). Optionally,
it trains the feature extractor further as part of the new classifier.
As expected, that makes training slower but does not help much for the
original training dataset but helps a lot for transfer to the other dataset.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
from tensorflow.examples.saved_model.integration_tests import distribution_strategy_utils as ds_utils
from tensorflow.examples.saved_model.integration_tests import mnist_util
FLAGS = flags.FLAGS
flags.DEFINE_string(
'input_saved_model_dir', None,
'Directory of the reusable SavedModel that is imported into this program.')
flags.DEFINE_integer(
'epochs', 5,
'Number of epochs to train.')
flags.DEFINE_bool(
'retrain', False,
'If set, the imported SavedModel is trained further.')
flags.DEFINE_float(
'dropout_rate', None,
'If set, dropout rate passed to the SavedModel. '
'Requires a SavedModel with support for adjustable hyperparameters.')
flags.DEFINE_float(
'regularization_loss_multiplier', None,
'If set, multiplier for the regularization losses in the SavedModel.')
flags.DEFINE_bool(
'use_fashion_mnist', False,
'Use Fashion MNIST (products) instead of the real MNIST (digits). '
'With this, --retrain gains a lot.')
flags.DEFINE_bool(
'fast_test_mode', False,
'Shortcut training for running in unit tests.')
flags.DEFINE_string(
'output_saved_model_dir', None,
'Directory of the SavedModel that was exported for reuse.')
flags.DEFINE_bool(
'use_keras_save_api', False,
'Uses tf.keras.models.save_model() instead of tf.saved_model.save().')
flags.DEFINE_string('strategy', None,
'Name of the distribution strategy to use.')
def make_feature_extractor(saved_model_path, trainable,
regularization_loss_multiplier):
"""Load a pre-trained feature extractor and wrap it for use in Keras."""
if regularization_loss_multiplier is not None:
# TODO(b/63257857): Scaling regularization losses requires manual loading
# and modification of the SavedModel
obj = tf.saved_model.load(saved_model_path)
def _scale_one_loss(l): # Separate def avoids lambda capture of loop var.
f = tf.function(lambda: tf.multiply(regularization_loss_multiplier, l()))
_ = f.get_concrete_function()
return f
obj.regularization_losses = [_scale_one_loss(l)
for l in obj.regularization_losses]
# The modified object is then passed to hub.KerasLayer instead of the
# string handle. That prevents it from saving a Keras config (b/134528831).
handle = obj
else:
# If possible, we exercise the more common case of passing a string handle
# such that hub.KerasLayer can save a Keras config (b/134528831).
handle = saved_model_path
arguments = {}
if FLAGS.dropout_rate is not None:
arguments['dropout_rate'] = FLAGS.dropout_rate
return hub.KerasLayer(handle, trainable=trainable, arguments=arguments)
def make_classifier(feature_extractor, l2_strength=0.01, dropout_rate=0.5):
"""Returns a Keras Model to classify MNIST using feature_extractor."""
regularizer = lambda: tf.keras.regularizers.l2(l2_strength)
net = inp = tf.keras.Input(mnist_util.INPUT_SHAPE)
net = feature_extractor(net)
if dropout_rate:
net = tf.keras.layers.Dropout(dropout_rate)(net)
net = tf.keras.layers.Dense(mnist_util.NUM_CLASSES, activation='softmax',
kernel_regularizer=regularizer())(net)
return tf.keras.Model(inputs=inp, outputs=net)
def main(argv):
del argv
with ds_utils.MaybeDistributionScope.from_name(FLAGS.strategy):
feature_extractor = make_feature_extractor(
FLAGS.input_saved_model_dir,
FLAGS.retrain,
FLAGS.regularization_loss_multiplier)
model = make_classifier(feature_extractor)
model.compile(loss=tf.keras.losses.categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(),
metrics=['accuracy'])
# Train the classifier (possibly on a different dataset).
(x_train, y_train), (x_test, y_test) = mnist_util.load_reshaped_data(
use_fashion_mnist=FLAGS.use_fashion_mnist,
fake_tiny_data=FLAGS.fast_test_mode)
print('Training on %s with %d trainable and %d untrainable variables.' %
('Fashion MNIST' if FLAGS.use_fashion_mnist else 'MNIST',
len(model.trainable_variables), len(model.non_trainable_variables)))
model.fit(x_train, y_train,
batch_size=128,
epochs=FLAGS.epochs,
verbose=1,
validation_data=(x_test, y_test))
if FLAGS.output_saved_model_dir:
if FLAGS.use_keras_save_api:
tf.keras.models.save_model(model, FLAGS.output_saved_model_dir)
else:
tf.saved_model.save(model, FLAGS.output_saved_model_dir)
if __name__ == '__main__':
app.run(main)

View File

@ -1,75 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Load and use text embedding module in sequential Keras."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub
FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
def train(fine_tuning):
"""Build a Keras model and train with mock data."""
features = np.array(["my first sentence", "my second sentence"])
labels = np.array([1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
module = tf.saved_model.load(FLAGS.model_dir)
# Create the sequential keras model.
l = tf.keras.layers
model = tf.keras.Sequential()
model.add(l.Reshape((), batch_input_shape=[None, 1], dtype=tf.string))
# TODO(b/124219898): output_shape should be optional.
model.add(hub.KerasLayer(module, output_shape=[10], trainable=fine_tuning))
model.add(l.Dense(100, activation="relu"))
model.add(l.Dense(50, activation="relu"))
model.add(l.Dense(1, activation="sigmoid"))
model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"],
# TODO(b/124446120): Remove after fixed.
run_eagerly=True)
model.fit_generator(generator=dataset.batch(1), epochs=5)
# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())
def main(argv):
del argv
train(fine_tuning=False)
train(fine_tuning=True)
if __name__ == "__main__":
app.run(main)

View File

@ -1,50 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Load and use an RNN cell stored as a SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
def main(argv):
del argv
cell = tf.saved_model.load(FLAGS.model_dir)
initial_state = cell.get_initial_state(
tf.constant(np.random.uniform(size=[3, 10]).astype(np.float32)))
cell.next_state(
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
initial_state)
# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(cell, tempfile.mkdtemp())
if __name__ == "__main__":
app.run(main)

View File

@ -1,73 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Load and use text embedding module in a Dataset map function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
def train():
"""Build a Keras model and train with mock data."""
module = tf.saved_model.load(FLAGS.model_dir)
def _map_fn(features, labels):
features = tf.expand_dims(features, 0)
features = module(features)
features = tf.squeeze(features, 0)
return features, labels
features = np.array(["my first sentence", "my second sentence"])
labels = np.array([1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn)
# Create the sequential keras model.
l = tf.keras.layers
model = tf.keras.Sequential()
model.add(l.Dense(10, activation="relu"))
model.add(l.Dense(1, activation="sigmoid"))
model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"])
model.fit_generator(generator=dataset.batch(10), epochs=5)
# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())
def main(argv):
del argv
train()
if __name__ == "__main__":
tf.enable_v2_behavior()
app.run(main)

View File

@ -1,50 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Load and use RNN model stored as a SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tempfile
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
def main(argv):
del argv
sentences = [
"<S> sentence <E>", "<S> second sentence <E>", "<S> third sentence<E>"
]
model = tf.saved_model.load(FLAGS.model_dir)
model.train(tf.constant(sentences))
decoded = model.decode_greedy(
sequence_length=10, first_word=tf.constant("<S>"))
_ = [d.numpy() for d in decoded]
# This is testing that a model using a SavedModel can be re-exported again,
# e.g. to catch issues such as b/142231881.
tf.saved_model.save(model, tempfile.mkdtemp())
if __name__ == "__main__":
app.run(main)

View File

@ -98,7 +98,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_srcs",
"//tensorflow/compiler/mlir/tensorflow:gen_mlir_passthrough_op_py",
"//tensorflow/core:protos_all_proto_srcs",
"//tensorflow/examples/saved_model/integration_tests:mnist_util",
"//tensorflow/lite/python/testdata:interpreter_test_data",
"//tensorflow/lite/python:tflite_convert",
"//tensorflow/lite/toco/python:toco_from_protos",