diff --git a/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc b/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc index c8dfe754060..82ed8b892de 100644 --- a/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc +++ b/tensorflow/core/kernels/image/resize_bilinear_op_gpu.cu.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/image/resize_bilinear_op.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/gpu_kernel_helper.h" namespace tensorflow { @@ -228,6 +229,56 @@ __global__ void ResizeBilinearGradKernel(const int32 nthreads, } } +template +__global__ void ResizeBilinearDeterministicGradKernel( + const int32 nthreads, const float* __restrict__ input_grad, + float height_scale, float inverse_height_scale, float width_scale, + float inverse_width_scale, int batch, int original_height, + int original_width, int channels, int resized_height, int resized_width, + T* __restrict__ output_grad) { + GPU_1D_KERNEL_LOOP(out_idx, nthreads) { + // out_idx = c + channels * (x + original_width * (y + original_height * b)) + int idx = out_idx; + const int c = idx % channels; + idx /= channels; + const int out_x_center = idx % original_width; + idx /= original_width; + const int out_y_center = idx % original_height; + const int b = idx / original_height; + + int in_y_start = max(0, __float2int_ru( + (out_y_center - 1 + 0.5) * inverse_height_scale - 0.5)); + const float out_y_start = (in_y_start + 0.5) * height_scale - 0.5; + int in_x_start = max(0, __float2int_ru( + (out_x_center - 1 + 0.5) * inverse_width_scale - 0.5)); + const float out_x_start = (in_x_start + 0.5) * width_scale - 0.5; + T acc = 0; + // For clarity, prior to C++17, while loops are preferable to for loops here + float out_y = out_y_start; int in_y = in_y_start; + while(out_y < out_y_center + 1 && in_y < resized_height) { + float out_x = out_x_start; int in_x = in_x_start; + while(out_x < out_x_center + 1 && in_x < resized_width) { + int in_idx = ((b * resized_height + in_y) * resized_width + in_x) * + channels + c; + // Clamping to zero is necessary because out_x and out_y can be negative + // due to half-pixel adjustments to out_y_start and out_x_start. + // Clamping to height/width is necessary when upscaling. + float out_y_clamped = fmaxf(0, fminf(out_y, original_height - 1)); + float out_x_clamped = fmaxf(0, fminf(out_x, original_width - 1)); + float y_lerp = (1 - fabsf(out_y_clamped - out_y_center)); + float x_lerp = (1 - fabsf(out_x_clamped - out_x_center)); + acc += static_cast(input_grad[in_idx] * y_lerp * x_lerp); + out_x += width_scale; + in_x++; + } + out_y += height_scale; + in_y++; + } + output_grad[out_idx] = acc; + + } +} + template __global__ void LegacyResizeBilinearKernel( const int32 nthreads, const T* __restrict__ images, float height_scale, @@ -338,6 +389,55 @@ __global__ void LegacyResizeBilinearGradKernel( } } +template +__global__ void LegacyResizeBilinearDeterministicGradKernel( + const int32 nthreads, const float* __restrict__ input_grad, + float height_scale, float inverse_height_scale, float width_scale, + float inverse_width_scale, int batch, int original_height, + int original_width, int channels, int resized_height, int resized_width, + T* __restrict__ output_grad) { + GPU_1D_KERNEL_LOOP(out_idx, nthreads) { + // out_idx = c + channels * (x + original_width * (y + original_height * b)) + int idx = out_idx; + const int c = idx % channels; + idx /= channels; + const int out_x_center = idx % original_width; + idx /= original_width; + const int out_y_center = idx % original_height; + const int b = idx / original_height; + + int in_y_start = max(0, __float2int_ru( + (out_y_center - 1) * inverse_height_scale)); + const float out_y_start = in_y_start * height_scale; + int in_x_start = max(0, __float2int_ru( + (out_x_center - 1) * inverse_width_scale)); + const float out_x_start = in_x_start * width_scale; + T acc = 0; + // For clarity, prior to C++17, while loops are preferable to for loops here + float out_y = out_y_start; int in_y = in_y_start; + while(out_y < out_y_center + 1 && in_y < resized_height) { + float out_x = out_x_start; int in_x = in_x_start; + while(out_x < out_x_center + 1 && in_x < resized_width) { + int in_idx = ((b * resized_height + in_y) * resized_width + in_x) * + channels + c; + // Clamping to zero is unnecessary because out_x and out_y will never + // be less than zero in legacy mode. + // Clamping to height/width is necessary when upscaling. + float out_y_clamped = fminf(out_y, original_height - 1); + float out_x_clamped = fminf(out_x, original_width - 1); + float y_lerp = (1 - fabsf(out_y_clamped - out_y_center)); + float x_lerp = (1 - fabsf(out_x_clamped - out_x_center)); + acc += static_cast(input_grad[in_idx] * y_lerp * x_lerp); + out_x += width_scale; + in_x++; + } + out_y += height_scale; + in_y++; + } + output_grad[out_idx] = acc; + } +} + } // namespace namespace functor { @@ -394,6 +494,17 @@ struct ResizeBilinear { } }; +bool RequireDeterminism() { + static bool require_determinism = [] { + bool deterministic_ops = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS", + /*default_val=*/false, + &deterministic_ops)); + return deterministic_ops; + }(); + return require_determinism; +} + // Partial specialization of ResizeBilinearGrad functor for a GPUDevice. template struct ResizeBilinearGrad { @@ -413,31 +524,53 @@ struct ResizeBilinearGrad { int total_count; GpuLaunchConfig config; - // Initialize output_grad with all zeros. total_count = batch * original_height * original_width * channels; if (total_count == 0) return; config = GetGpuLaunchConfig(total_count, d); - TF_CHECK_OK(GpuLaunchKernel( - SetZero, config.block_count, config.thread_per_block, 0, d.stream(), - config.virtual_thread_count, output_grad.data())); - // Accumulate. - total_count = batch * resized_height * resized_width * channels; - config = GetGpuLaunchConfig(total_count, d); - if (half_pixel_centers) { - TF_CHECK_OK(GpuLaunchKernel( - ResizeBilinearGradKernel, config.block_count, - config.thread_per_block, 0, d.stream(), config.virtual_thread_count, - input_grad.data(), height_scale, width_scale, batch, original_height, - original_width, channels, resized_height, resized_width, - output_grad.data())); + if (RequireDeterminism()) { + // The following scale values below should never be zero, enforced by + // ImageResizerGradientState + float inverse_height_scale = 1 / height_scale; + float inverse_width_scale = 1 / width_scale; + if (half_pixel_centers) { + TF_CHECK_OK(GpuLaunchKernel( + ResizeBilinearDeterministicGradKernel, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + input_grad.data(), height_scale, inverse_height_scale, width_scale, + inverse_width_scale, batch, original_height, original_width, + channels, resized_height, resized_width, output_grad.data())); + } else { + TF_CHECK_OK(GpuLaunchKernel( + LegacyResizeBilinearDeterministicGradKernel, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + input_grad.data(), height_scale, inverse_height_scale, width_scale, + inverse_width_scale, batch, original_height, original_width, + channels, resized_height, resized_width, output_grad.data())); + } } else { + // Initialize output_grad with all zeros. TF_CHECK_OK(GpuLaunchKernel( - LegacyResizeBilinearGradKernel, config.block_count, - config.thread_per_block, 0, d.stream(), config.virtual_thread_count, - input_grad.data(), height_scale, width_scale, batch, original_height, - original_width, channels, resized_height, resized_width, - output_grad.data())); + SetZero, config.block_count, config.thread_per_block, 0, + d.stream(), config.virtual_thread_count, output_grad.data())); + // Accumulate. + total_count = batch * resized_height * resized_width * channels; + config = GetGpuLaunchConfig(total_count, d); + if (half_pixel_centers) { + TF_CHECK_OK(GpuLaunchKernel( + ResizeBilinearGradKernel, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + input_grad.data(), height_scale, width_scale, batch, + original_height, original_width, channels, resized_height, + resized_width, output_grad.data())); + } else { + TF_CHECK_OK(GpuLaunchKernel( + LegacyResizeBilinearGradKernel, config.block_count, + config.thread_per_block, 0, d.stream(), config.virtual_thread_count, + input_grad.data(), height_scale, width_scale, batch, + original_height, original_width, channels, resized_height, + resized_width, output_grad.data())); + } } } }; diff --git a/tensorflow/core/util/image_resizer_state.h b/tensorflow/core/util/image_resizer_state.h index b302021918d..db201c3ba2d 100644 --- a/tensorflow/core/util/image_resizer_state.h +++ b/tensorflow/core/util/image_resizer_state.h @@ -192,6 +192,19 @@ struct ImageResizerGradientState { original_height = original_image.dim_size(1); original_width = original_image.dim_size(2); + // The following check is also carried out for the forward op. It is added + // here to prevent a divide-by-zero exception when either height_scale or + // width_scale is being calculated. + OP_REQUIRES(context, resized_height > 0 && resized_width > 0, + errors::InvalidArgument("resized dimensions must be positive")); + + // The following check is also carried out for the forward op. It is added + // here to prevent either height_scale or width_scale from being set to + // zero, which would cause a divide-by-zero exception in the deterministic + // back-prop path. + OP_REQUIRES(context, original_height > 0 && original_width > 0, + errors::InvalidArgument("original dimensions must be positive")); + OP_REQUIRES( context, FastBoundsCheck(original_height, std::numeric_limits::max()) && diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 65c77748c84..f3b1701a5e8 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -5174,12 +5174,30 @@ cuda_py_test( ], ) +cuda_py_test( + name = "image_grad_deterministic_test", + size = "medium", + srcs = ["ops/image_grad_deterministic_test.py"], + python_version = "PY3", + deps = [ + ":image_grad_test_base", + ], +) + cuda_py_test( name = "image_grad_test", size = "medium", srcs = ["ops/image_grad_test.py"], python_version = "PY3", tfrt_enabled = True, + deps = [ + ":image_grad_test_base", + ], +) + +py_library( + name = "image_grad_test_base", + srcs = ["ops/image_grad_test_base.py"], deps = [ ":client_testlib", ":framework_for_generated_wrappers", diff --git a/tensorflow/python/ops/image_grad_deterministic_test.py b/tensorflow/python/ops/image_grad_deterministic_test.py new file mode 100644 index 00000000000..5866114025f --- /dev/null +++ b/tensorflow/python/ops/image_grad_deterministic_test.py @@ -0,0 +1,123 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for deterministic image op gradient functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from absl.testing import parameterized + +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import image_grad_test_base as test_base +from tensorflow.python.ops import image_ops +from tensorflow.python.platform import test + + +class ResizeBilinearOpDeterministicTest(test_base.ResizeBilinearOpTestBase): + + def _randomNDArray(self, shape): + return 2 * np.random.random_sample(shape) - 1 + + def _randomDataOp(self, shape, data_type): + return constant_op.constant(self._randomNDArray(shape), dtype=data_type) + + @parameterized.parameters( + # Note that there is no 16-bit floating point format registered for GPU + {'align_corners': False, 'half_pixel_centers': False, + 'data_type': dtypes.float32}, + {'align_corners': False, 'half_pixel_centers': False, + 'data_type': dtypes.float64}, + {'align_corners': True, 'half_pixel_centers': False, + 'data_type': dtypes.float32}, + {'align_corners': False, 'half_pixel_centers': True, + 'data_type': dtypes.float32}) + @test_util.run_in_graph_and_eager_modes + @test_util.run_cuda_only + def testDeterministicGradients(self, align_corners, half_pixel_centers, + data_type): + if not align_corners and test_util.is_xla_enabled(): + # Align corners is deprecated in TF2.0, but align_corners==False is not + # supported by XLA. + self.skipTest("align_corners==False not currently supported by XLA") + with self.session(force_gpu=True): + seed = (hash(align_corners) % 256 + hash(half_pixel_centers) %256 + + hash(data_type) % 256) + np.random.seed(seed) + input_shape = (1, 25, 12, 3) # NHWC + output_shape = (1, 200, 250, 3) + input_image = self._randomDataOp(input_shape, data_type) + repeat_count = 3 + if context.executing_eagerly(): + + def resize_bilinear_gradients(local_seed): + np.random.seed(local_seed) + upstream_gradients = self._randomDataOp(output_shape, dtypes.float32) + with backprop.GradientTape(persistent=True) as tape: + tape.watch(input_image) + output_image = image_ops.resize_bilinear( + input_image, output_shape[1:3], align_corners=align_corners, + half_pixel_centers=half_pixel_centers) + gradient_injector_output = output_image * upstream_gradients + return tape.gradient(gradient_injector_output, input_image) + + for i in range(repeat_count): + local_seed = seed + i # select different upstream gradients + result_a = resize_bilinear_gradients(local_seed) + result_b = resize_bilinear_gradients(local_seed) + self.assertAllEqual(result_a, result_b) + else: # graph mode + upstream_gradients = array_ops.placeholder( + dtypes.float32, shape=output_shape, name='upstream_gradients') + output_image = image_ops.resize_bilinear( + input_image, output_shape[1:3], align_corners=align_corners, + half_pixel_centers=half_pixel_centers) + gradient_injector_output = output_image * upstream_gradients + # The gradient function behaves as if grad_ys is multiplied by the op + # gradient result, not passing the upstram gradients through the op's + # gradient generation graph. This is the reason for using the + # gradient injector + resize_bilinear_gradients = gradients_impl.gradients( + gradient_injector_output, input_image, grad_ys=None, + colocate_gradients_with_ops=True)[0] + for i in range(repeat_count): + feed_dict = {upstream_gradients: self._randomNDArray(output_shape)} + result_a = resize_bilinear_gradients.eval(feed_dict=feed_dict) + result_b = resize_bilinear_gradients.eval(feed_dict=feed_dict) + self.assertAllEqual(result_a, result_b) + + +if __name__ == '__main__': + # Note that the effect of setting the following environment variable to + # 'true' is not tested. Unless we can find a simpler pattern for testing these + # environment variables, it would require this file to be made into a base + # and then two more test files to be created. + # + # When deterministic op functionality can be enabled and disabled between test + # cases in the same process, then the tests for deterministic op + # functionality, for this op and for other ops, will be able to be included in + # the same file with the regular tests, simplifying the organization of tests + # and test files. + os.environ['TF_DETERMINISTIC_OPS'] = '1' + test.main() diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py new file mode 100644 index 00000000000..92e174b3703 --- /dev/null +++ b/tensorflow/python/ops/image_grad_test.py @@ -0,0 +1,32 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for Image Op Gradients.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import image_grad_test_base as test_base +from tensorflow.python.platform import test + +ResizeNearestNeighborOpTest = test_base.ResizeNearestNeighborOpTestBase +ResizeBilinearOpTest = test_base.ResizeBilinearOpTestBase +ResizeBicubicOpTest = test_base.ResizeBicubicOpTestBase +ScaleAndTranslateOpTest = test_base.ScaleAndTranslateOpTestBase +CropAndResizeOpTest = test_base.CropAndResizeOpTestBase +RGBToHSVOpTest = test_base.RGBToHSVOpTestBase + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/image_grad_test_base.py b/tensorflow/python/ops/image_grad_test_base.py index f5d4fc29dfd..c325388add1 100644 --- a/tensorflow/python/ops/image_grad_test_base.py +++ b/tensorflow/python/ops/image_grad_test_base.py @@ -20,6 +20,8 @@ from __future__ import print_function import numpy as np +from absl.testing import parameterized + from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util @@ -35,7 +37,7 @@ from tensorflow.python.ops import array_ops @test_util.for_all_test_methods(test_util.disable_xla, 'align_corners=False not supported by XLA') -class ResizeNearestNeighborOpTest(test.TestCase): +class ResizeNearestNeighborOpTestBase(test.TestCase): TYPES = [np.float32, np.float64] @@ -111,97 +113,140 @@ class ResizeNearestNeighborOpTest(test.TestCase): self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5) -class ResizeBilinearOpTest(test.TestCase): +class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase): - def testShapeIsCorrectAfterOp(self): - in_shape = [1, 2, 2, 1] - out_shape = [1, 4, 6, 1] - - x = np.arange(0, 4).reshape(in_shape).astype(np.float32) - - input_tensor = constant_op.constant(x, shape=in_shape) - resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) - with self.cached_session(): - self.assertEqual(out_shape, list(resize_out.get_shape())) - resize_out = self.evaluate(resize_out) - self.assertEqual(out_shape, list(resize_out.shape)) - - @test_util.run_deprecated_v1 - def testGradFromResizeToLargerInBothDims(self): - in_shape = [1, 2, 3, 1] - out_shape = [1, 4, 6, 1] - - x = np.arange(0, 6).reshape(in_shape).astype(np.float32) - - with self.cached_session(): - input_tensor = constant_op.constant(x, shape=in_shape) - resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) - err = gradient_checker.compute_gradient_error( - input_tensor, in_shape, resize_out, out_shape, x_init_value=x) - self.assertLess(err, 1e-3) - - @test_util.run_deprecated_v1 - def testGradFromResizeToSmallerInBothDims(self): - in_shape = [1, 4, 6, 1] - out_shape = [1, 2, 3, 1] - - x = np.arange(0, 24).reshape(in_shape).astype(np.float32) - - with self.cached_session(): - input_tensor = constant_op.constant(x, shape=in_shape) - resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) - err = gradient_checker.compute_gradient_error( - input_tensor, in_shape, resize_out, out_shape, x_init_value=x) - self.assertLess(err, 1e-3) - - @test_util.run_deprecated_v1 - def testCompareGpuVsCpu(self): - in_shape = [2, 4, 6, 3] - out_shape = [2, 8, 16, 3] - - size = np.prod(in_shape) - x = 1.0 / size * np.arange(0, size).reshape(in_shape).astype(np.float32) - - # Align corners will be deprecated for tf2.0 and the false version is not + def _itGen(self, smaller_shape, larger_shape): + up_sample = (smaller_shape, larger_shape) + down_sample = (larger_shape, smaller_shape) + pass_through = (larger_shape, larger_shape) + shape_pairs = (up_sample, down_sample, pass_through) + # Align corners is deprecated in TF2.0, but align_corners==False is not # supported by XLA. - align_corner_options = [True - ] if test_util.is_xla_enabled() else [True, False] - for align_corners in align_corner_options: - grad = {} - for use_gpu in [False, True]: - with self.cached_session(use_gpu=use_gpu): - input_tensor = constant_op.constant(x, shape=in_shape) - resized_tensor = image_ops.resize_bilinear( - input_tensor, out_shape[1:3], align_corners=align_corners) - grad[use_gpu] = gradient_checker.compute_gradient( - input_tensor, in_shape, resized_tensor, out_shape, x_init_value=x) + options = [(True, False)] + if not test_util.is_xla_enabled(): + options += [(False, True), (False, False)] + for align_corners, half_pixel_centers in options: + for in_shape, out_shape in shape_pairs: + yield in_shape, out_shape, align_corners, half_pixel_centers - self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4) + def _getJacobians(self, in_shape, out_shape, align_corners=False, + half_pixel_centers=False, dtype=np.float32, use_gpu=False, + force_gpu=False): + with self.cached_session(use_gpu=use_gpu, force_gpu=force_gpu) as sess: + # Input values should not influence gradients + x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(dtype) + input_tensor = constant_op.constant(x, shape=in_shape) + resized_tensor = image_ops.resize_bilinear( + input_tensor, out_shape[1:3], align_corners=align_corners, + half_pixel_centers=half_pixel_centers) + # compute_gradient will use a random tensor as the init value + return gradient_checker.compute_gradient( + input_tensor, in_shape, resized_tensor, out_shape) + + @parameterized.parameters( + {'batch_size': 1, 'channel_count': 1}, + {'batch_size': 2, 'channel_count': 3}, + {'batch_size': 5, 'channel_count': 4}) + @test_util.run_deprecated_v1 + def testShapes(self, batch_size, channel_count): + smaller_shape = [batch_size, 2, 3, channel_count] + larger_shape = [batch_size, 4, 6, channel_count] + for in_shape, out_shape, align_corners, half_pixel_centers in \ + self._itGen(smaller_shape, larger_shape): + # Input values should not influence shapes + x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(np.float32) + input_tensor = constant_op.constant(x, shape=in_shape) + resized_tensor = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) + self.assertEqual(out_shape, list(resized_tensor.get_shape())) + grad_tensor = gradients_impl.gradients(resized_tensor, input_tensor)[0] + self.assertEqual(in_shape, list(grad_tensor.get_shape())) + with self.cached_session(): + resized_values = self.evaluate(resized_tensor) + self.assertEqual(out_shape, list(resized_values.shape)) + grad_values = self.evaluate(grad_tensor) + self.assertEqual(in_shape, list(grad_values.shape)) + + @parameterized.parameters( + {'batch_size': 1, 'channel_count': 1}, + {'batch_size': 4, 'channel_count': 3}, + {'batch_size': 3, 'channel_count': 2}) + @test_util.run_deprecated_v1 + def testGradients(self, batch_size, channel_count): + smaller_shape = [batch_size, 2, 3, channel_count] + larger_shape = [batch_size, 5, 6, channel_count] + for in_shape, out_shape, align_corners, half_pixel_centers in \ + self._itGen(smaller_shape, larger_shape): + jacob_a, jacob_n = self._getJacobians( + in_shape, out_shape, align_corners, half_pixel_centers) + threshold = 1e-4 + self.assertAllClose(jacob_a, jacob_n, threshold, threshold) @test_util.run_deprecated_v1 def testTypes(self): in_shape = [1, 4, 6, 1] out_shape = [1, 2, 3, 1] - x = np.arange(0, 24).reshape(in_shape) - for use_gpu in [False, True]: - with self.cached_session(use_gpu=use_gpu) as sess: - for dtype in [np.float16, np.float32, np.float64]: - input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape) - resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) - grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0] - self.assertAllEqual(in_shape, grad.shape) - # Not using gradient_checker.compute_gradient as I didn't work out - # the changes required to compensate for the lower precision of - # float16 when computing the numeric jacobian. - # Instead, we just test the theoretical jacobian. - self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]], - [[0.], [0.], [0.], [0.], [0.], [0.]], - [[1.], [0.], [1.], [0.], [1.], [0.]], - [[0.], [0.], [0.], [0.], [0.], [0.]]]], grad) + for dtype in [np.float16, np.float32, np.float64]: + jacob_a, jacob_n = self._getJacobians( + in_shape, out_shape, dtype=dtype, use_gpu=use_gpu) + if dtype == np.float16: + # Compare fp16 analytical gradients to fp32 numerical gradients, + # since fp16 numerical gradients are too imprecise unless great + # care is taken with choosing the inputs and the delta. This is + # a weaker, but pragmatic, check (in particular, it does not test + # the op itself, only its gradient). + _, jacob_n = self._getJacobians( + in_shape, out_shape, dtype=np.float32, use_gpu=use_gpu) + threshold = 1e-3 + if dtype == np.float64: + threshold = 1e-5 + self.assertAllClose(jacob_a, jacob_n, threshold, threshold) + + @test_util.run_deprecated_v1 + def testGradOnUnsupportedType(self): + in_shape = [1, 4, 6, 1] + out_shape = [1, 2, 3, 1] + + x = np.arange(0, 24).reshape(in_shape).astype(np.uint8) + + input_tensor = constant_op.constant(x, shape=in_shape) + resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) + with self.cached_session(): + grad = gradients_impl.gradients(resize_out, [input_tensor]) + self.assertEqual([None], grad) + + def _gpuVsCpuCase(self, in_shape, out_shape, align_corners, + half_pixel_centers, dtype): + grad = {} + for use_gpu in [False, True]: + grad[use_gpu] = self._getJacobians( + in_shape, out_shape, align_corners, half_pixel_centers, dtype=dtype, + use_gpu=use_gpu) + threshold = 1e-4 + # Note that this is comparing both analytical and numerical Jacobians + self.assertAllClose(grad[False], grad[True], rtol=threshold, atol=threshold) + + @parameterized.parameters( + {'batch_size': 1, 'channel_count': 1}, + {'batch_size': 2, 'channel_count': 3}, + {'batch_size': 5, 'channel_count': 4}) + @test_util.run_deprecated_v1 + def testCompareGpuVsCpu(self, batch_size, channel_count): + smaller_shape = [batch_size, 4, 6, channel_count] + larger_shape = [batch_size, 8, 16, channel_count] + for params in self._itGen(smaller_shape, larger_shape): + self._gpuVsCpuCase(*params, dtype=np.float32) + + @test_util.run_deprecated_v1 + def testCompareGpuVsCpuFloat64(self): + in_shape = [1, 5, 7, 1] + out_shape = [1, 9, 11, 1] + # Note that there is no 16-bit floating-point format registered for GPU + self._gpuVsCpuCase(in_shape, out_shape, align_corners=True, + half_pixel_centers=False, dtype=np.float64) -class ResizeBicubicOpTest(test.TestCase): +class ResizeBicubicOpTestBase(test.TestCase): def testShapeIsCorrectAfterOp(self): in_shape = [1, 2, 2, 1] @@ -264,7 +309,7 @@ class ResizeBicubicOpTest(test.TestCase): self.assertEqual([None], grad) -class ScaleAndTranslateOpTest(test.TestCase): +class ScaleAndTranslateOpTestBase(test.TestCase): @test_util.run_deprecated_v1 def testGrads(self): @@ -328,7 +373,7 @@ class ScaleAndTranslateOpTest(test.TestCase): self.assertAllClose(np.ones_like(grad_v), grad_v) -class CropAndResizeOpTest(test.TestCase): +class CropAndResizeOpTestBase(test.TestCase): def testShapeIsCorrectAfterOp(self): batch = 2 @@ -457,7 +502,7 @@ class CropAndResizeOpTest(test.TestCase): @test_util.run_all_in_graph_and_eager_modes -class RGBToHSVOpTest(test.TestCase): +class RGBToHSVOpTestBase(test.TestCase): TYPES = [np.float32, np.float64] diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 71b17271226..4f83fd0a194 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -142,6 +142,7 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/tools:tools_pip", "//tensorflow/python/tools/api/generator:create_python_api", "//tensorflow/python/tpu", + "//tensorflow/python:image_grad_test_base", "//tensorflow/python:test_ops", "//tensorflow/python:while_v2", "//tensorflow/tools/common:public_api",