Explicitly specify input array dtype to TensorFlow's assertAllEqual() test util function.

PiperOrigin-RevId: 313616979
Change-Id: Id3aabf89bb7a05e1f338fb05b20da3a0848a0440
This commit is contained in:
Hye Soo Yang 2020-05-28 10:44:42 -07:00 committed by TensorFlower Gardener
parent f1e137db12
commit 393e92ae5f

View File

@ -29,6 +29,7 @@ from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import gen_stateful_random_ops
from tensorflow.python.ops import image_ops_impl as image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import test
@ -1114,7 +1115,10 @@ class RandomHeightTest(keras_parameterized.TestCase):
with tf_test_util.use_gpu():
input_image = np.reshape(np.arange(0, 6), (2, 3, 1)).astype(dtype)
layer = image_preprocessing.RandomHeight(factor=(1., 1.))
output_image = layer(np.expand_dims(input_image, axis=0))
# Return type of RandomHeight() is float32 if `interpolation` is not
# set to `ResizeMethod.NEAREST_NEIGHBOR`; cast `layer` to desired dtype.
output_image = math_ops.cast(layer(np.expand_dims(input_image, axis=0)),
dtype=dtype)
# pyformat: disable
expected_output = np.asarray([
[0, 1, 2],
@ -1202,7 +1206,10 @@ class RandomWidthTest(keras_parameterized.TestCase):
with tf_test_util.use_gpu():
input_image = np.reshape(np.arange(0, 6), (3, 2, 1)).astype(dtype)
layer = image_preprocessing.RandomWidth(factor=(1., 1.))
output_image = layer(np.expand_dims(input_image, axis=0))
# Return type of RandomWidth() is float32 if `interpolation` is not
# set to `ResizeMethod.NEAREST_NEIGHBOR`; cast `layer` to desired dtype.
output_image = math_ops.cast(layer(np.expand_dims(input_image, axis=0)),
dtype=dtype)
# pyformat: disable
expected_output = np.asarray([
[0, 0.25, 0.75, 1],