Remove keras dependency from TFLite tests and BUILD files
PiperOrigin-RevId: 340033160 Change-Id: Ia7007ddb1bcd128c684712eafa284af79fedac05
This commit is contained in:
parent
89b078797d
commit
d3bcfc947d
@ -763,6 +763,7 @@ def gen_model_coverage_test(src, model_name, data, failure_type, tags, size = "m
|
||||
"no_windows",
|
||||
] + tags + coverage_tags,
|
||||
deps = [
|
||||
"//third_party/py/tensorflow",
|
||||
"//tensorflow/lite/testing/model_coverage:model_coverage_lib",
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -125,7 +125,6 @@ py_test(
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model:save",
|
||||
"//tensorflow/python/training:training_util",
|
||||
@ -151,7 +150,6 @@ py_library(
|
||||
"//tensorflow/lite/experimental/tensorboard:ops_util",
|
||||
"//tensorflow/lite/python/optimize:calibrator",
|
||||
"//tensorflow/python:graph_util",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/saved_model:constants",
|
||||
"//tensorflow/python/saved_model:loader",
|
||||
"@six_archive//:six",
|
||||
@ -170,6 +168,7 @@ py_test(
|
||||
],
|
||||
deps = [
|
||||
":lite",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"@six_archive//:six",
|
||||
@ -392,7 +391,6 @@ py_test(
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python/keras",
|
||||
"//tensorflow/python/ops/losses",
|
||||
"//tensorflow/python/saved_model",
|
||||
],
|
||||
|
@ -28,13 +28,13 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
from tensorflow import keras
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.python import lite_constants
|
||||
from tensorflow.lite.python.convert import ConverterError
|
||||
from tensorflow.lite.python.convert import mlir_quantize
|
||||
from tensorflow.lite.python.interpreter import Interpreter
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
|
@ -35,8 +35,6 @@ from tensorflow.lite.toco import types_pb2 as _types_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.layers import recurrent
|
||||
from tensorflow.python.keras.layers import recurrent_v2
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save_options
|
||||
@ -1147,9 +1145,9 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
expected = expected.c
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@parameterized.named_parameters(('LSTM', recurrent_v2.LSTM),
|
||||
('SimpleRNN', recurrent.SimpleRNN),
|
||||
('GRU', recurrent_v2.GRU))
|
||||
@parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
|
||||
('SimpleRNN', tf.keras.layers.SimpleRNN),
|
||||
('GRU', tf.keras.layers.GRU))
|
||||
@test_util.run_v2_only
|
||||
def testKerasRNN(self, rnn_layer):
|
||||
# This relies on TFLiteConverter to rewrite unknown batch size to 1. The
|
||||
@ -1171,9 +1169,9 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
expected_value = model.predict(input_data)
|
||||
self.assertAllClose(expected_value, actual_value, atol=1e-05)
|
||||
|
||||
@parameterized.named_parameters(('LSTM', recurrent_v2.LSTM),
|
||||
('SimpleRNN', recurrent.SimpleRNN),
|
||||
('GRU', recurrent_v2.GRU))
|
||||
@parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
|
||||
('SimpleRNN', tf.keras.layers.SimpleRNN),
|
||||
('GRU', tf.keras.layers.GRU))
|
||||
@test_util.run_v2_only
|
||||
def testKerasRNNMultiBatches(self, rnn_layer):
|
||||
input_data = tf.constant(
|
||||
@ -1200,7 +1198,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
model.add(tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input'))
|
||||
model.add(
|
||||
tf.keras.layers.Bidirectional(
|
||||
recurrent_v2.LSTM(units=10, return_sequences=True),
|
||||
tf.keras.layers.LSTM(units=10, return_sequences=True),
|
||||
input_shape=(10, 10)))
|
||||
model.add(tf.keras.layers.Flatten())
|
||||
model.add(tf.keras.layers.Dense(5))
|
||||
@ -1221,7 +1219,7 @@ class ControlFlowTest(lite_v2_test_util.ModelTest):
|
||||
np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
|
||||
model = tf.keras.models.Sequential()
|
||||
model.add(tf.keras.layers.Input(batch_size=1, shape=(10, 10), name='input'))
|
||||
model.add(tf.keras.layers.Bidirectional(recurrent_v2.LSTM(units=10)))
|
||||
model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=10)))
|
||||
model.add(tf.keras.layers.Dense(5))
|
||||
model.add(tf.keras.layers.Activation('softmax'))
|
||||
|
||||
|
@ -10,6 +10,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
|
||||
import numpy as np
|
||||
from six import PY2
|
||||
from tensorflow import keras
|
||||
|
||||
from google.protobuf import text_format as _text_format
|
||||
from google.protobuf.message import DecodeError
|
||||
@ -29,13 +30,11 @@ from tensorflow.core.framework import graph_pb2 as _graph_pb2
|
||||
from tensorflow.lite.python import convert_saved_model as _convert_saved_model
|
||||
from tensorflow.lite.python import lite as _lite
|
||||
from tensorflow.lite.python import util as _util
|
||||
from tensorflow.python import keras as _keras
|
||||
from tensorflow.python.client import session as _session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
|
||||
from tensorflow.python.keras.preprocessing import image
|
||||
from tensorflow.python.lib.io import file_io as _file_io
|
||||
from tensorflow.python.platform import resource_loader as _resource_loader
|
||||
from tensorflow.python.saved_model import load as _load
|
||||
@ -71,8 +70,9 @@ def get_image(size):
|
||||
"""
|
||||
img_filename = _resource_loader.get_path_to_datafile(
|
||||
"testdata/grace_hopper.jpg")
|
||||
img = image.load_img(img_filename, target_size=(size, size))
|
||||
img_array = image.img_to_array(img)
|
||||
img = keras.preprocessing.image.load_img(
|
||||
img_filename, target_size=(size, size))
|
||||
img_array = keras.preprocessing.image.img_to_array(img)
|
||||
img_array = np.expand_dims(img_array, axis=0)
|
||||
return img_array
|
||||
|
||||
@ -342,7 +342,7 @@ def evaluate_keras_model(filename):
|
||||
Returns:
|
||||
Lambda function ([np.ndarray data] : [np.ndarray result]).
|
||||
"""
|
||||
keras_model = _keras.models.load_model(filename)
|
||||
keras_model = keras.models.load_model(filename)
|
||||
return lambda input_data: [keras_model.predict(input_data)]
|
||||
|
||||
|
||||
@ -740,7 +740,7 @@ def test_keras_model_v2(filename,
|
||||
(half-inclusive). (default None)
|
||||
**kwargs: Additional arguments to be passed into the converter.
|
||||
"""
|
||||
keras_model = _keras.models.load_model(filename)
|
||||
keras_model = keras.models.load_model(filename)
|
||||
if input_shapes:
|
||||
for tensor, shape in zip(keras_model.inputs, input_shapes):
|
||||
tensor.set_shape(shape)
|
||||
|
@ -22,10 +22,10 @@ import os
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from tensorflow import keras
|
||||
|
||||
from tensorflow.lite.python import lite
|
||||
from tensorflow.lite.testing.model_coverage import model_coverage_lib as model_coverage
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
|
Loading…
x
Reference in New Issue
Block a user