Integration test for use of a text embedding saved model in a dataset map function.
PiperOrigin-RevId: 241487445
This commit is contained in:
parent
729e994ec0
commit
b7465d7c7f
@ -17,6 +17,7 @@ py_binary(
|
||||
"use_mnist_cnn.py",
|
||||
"use_model_in_sequential_keras.py",
|
||||
"use_rnn_cell.py",
|
||||
"use_text_embedding_in_dataset.py",
|
||||
"use_text_rnn_model.py",
|
||||
],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
|
@ -56,6 +56,13 @@ class SavedModelTest(tf.test.TestCase):
|
||||
self.assertCommandSucceeded(
|
||||
"use_model_in_sequential_keras", model_dir=export_dir)
|
||||
|
||||
def test_text_embedding_in_dataset(self):
|
||||
export_dir = self.get_temp_dir()
|
||||
self.assertCommandSucceeded(
|
||||
"export_simple_text_embedding", export_dir=export_dir)
|
||||
self.assertCommandSucceeded(
|
||||
"use_text_embedding_in_dataset", model_dir=export_dir)
|
||||
|
||||
def test_mnist_cnn(self):
|
||||
export_dir = self.get_temp_dir()
|
||||
self.assertCommandSucceeded(
|
||||
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Load and use text embedding module in a Dataset map function."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
||||
|
||||
|
||||
def train():
|
||||
"""Build a Keras model and train with mock data."""
|
||||
module = tf.saved_model.load(FLAGS.model_dir)
|
||||
def _map_fn(features, labels):
|
||||
features = tf.expand_dims(features, 0)
|
||||
features = module(features)
|
||||
features = tf.squeeze(features, 0)
|
||||
return features, labels
|
||||
|
||||
features = np.array(["my first sentence", "my second sentence"])
|
||||
labels = np.array([1, 0])
|
||||
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn)
|
||||
|
||||
# Create the sequential keras model.
|
||||
l = tf.keras.layers
|
||||
model = tf.keras.Sequential()
|
||||
model.add(l.Dense(10, activation="relu"))
|
||||
model.add(l.Dense(1, activation="sigmoid"))
|
||||
|
||||
model.compile(
|
||||
optimizer="adam",
|
||||
loss="binary_crossentropy",
|
||||
metrics=["accuracy"])
|
||||
|
||||
model.fit_generator(generator=dataset.batch(10), epochs=5)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.enable_v2_behavior()
|
||||
app.run(main)
|
Loading…
Reference in New Issue
Block a user