Explicitly specify input array dtype
to TensorFlow's assertAllEqual()
test util function.
PiperOrigin-RevId: 313616979 Change-Id: Id3aabf89bb7a05e1f338fb05b20da3a0848a0440
This commit is contained in:
parent
f1e137db12
commit
393e92ae5f
@ -29,6 +29,7 @@ from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
from tensorflow.python.ops import image_ops_impl as image_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import stateless_random_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -1114,7 +1115,10 @@ class RandomHeightTest(keras_parameterized.TestCase):
|
||||
with tf_test_util.use_gpu():
|
||||
input_image = np.reshape(np.arange(0, 6), (2, 3, 1)).astype(dtype)
|
||||
layer = image_preprocessing.RandomHeight(factor=(1., 1.))
|
||||
output_image = layer(np.expand_dims(input_image, axis=0))
|
||||
# Return type of RandomHeight() is float32 if `interpolation` is not
|
||||
# set to `ResizeMethod.NEAREST_NEIGHBOR`; cast `layer` to desired dtype.
|
||||
output_image = math_ops.cast(layer(np.expand_dims(input_image, axis=0)),
|
||||
dtype=dtype)
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray([
|
||||
[0, 1, 2],
|
||||
@ -1202,7 +1206,10 @@ class RandomWidthTest(keras_parameterized.TestCase):
|
||||
with tf_test_util.use_gpu():
|
||||
input_image = np.reshape(np.arange(0, 6), (3, 2, 1)).astype(dtype)
|
||||
layer = image_preprocessing.RandomWidth(factor=(1., 1.))
|
||||
output_image = layer(np.expand_dims(input_image, axis=0))
|
||||
# Return type of RandomWidth() is float32 if `interpolation` is not
|
||||
# set to `ResizeMethod.NEAREST_NEIGHBOR`; cast `layer` to desired dtype.
|
||||
output_image = math_ops.cast(layer(np.expand_dims(input_image, axis=0)),
|
||||
dtype=dtype)
|
||||
# pyformat: disable
|
||||
expected_output = np.asarray([
|
||||
[0, 0.25, 0.75, 1],
|
||||
|
Loading…
Reference in New Issue
Block a user