Integration test for use of a text embedding saved model in a dataset map function.
PiperOrigin-RevId: 241487445
This commit is contained in:
parent
729e994ec0
commit
b7465d7c7f
@ -17,6 +17,7 @@ py_binary(
|
|||||||
"use_mnist_cnn.py",
|
"use_mnist_cnn.py",
|
||||||
"use_model_in_sequential_keras.py",
|
"use_model_in_sequential_keras.py",
|
||||||
"use_rnn_cell.py",
|
"use_rnn_cell.py",
|
||||||
|
"use_text_embedding_in_dataset.py",
|
||||||
"use_text_rnn_model.py",
|
"use_text_rnn_model.py",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
|
@ -56,6 +56,13 @@ class SavedModelTest(tf.test.TestCase):
|
|||||||
self.assertCommandSucceeded(
|
self.assertCommandSucceeded(
|
||||||
"use_model_in_sequential_keras", model_dir=export_dir)
|
"use_model_in_sequential_keras", model_dir=export_dir)
|
||||||
|
|
||||||
|
def test_text_embedding_in_dataset(self):
|
||||||
|
export_dir = self.get_temp_dir()
|
||||||
|
self.assertCommandSucceeded(
|
||||||
|
"export_simple_text_embedding", export_dir=export_dir)
|
||||||
|
self.assertCommandSucceeded(
|
||||||
|
"use_text_embedding_in_dataset", model_dir=export_dir)
|
||||||
|
|
||||||
def test_mnist_cnn(self):
|
def test_mnist_cnn(self):
|
||||||
export_dir = self.get_temp_dir()
|
export_dir = self.get_temp_dir()
|
||||||
self.assertCommandSucceeded(
|
self.assertCommandSucceeded(
|
||||||
|
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Load and use text embedding module in a Dataset map function."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl import app
|
||||||
|
from absl import flags
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow.compat.v2 as tf
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
flags.DEFINE_string("model_dir", None, "Directory to load SavedModel from.")
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
"""Build a Keras model and train with mock data."""
|
||||||
|
module = tf.saved_model.load(FLAGS.model_dir)
|
||||||
|
def _map_fn(features, labels):
|
||||||
|
features = tf.expand_dims(features, 0)
|
||||||
|
features = module(features)
|
||||||
|
features = tf.squeeze(features, 0)
|
||||||
|
return features, labels
|
||||||
|
|
||||||
|
features = np.array(["my first sentence", "my second sentence"])
|
||||||
|
labels = np.array([1, 0])
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).map(_map_fn)
|
||||||
|
|
||||||
|
# Create the sequential keras model.
|
||||||
|
l = tf.keras.layers
|
||||||
|
model = tf.keras.Sequential()
|
||||||
|
model.add(l.Dense(10, activation="relu"))
|
||||||
|
model.add(l.Dense(1, activation="sigmoid"))
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
optimizer="adam",
|
||||||
|
loss="binary_crossentropy",
|
||||||
|
metrics=["accuracy"])
|
||||||
|
|
||||||
|
model.fit_generator(generator=dataset.batch(10), epochs=5)
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv):
|
||||||
|
del argv
|
||||||
|
|
||||||
|
train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.enable_v2_behavior()
|
||||||
|
app.run(main)
|
Loading…
Reference in New Issue
Block a user