Integration test for use of a text embedding saved model in a dataset map function.

PiperOrigin-RevId: 241487445
This commit is contained in:
Jeremiah Harmsen 2019-04-02 02:16:14 -07:00 committed by TensorFlower Gardener
parent 729e994ec0
commit b7465d7c7f
3 changed files with 75 additions and 0 deletions

View File

@ -17,6 +17,7 @@ py_binary(
"use_mnist_cnn.py",
"use_model_in_sequential_keras.py",
"use_rnn_cell.py",
"use_text_embedding_in_dataset.py",
"use_text_rnn_model.py",
],
visibility = ["//tensorflow:internal"],

View File

@ -56,6 +56,13 @@ class SavedModelTest(tf.test.TestCase):
self.assertCommandSucceeded(
"use_model_in_sequential_keras", model_dir=export_dir)
def test_text_embedding_in_dataset(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded(
"export_simple_text_embedding", export_dir=export_dir)
self.assertCommandSucceeded(
"use_text_embedding_in_dataset", model_dir=export_dir)
def test_mnist_cnn(self):
export_dir = self.get_temp_dir()
self.assertCommandSucceeded(

View File

@ -0,0 +1,67 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Load and use text embedding module in a Dataset map function."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import app
from absl import flags
import numpy as np
import tensorflow.compat.v2 as tf
FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
def train():
"""Build a Keras model and train with mock data."""
module = tf.saved_model.load(FLAGS.model_dir)
def _map_fn(features, labels):
features = tf.expand_dims(features, 0)
features = module(features)
features = tf.squeeze(features, 0)
return features, labels
features = np.array(["my first sentence", "my second sentence"])
labels = np.array([1, 0])
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn)
# Create the sequential keras model.
l = tf.keras.layers
model = tf.keras.Sequential()
model.add(l.Dense(10, activation="relu"))
model.add(l.Dense(1, activation="sigmoid"))
model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"])
model.fit_generator(generator=dataset.batch(10), epochs=5)
def main(argv):
del argv
train()
if __name__ == "__main__":
tf.enable_v2_behavior()
app.run(main)