Add deterministic mode for resize_bilinear back-prop

This commit is contained in:
Duncan Riach 2020-04-13 12:07:47 -07:00
parent 038edfbddd
commit 116db3235a
7 changed files with 468 additions and 103 deletions

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/image/resize_bilinear_op.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
@ -228,6 +229,56 @@ __global__ void ResizeBilinearGradKernel(const int32 nthreads,
}
}
template <typename T>
__global__ void ResizeBilinearDeterministicGradKernel(
const int32 nthreads, const float* __restrict__ input_grad,
float height_scale, float inverse_height_scale, float width_scale,
float inverse_width_scale, int batch, int original_height,
int original_width, int channels, int resized_height, int resized_width,
T* __restrict__ output_grad) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = c + channels * (x + original_width * (y + original_height * b))
int idx = out_idx;
const int c = idx % channels;
idx /= channels;
const int out_x_center = idx % original_width;
idx /= original_width;
const int out_y_center = idx % original_height;
const int b = idx / original_height;
int in_y_start = max(0, __float2int_ru(
(out_y_center - 1 + 0.5) * inverse_height_scale - 0.5));
const float out_y_start = (in_y_start + 0.5) * height_scale - 0.5;
int in_x_start = max(0, __float2int_ru(
(out_x_center - 1 + 0.5) * inverse_width_scale - 0.5));
const float out_x_start = (in_x_start + 0.5) * width_scale - 0.5;
T acc = 0;
// For clarity, prior to C++17, while loops are preferable to for loops here
float out_y = out_y_start; int in_y = in_y_start;
while(out_y < out_y_center + 1 && in_y < resized_height) {
float out_x = out_x_start; int in_x = in_x_start;
while(out_x < out_x_center + 1 && in_x < resized_width) {
int in_idx = ((b * resized_height + in_y) * resized_width + in_x) *
channels + c;
// Clamping to zero is necessary because out_x and out_y can be negative
// due to half-pixel adjustments to out_y_start and out_x_start.
// Clamping to height/width is necessary when upscaling.
float out_y_clamped = fmaxf(0, fminf(out_y, original_height - 1));
float out_x_clamped = fmaxf(0, fminf(out_x, original_width - 1));
float y_lerp = (1 - fabsf(out_y_clamped - out_y_center));
float x_lerp = (1 - fabsf(out_x_clamped - out_x_center));
acc += static_cast<T>(input_grad[in_idx] * y_lerp * x_lerp);
out_x += width_scale;
in_x++;
}
out_y += height_scale;
in_y++;
}
output_grad[out_idx] = acc;
}
}
template <typename T>
__global__ void LegacyResizeBilinearKernel(
const int32 nthreads, const T* __restrict__ images, float height_scale,
@ -338,6 +389,55 @@ __global__ void LegacyResizeBilinearGradKernel(
}
}
template <typename T>
__global__ void LegacyResizeBilinearDeterministicGradKernel(
const int32 nthreads, const float* __restrict__ input_grad,
float height_scale, float inverse_height_scale, float width_scale,
float inverse_width_scale, int batch, int original_height,
int original_width, int channels, int resized_height, int resized_width,
T* __restrict__ output_grad) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = c + channels * (x + original_width * (y + original_height * b))
int idx = out_idx;
const int c = idx % channels;
idx /= channels;
const int out_x_center = idx % original_width;
idx /= original_width;
const int out_y_center = idx % original_height;
const int b = idx / original_height;
int in_y_start = max(0, __float2int_ru(
(out_y_center - 1) * inverse_height_scale));
const float out_y_start = in_y_start * height_scale;
int in_x_start = max(0, __float2int_ru(
(out_x_center - 1) * inverse_width_scale));
const float out_x_start = in_x_start * width_scale;
T acc = 0;
// For clarity, prior to C++17, while loops are preferable to for loops here
float out_y = out_y_start; int in_y = in_y_start;
while(out_y < out_y_center + 1 && in_y < resized_height) {
float out_x = out_x_start; int in_x = in_x_start;
while(out_x < out_x_center + 1 && in_x < resized_width) {
int in_idx = ((b * resized_height + in_y) * resized_width + in_x) *
channels + c;
// Clamping to zero is unnecessary because out_x and out_y will never
// be less than zero in legacy mode.
// Clamping to height/width is necessary when upscaling.
float out_y_clamped = fminf(out_y, original_height - 1);
float out_x_clamped = fminf(out_x, original_width - 1);
float y_lerp = (1 - fabsf(out_y_clamped - out_y_center));
float x_lerp = (1 - fabsf(out_x_clamped - out_x_center));
acc += static_cast<T>(input_grad[in_idx] * y_lerp * x_lerp);
out_x += width_scale;
in_x++;
}
out_y += height_scale;
in_y++;
}
output_grad[out_idx] = acc;
}
}
} // namespace
namespace functor {
@ -394,6 +494,17 @@ struct ResizeBilinear<GPUDevice, T> {
}
};
bool RequireDeterminism() {
static bool require_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
return deterministic_ops;
}();
return require_determinism;
}
// Partial specialization of ResizeBilinearGrad functor for a GPUDevice.
template <typename T>
struct ResizeBilinearGrad<GPUDevice, T> {
@ -413,31 +524,53 @@ struct ResizeBilinearGrad<GPUDevice, T> {
int total_count;
GpuLaunchConfig config;
// Initialize output_grad with all zeros.
total_count = batch * original_height * original_width * channels;
if (total_count == 0) return;
config = GetGpuLaunchConfig(total_count, d);
TF_CHECK_OK(GpuLaunchKernel(
SetZero<T>, config.block_count, config.thread_per_block, 0, d.stream(),
config.virtual_thread_count, output_grad.data()));
// Accumulate.
total_count = batch * resized_height * resized_width * channels;
config = GetGpuLaunchConfig(total_count, d);
if (half_pixel_centers) {
TF_CHECK_OK(GpuLaunchKernel(
ResizeBilinearGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch, original_height,
original_width, channels, resized_height, resized_width,
output_grad.data()));
if (RequireDeterminism()) {
// The following scale values below should never be zero, enforced by
// ImageResizerGradientState
float inverse_height_scale = 1 / height_scale;
float inverse_width_scale = 1 / width_scale;
if (half_pixel_centers) {
TF_CHECK_OK(GpuLaunchKernel(
ResizeBilinearDeterministicGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, inverse_height_scale, width_scale,
inverse_width_scale, batch, original_height, original_width,
channels, resized_height, resized_width, output_grad.data()));
} else {
TF_CHECK_OK(GpuLaunchKernel(
LegacyResizeBilinearDeterministicGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, inverse_height_scale, width_scale,
inverse_width_scale, batch, original_height, original_width,
channels, resized_height, resized_width, output_grad.data()));
}
} else {
// Initialize output_grad with all zeros.
TF_CHECK_OK(GpuLaunchKernel(
LegacyResizeBilinearGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch, original_height,
original_width, channels, resized_height, resized_width,
output_grad.data()));
SetZero<T>, config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, output_grad.data()));
// Accumulate.
total_count = batch * resized_height * resized_width * channels;
config = GetGpuLaunchConfig(total_count, d);
if (half_pixel_centers) {
TF_CHECK_OK(GpuLaunchKernel(
ResizeBilinearGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch,
original_height, original_width, channels, resized_height,
resized_width, output_grad.data()));
} else {
TF_CHECK_OK(GpuLaunchKernel(
LegacyResizeBilinearGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch,
original_height, original_width, channels, resized_height,
resized_width, output_grad.data()));
}
}
}
};

View File

@ -192,6 +192,19 @@ struct ImageResizerGradientState {
original_height = original_image.dim_size(1);
original_width = original_image.dim_size(2);
// The following check is also carried out for the forward op. It is added
// here to prevent a divide-by-zero exception when either height_scale or
// width_scale is being calculated.
OP_REQUIRES(context, resized_height > 0 && resized_width > 0,
errors::InvalidArgument("resized dimensions must be positive"));
// The following check is also carried out for the forward op. It is added
// here to prevent either height_scale or width_scale from being set to
// zero, which would cause a divide-by-zero exception in the deterministic
// back-prop path.
OP_REQUIRES(context, original_height > 0 && original_width > 0,
errors::InvalidArgument("original dimensions must be positive"));
OP_REQUIRES(
context,
FastBoundsCheck(original_height, std::numeric_limits<int32>::max()) &&

View File

@ -5174,12 +5174,30 @@ cuda_py_test(
],
)
cuda_py_test(
name = "image_grad_deterministic_test",
size = "medium",
srcs = ["ops/image_grad_deterministic_test.py"],
python_version = "PY3",
deps = [
":image_grad_test_base",
],
)
cuda_py_test(
name = "image_grad_test",
size = "medium",
srcs = ["ops/image_grad_test.py"],
python_version = "PY3",
tfrt_enabled = True,
deps = [
":image_grad_test_base",
],
)
py_library(
name = "image_grad_test_base",
srcs = ["ops/image_grad_test_base.py"],
deps = [
":client_testlib",
":framework_for_generated_wrappers",

View File

@ -0,0 +1,123 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for deterministic image op gradient functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from absl.testing import parameterized
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import image_grad_test_base as test_base
from tensorflow.python.ops import image_ops
from tensorflow.python.platform import test
class ResizeBilinearOpDeterministicTest(test_base.ResizeBilinearOpTestBase):
def _randomNDArray(self, shape):
return 2 * np.random.random_sample(shape) - 1
def _randomDataOp(self, shape, data_type):
return constant_op.constant(self._randomNDArray(shape), dtype=data_type)
@parameterized.parameters(
# Note that there is no 16-bit floating point format registered for GPU
{'align_corners': False, 'half_pixel_centers': False,
'data_type': dtypes.float32},
{'align_corners': False, 'half_pixel_centers': False,
'data_type': dtypes.float64},
{'align_corners': True, 'half_pixel_centers': False,
'data_type': dtypes.float32},
{'align_corners': False, 'half_pixel_centers': True,
'data_type': dtypes.float32})
@test_util.run_in_graph_and_eager_modes
@test_util.run_cuda_only
def testDeterministicGradients(self, align_corners, half_pixel_centers,
data_type):
if not align_corners and test_util.is_xla_enabled():
# Align corners is deprecated in TF2.0, but align_corners==False is not
# supported by XLA.
self.skipTest("align_corners==False not currently supported by XLA")
with self.session(force_gpu=True):
seed = (hash(align_corners) % 256 + hash(half_pixel_centers) %256 +
hash(data_type) % 256)
np.random.seed(seed)
input_shape = (1, 25, 12, 3) # NHWC
output_shape = (1, 200, 250, 3)
input_image = self._randomDataOp(input_shape, data_type)
repeat_count = 3
if context.executing_eagerly():
def resize_bilinear_gradients(local_seed):
np.random.seed(local_seed)
upstream_gradients = self._randomDataOp(output_shape, dtypes.float32)
with backprop.GradientTape(persistent=True) as tape:
tape.watch(input_image)
output_image = image_ops.resize_bilinear(
input_image, output_shape[1:3], align_corners=align_corners,
half_pixel_centers=half_pixel_centers)
gradient_injector_output = output_image * upstream_gradients
return tape.gradient(gradient_injector_output, input_image)
for i in range(repeat_count):
local_seed = seed + i # select different upstream gradients
result_a = resize_bilinear_gradients(local_seed)
result_b = resize_bilinear_gradients(local_seed)
self.assertAllEqual(result_a, result_b)
else: # graph mode
upstream_gradients = array_ops.placeholder(
dtypes.float32, shape=output_shape, name='upstream_gradients')
output_image = image_ops.resize_bilinear(
input_image, output_shape[1:3], align_corners=align_corners,
half_pixel_centers=half_pixel_centers)
gradient_injector_output = output_image * upstream_gradients
# The gradient function behaves as if grad_ys is multiplied by the op
# gradient result, not passing the upstram gradients through the op's
# gradient generation graph. This is the reason for using the
# gradient injector
resize_bilinear_gradients = gradients_impl.gradients(
gradient_injector_output, input_image, grad_ys=None,
colocate_gradients_with_ops=True)[0]
for i in range(repeat_count):
feed_dict = {upstream_gradients: self._randomNDArray(output_shape)}
result_a = resize_bilinear_gradients.eval(feed_dict=feed_dict)
result_b = resize_bilinear_gradients.eval(feed_dict=feed_dict)
self.assertAllEqual(result_a, result_b)
if __name__ == '__main__':
# Note that the effect of setting the following environment variable to
# 'true' is not tested. Unless we can find a simpler pattern for testing these
# environment variables, it would require this file to be made into a base
# and then two more test files to be created.
#
# When deterministic op functionality can be enabled and disabled between test
# cases in the same process, then the tests for deterministic op
# functionality, for this op and for other ops, will be able to be included in
# the same file with the regular tests, simplifying the organization of tests
# and test files.
os.environ['TF_DETERMINISTIC_OPS'] = '1'
test.main()

View File

@ -0,0 +1,32 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for Image Op Gradients."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import image_grad_test_base as test_base
from tensorflow.python.platform import test
ResizeNearestNeighborOpTest = test_base.ResizeNearestNeighborOpTestBase
ResizeBilinearOpTest = test_base.ResizeBilinearOpTestBase
ResizeBicubicOpTest = test_base.ResizeBicubicOpTestBase
ScaleAndTranslateOpTest = test_base.ScaleAndTranslateOpTestBase
CropAndResizeOpTest = test_base.CropAndResizeOpTestBase
RGBToHSVOpTest = test_base.RGBToHSVOpTestBase
if __name__ == "__main__":
test.main()

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import numpy as np
from absl.testing import parameterized
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
@ -35,7 +37,7 @@ from tensorflow.python.ops import array_ops
@test_util.for_all_test_methods(test_util.disable_xla,
'align_corners=False not supported by XLA')
class ResizeNearestNeighborOpTest(test.TestCase):
class ResizeNearestNeighborOpTestBase(test.TestCase):
TYPES = [np.float32, np.float64]
@ -111,97 +113,140 @@ class ResizeNearestNeighborOpTest(test.TestCase):
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
class ResizeBilinearOpTest(test.TestCase):
class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
with self.cached_session():
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
@test_util.run_deprecated_v1
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testCompareGpuVsCpu(self):
in_shape = [2, 4, 6, 3]
out_shape = [2, 8, 16, 3]
size = np.prod(in_shape)
x = 1.0 / size * np.arange(0, size).reshape(in_shape).astype(np.float32)
# Align corners will be deprecated for tf2.0 and the false version is not
def _itGen(self, smaller_shape, larger_shape):
up_sample = (smaller_shape, larger_shape)
down_sample = (larger_shape, smaller_shape)
pass_through = (larger_shape, larger_shape)
shape_pairs = (up_sample, down_sample, pass_through)
# Align corners is deprecated in TF2.0, but align_corners==False is not
# supported by XLA.
align_corner_options = [True
] if test_util.is_xla_enabled() else [True, False]
for align_corners in align_corner_options:
grad = {}
for use_gpu in [False, True]:
with self.cached_session(use_gpu=use_gpu):
input_tensor = constant_op.constant(x, shape=in_shape)
resized_tensor = image_ops.resize_bilinear(
input_tensor, out_shape[1:3], align_corners=align_corners)
grad[use_gpu] = gradient_checker.compute_gradient(
input_tensor, in_shape, resized_tensor, out_shape, x_init_value=x)
options = [(True, False)]
if not test_util.is_xla_enabled():
options += [(False, True), (False, False)]
for align_corners, half_pixel_centers in options:
for in_shape, out_shape in shape_pairs:
yield in_shape, out_shape, align_corners, half_pixel_centers
self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4)
def _getJacobians(self, in_shape, out_shape, align_corners=False,
half_pixel_centers=False, dtype=np.float32, use_gpu=False,
force_gpu=False):
with self.cached_session(use_gpu=use_gpu, force_gpu=force_gpu) as sess:
# Input values should not influence gradients
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(dtype)
input_tensor = constant_op.constant(x, shape=in_shape)
resized_tensor = image_ops.resize_bilinear(
input_tensor, out_shape[1:3], align_corners=align_corners,
half_pixel_centers=half_pixel_centers)
# compute_gradient will use a random tensor as the init value
return gradient_checker.compute_gradient(
input_tensor, in_shape, resized_tensor, out_shape)
@parameterized.parameters(
{'batch_size': 1, 'channel_count': 1},
{'batch_size': 2, 'channel_count': 3},
{'batch_size': 5, 'channel_count': 4})
@test_util.run_deprecated_v1
def testShapes(self, batch_size, channel_count):
smaller_shape = [batch_size, 2, 3, channel_count]
larger_shape = [batch_size, 4, 6, channel_count]
for in_shape, out_shape, align_corners, half_pixel_centers in \
self._itGen(smaller_shape, larger_shape):
# Input values should not influence shapes
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(np.float32)
input_tensor = constant_op.constant(x, shape=in_shape)
resized_tensor = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
self.assertEqual(out_shape, list(resized_tensor.get_shape()))
grad_tensor = gradients_impl.gradients(resized_tensor, input_tensor)[0]
self.assertEqual(in_shape, list(grad_tensor.get_shape()))
with self.cached_session():
resized_values = self.evaluate(resized_tensor)
self.assertEqual(out_shape, list(resized_values.shape))
grad_values = self.evaluate(grad_tensor)
self.assertEqual(in_shape, list(grad_values.shape))
@parameterized.parameters(
{'batch_size': 1, 'channel_count': 1},
{'batch_size': 4, 'channel_count': 3},
{'batch_size': 3, 'channel_count': 2})
@test_util.run_deprecated_v1
def testGradients(self, batch_size, channel_count):
smaller_shape = [batch_size, 2, 3, channel_count]
larger_shape = [batch_size, 5, 6, channel_count]
for in_shape, out_shape, align_corners, half_pixel_centers in \
self._itGen(smaller_shape, larger_shape):
jacob_a, jacob_n = self._getJacobians(
in_shape, out_shape, align_corners, half_pixel_centers)
threshold = 1e-4
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testTypes(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape)
for use_gpu in [False, True]:
with self.cached_session(use_gpu=use_gpu) as sess:
for dtype in [np.float16, np.float32, np.float64]:
input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0]
self.assertAllEqual(in_shape, grad.shape)
# Not using gradient_checker.compute_gradient as I didn't work out
# the changes required to compensate for the lower precision of
# float16 when computing the numeric jacobian.
# Instead, we just test the theoretical jacobian.
self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]],
[[0.], [0.], [0.], [0.], [0.], [0.]],
[[1.], [0.], [1.], [0.], [1.], [0.]],
[[0.], [0.], [0.], [0.], [0.], [0.]]]], grad)
for dtype in [np.float16, np.float32, np.float64]:
jacob_a, jacob_n = self._getJacobians(
in_shape, out_shape, dtype=dtype, use_gpu=use_gpu)
if dtype == np.float16:
# Compare fp16 analytical gradients to fp32 numerical gradients,
# since fp16 numerical gradients are too imprecise unless great
# care is taken with choosing the inputs and the delta. This is
# a weaker, but pragmatic, check (in particular, it does not test
# the op itself, only its gradient).
_, jacob_n = self._getJacobians(
in_shape, out_shape, dtype=np.float32, use_gpu=use_gpu)
threshold = 1e-3
if dtype == np.float64:
threshold = 1e-5
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testGradOnUnsupportedType(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
with self.cached_session():
grad = gradients_impl.gradients(resize_out, [input_tensor])
self.assertEqual([None], grad)
def _gpuVsCpuCase(self, in_shape, out_shape, align_corners,
half_pixel_centers, dtype):
grad = {}
for use_gpu in [False, True]:
grad[use_gpu] = self._getJacobians(
in_shape, out_shape, align_corners, half_pixel_centers, dtype=dtype,
use_gpu=use_gpu)
threshold = 1e-4
# Note that this is comparing both analytical and numerical Jacobians
self.assertAllClose(grad[False], grad[True], rtol=threshold, atol=threshold)
@parameterized.parameters(
{'batch_size': 1, 'channel_count': 1},
{'batch_size': 2, 'channel_count': 3},
{'batch_size': 5, 'channel_count': 4})
@test_util.run_deprecated_v1
def testCompareGpuVsCpu(self, batch_size, channel_count):
smaller_shape = [batch_size, 4, 6, channel_count]
larger_shape = [batch_size, 8, 16, channel_count]
for params in self._itGen(smaller_shape, larger_shape):
self._gpuVsCpuCase(*params, dtype=np.float32)
@test_util.run_deprecated_v1
def testCompareGpuVsCpuFloat64(self):
in_shape = [1, 5, 7, 1]
out_shape = [1, 9, 11, 1]
# Note that there is no 16-bit floating-point format registered for GPU
self._gpuVsCpuCase(in_shape, out_shape, align_corners=True,
half_pixel_centers=False, dtype=np.float64)
class ResizeBicubicOpTest(test.TestCase):
class ResizeBicubicOpTestBase(test.TestCase):
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
@ -264,7 +309,7 @@ class ResizeBicubicOpTest(test.TestCase):
self.assertEqual([None], grad)
class ScaleAndTranslateOpTest(test.TestCase):
class ScaleAndTranslateOpTestBase(test.TestCase):
@test_util.run_deprecated_v1
def testGrads(self):
@ -328,7 +373,7 @@ class ScaleAndTranslateOpTest(test.TestCase):
self.assertAllClose(np.ones_like(grad_v), grad_v)
class CropAndResizeOpTest(test.TestCase):
class CropAndResizeOpTestBase(test.TestCase):
def testShapeIsCorrectAfterOp(self):
batch = 2
@ -457,7 +502,7 @@ class CropAndResizeOpTest(test.TestCase):
@test_util.run_all_in_graph_and_eager_modes
class RGBToHSVOpTest(test.TestCase):
class RGBToHSVOpTestBase(test.TestCase):
TYPES = [np.float32, np.float64]

View File

@ -142,6 +142,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/tools:tools_pip",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python/tpu",
"//tensorflow/python:image_grad_test_base",
"//tensorflow/python:test_ops",
"//tensorflow/python:while_v2",
"//tensorflow/tools/common:public_api",