use predict so test works for non-egaer execution
This commit is contained in:
parent
bd12651d73
commit
c655cf70a7
@ -476,7 +476,11 @@ class TextVectorizationPreprocessingTest(
|
||||
|
||||
standardize = "".join(["lower", "_and_strip_punctuation"])
|
||||
layer = get_layer_class()(standardize=standardize)
|
||||
output = layer(input_array).numpy()
|
||||
|
||||
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
|
||||
output_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=output_data)
|
||||
output = model.predict(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
@ -486,7 +490,11 @@ class TextVectorizationPreprocessingTest(
|
||||
|
||||
split = "".join(["white", "space"])
|
||||
layer = get_layer_class()(split=split)
|
||||
output = layer(input_array).numpy()
|
||||
|
||||
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
|
||||
output_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=output_data)
|
||||
output = model.predict(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user