use predict so test works for non-egaer execution

This commit is contained in:
Daniel Falbel 2019-12-09 15:23:02 -03:00
parent bd12651d73
commit c655cf70a7
No known key found for this signature in database
GPG Key ID: 86D67393B1F8D380

View File

@ -476,7 +476,11 @@ class TextVectorizationPreprocessingTest(
standardize = "".join(["lower", "_and_strip_punctuation"])
layer = get_layer_class()(standardize=standardize)
output = layer(input_array).numpy()
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
output_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=output_data)
output = model.predict(input_array)
self.assertAllEqual(expected_output, output)
@ -486,7 +490,11 @@ class TextVectorizationPreprocessingTest(
split = "".join(["white", "space"])
layer = get_layer_class()(split=split)
output = layer(input_array).numpy()
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
output_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=output_data)
output = model.predict(input_array)
self.assertAllEqual(expected_output, output)