Merge pull request #39243 from duncanriach:deterministic-image-resize-bilinear

PiperOrigin-RevId: 333087088
Change-Id: Ia85994cc8bf62f1bd21b821f035dea0941d12f93
This commit is contained in:
TensorFlower Gardener 2020-09-22 08:54:54 -07:00
commit 2ef1ff69d2
7 changed files with 903 additions and 542 deletions

View File

@ -1,4 +1,4 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2016-2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/image/resize_bilinear_op.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {
@ -228,6 +229,59 @@ __global__ void ResizeBilinearGradKernel(const int32 nthreads,
}
}
template <typename T>
__global__ void ResizeBilinearDeterministicGradKernel(
const int32 nthreads, const float* __restrict__ input_grad,
float height_scale, float inverse_height_scale, float width_scale,
float inverse_width_scale, int batch, int original_height,
int original_width, int channels, int resized_height, int resized_width,
float offset, T* __restrict__ output_grad) {
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
// out_idx = c + channels * (x + original_width * (y + original_height * b))
int idx = out_idx;
const int c = idx % channels;
idx /= channels;
const int out_x_center = idx % original_width;
idx /= original_width;
const int out_y_center = idx % original_height;
const int b = idx / original_height;
int in_y_start = max(
0, __float2int_ru((out_y_center - 1 + offset) * inverse_height_scale -
offset));
const float out_y_start = (in_y_start + offset) * height_scale - offset;
int in_x_start =
max(0, __float2int_ru(
(out_x_center - 1 + offset) * inverse_width_scale - offset));
const float out_x_start = (in_x_start + offset) * width_scale - offset;
T acc = 0;
// For clarity, prior to C++17, while loops are preferable to for loops here
float out_y = out_y_start;
int in_y = in_y_start;
while (out_y < out_y_center + 1 && in_y < resized_height) {
float out_x = out_x_start;
int in_x = in_x_start;
while (out_x < out_x_center + 1 && in_x < resized_width) {
int in_idx =
((b * resized_height + in_y) * resized_width + in_x) * channels + c;
// Clamping to zero is necessary because out_x and out_y can be negative
// due to half-pixel adjustments to out_y_start and out_x_start.
// Clamping to height/width is necessary when upscaling.
float out_y_clamped = fmaxf(0, fminf(out_y, original_height - 1));
float out_x_clamped = fmaxf(0, fminf(out_x, original_width - 1));
float y_lerp = (1 - fabsf(out_y_clamped - out_y_center));
float x_lerp = (1 - fabsf(out_x_clamped - out_x_center));
acc += static_cast<T>(input_grad[in_idx] * y_lerp * x_lerp);
out_x += width_scale;
in_x++;
}
out_y += height_scale;
in_y++;
}
output_grad[out_idx] = acc;
}
}
template <typename T>
__global__ void LegacyResizeBilinearKernel(
const int32 nthreads, const T* __restrict__ images, float height_scale,
@ -394,6 +448,17 @@ struct ResizeBilinear<GPUDevice, T> {
}
};
bool RequireDeterminism() {
static bool require_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
return deterministic_ops;
}();
return require_determinism;
}
// Partial specialization of ResizeBilinearGrad functor for a GPUDevice.
template <typename T>
struct ResizeBilinearGrad<GPUDevice, T> {
@ -413,31 +478,45 @@ struct ResizeBilinearGrad<GPUDevice, T> {
int total_count;
GpuLaunchConfig config;
// Initialize output_grad with all zeros.
total_count = batch * original_height * original_width * channels;
if (total_count == 0) return;
config = GetGpuLaunchConfig(total_count, d);
TF_CHECK_OK(GpuLaunchKernel(
SetZero<T>, config.block_count, config.thread_per_block, 0, d.stream(),
config.virtual_thread_count, output_grad.data()));
// Accumulate.
total_count = batch * resized_height * resized_width * channels;
config = GetGpuLaunchConfig(total_count, d);
if (half_pixel_centers) {
if (RequireDeterminism()) {
// The scale values below should never be zero, enforced by
// ImageResizerGradientState
float inverse_height_scale = 1 / height_scale;
float inverse_width_scale = 1 / width_scale;
float offset = half_pixel_centers ? 0.5 : 0;
TF_CHECK_OK(GpuLaunchKernel(
ResizeBilinearGradKernel<T>, config.block_count,
ResizeBilinearDeterministicGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch, original_height,
original_width, channels, resized_height, resized_width,
output_grad.data()));
input_grad.data(), height_scale, inverse_height_scale, width_scale,
inverse_width_scale, batch, original_height, original_width, channels,
resized_height, resized_width, offset, output_grad.data()));
} else {
// Initialize output_grad with all zeros.
TF_CHECK_OK(GpuLaunchKernel(
LegacyResizeBilinearGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch, original_height,
original_width, channels, resized_height, resized_width,
output_grad.data()));
SetZero<T>, config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, output_grad.data()));
// Accumulate.
total_count = batch * resized_height * resized_width * channels;
config = GetGpuLaunchConfig(total_count, d);
if (half_pixel_centers) {
TF_CHECK_OK(GpuLaunchKernel(
ResizeBilinearGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch,
original_height, original_width, channels, resized_height,
resized_width, output_grad.data()));
} else {
TF_CHECK_OK(GpuLaunchKernel(
LegacyResizeBilinearGradKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
input_grad.data(), height_scale, width_scale, batch,
original_height, original_width, channels, resized_height,
resized_width, output_grad.data()));
}
}
}
};

View File

@ -192,6 +192,20 @@ struct ImageResizerGradientState {
original_height = original_image.dim_size(1);
original_width = original_image.dim_size(2);
// The following check is also carried out for the forward op. It is added
// here to prevent a divide-by-zero exception when either height_scale or
// width_scale is being calculated.
OP_REQUIRES(context, resized_height > 0 && resized_width > 0,
errors::InvalidArgument("resized dimensions must be positive"));
// The following check is also carried out for the forward op. It is added
// here to prevent either height_scale or width_scale from being set to
// zero, which would cause a divide-by-zero exception in the deterministic
// back-prop path.
OP_REQUIRES(
context, original_height > 0 && original_width > 0,
errors::InvalidArgument("original dimensions must be positive"));
OP_REQUIRES(
context,
FastBoundsCheck(original_height, std::numeric_limits<int32>::max()) &&

View File

@ -5145,12 +5145,30 @@ cuda_py_test(
],
)
cuda_py_test(
name = "image_grad_deterministic_test",
size = "medium",
srcs = ["ops/image_grad_deterministic_test.py"],
python_version = "PY3",
deps = [
":image_grad_test_base",
],
)
cuda_py_test(
name = "image_grad_test",
size = "medium",
srcs = ["ops/image_grad_test.py"],
python_version = "PY3",
tfrt_enabled = True,
deps = [
":image_grad_test_base",
],
)
py_library(
name = "image_grad_test_base",
srcs = ["ops/image_grad_test_base.py"],
deps = [
":client_testlib",
":framework_for_generated_wrappers",

View File

@ -0,0 +1,142 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for deterministic image op gradient functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from absl.testing import parameterized
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import image_grad_test_base as test_base
from tensorflow.python.ops import image_ops
from tensorflow.python.platform import test
class ResizeBilinearOpDeterministicTest(test_base.ResizeBilinearOpTestBase):
def _randomNDArray(self, shape):
return 2 * np.random.random_sample(shape) - 1
def _randomDataOp(self, shape, data_type):
return constant_op.constant(self._randomNDArray(shape), dtype=data_type)
@parameterized.parameters(
# Note that there is no 16-bit floating point format registered for GPU
{
'align_corners': False,
'half_pixel_centers': False,
'data_type': dtypes.float32
},
{
'align_corners': False,
'half_pixel_centers': False,
'data_type': dtypes.float64
},
{
'align_corners': True,
'half_pixel_centers': False,
'data_type': dtypes.float32
},
{
'align_corners': False,
'half_pixel_centers': True,
'data_type': dtypes.float32
})
@test_util.run_in_graph_and_eager_modes
@test_util.run_cuda_only
def testDeterministicGradients(self, align_corners, half_pixel_centers,
data_type):
if not align_corners and test_util.is_xla_enabled():
# Align corners is deprecated in TF2.0, but align_corners==False is not
# supported by XLA.
self.skipTest('align_corners==False not currently supported by XLA')
with self.session(force_gpu=True):
seed = (
hash(align_corners) % 256 + hash(half_pixel_centers) % 256 +
hash(data_type) % 256)
np.random.seed(seed)
input_shape = (1, 25, 12, 3) # NHWC
output_shape = (1, 200, 250, 3)
input_image = self._randomDataOp(input_shape, data_type)
repeat_count = 3
if context.executing_eagerly():
def resize_bilinear_gradients(local_seed):
np.random.seed(local_seed)
upstream_gradients = self._randomDataOp(output_shape, dtypes.float32)
with backprop.GradientTape(persistent=True) as tape:
tape.watch(input_image)
output_image = image_ops.resize_bilinear(
input_image,
output_shape[1:3],
align_corners=align_corners,
half_pixel_centers=half_pixel_centers)
gradient_injector_output = output_image * upstream_gradients
return tape.gradient(gradient_injector_output, input_image)
for i in range(repeat_count):
local_seed = seed + i # select different upstream gradients
result_a = resize_bilinear_gradients(local_seed)
result_b = resize_bilinear_gradients(local_seed)
self.assertAllEqual(result_a, result_b)
else: # graph mode
upstream_gradients = array_ops.placeholder(
dtypes.float32, shape=output_shape, name='upstream_gradients')
output_image = image_ops.resize_bilinear(
input_image,
output_shape[1:3],
align_corners=align_corners,
half_pixel_centers=half_pixel_centers)
gradient_injector_output = output_image * upstream_gradients
# The gradient function behaves as if grad_ys is multiplied by the op
# gradient result, not passing the upstram gradients through the op's
# gradient generation graph. This is the reason for using the
# gradient injector
resize_bilinear_gradients = gradients_impl.gradients(
gradient_injector_output,
input_image,
grad_ys=None,
colocate_gradients_with_ops=True)[0]
for i in range(repeat_count):
feed_dict = {upstream_gradients: self._randomNDArray(output_shape)}
result_a = resize_bilinear_gradients.eval(feed_dict=feed_dict)
result_b = resize_bilinear_gradients.eval(feed_dict=feed_dict)
self.assertAllEqual(result_a, result_b)
if __name__ == '__main__':
# Note that the effect of setting the following environment variable to
# 'true' is not tested. Unless we can find a simpler pattern for testing these
# environment variables, it would require this file to be made into a base
# and then two more test files to be created.
#
# When deterministic op functionality can be enabled and disabled between test
# cases in the same process, then the tests for deterministic op
# functionality, for this op and for other ops, will be able to be included in
# the same file with the regular tests, simplifying the organization of tests
# and test files.
os.environ['TF_DETERMINISTIC_OPS'] = '1'
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,536 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Python ops defined in image_grad.py."""
"""Functional tests for Image Op Gradients."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import image_grad_test_base as test_base
from tensorflow.python.platform import test
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
@test_util.for_all_test_methods(test_util.disable_xla,
'align_corners=False not supported by XLA')
class ResizeNearestNeighborOpTest(test.TestCase):
TYPES = [np.float32, np.float64]
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
for nptype in self.TYPES:
x = np.arange(0, 4).reshape(in_shape).astype(nptype)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_nearest_neighbor(input_tensor,
out_shape[1:3])
with self.cached_session(use_gpu=True):
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = (1, 4, 6, 1)
for nptype in self.TYPES:
x = np.arange(0, 6).reshape(in_shape).astype(nptype)
def resize_nn(t, shape=out_shape):
return image_ops.resize_nearest_neighbor(t, shape[1:3])
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
self.assertLess(err, 1e-3)
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = (1, 2, 3, 1)
for nptype in self.TYPES:
x = np.arange(0, 24).reshape(in_shape).astype(nptype)
def resize_nn(t, shape=out_shape):
return image_ops.resize_nearest_neighbor(t, shape[1:3])
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
self.assertLess(err, 1e-3)
def testCompareGpuVsCpu(self):
in_shape = [1, 4, 6, 3]
out_shape = (1, 8, 16, 3)
for nptype in self.TYPES:
x = np.arange(0, np.prod(in_shape)).reshape(in_shape).astype(nptype)
for align_corners in [True, False]:
def resize_nn(t, shape=out_shape, align_corners=align_corners):
return image_ops.resize_nearest_neighbor(
t, shape[1:3], align_corners=align_corners)
with self.cached_session(use_gpu=False):
input_tensor = constant_op.constant(x, shape=in_shape)
grad_cpu = gradient_checker_v2.compute_gradient(resize_nn,
[input_tensor])
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
grad_gpu = gradient_checker_v2.compute_gradient(resize_nn,
[input_tensor])
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
class ResizeBilinearOpTest(test.TestCase):
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
with self.cached_session():
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
@test_util.run_deprecated_v1
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testCompareGpuVsCpu(self):
in_shape = [2, 4, 6, 3]
out_shape = [2, 8, 16, 3]
size = np.prod(in_shape)
x = 1.0 / size * np.arange(0, size).reshape(in_shape).astype(np.float32)
# Align corners will be deprecated for tf2.0 and the false version is not
# supported by XLA.
align_corner_options = [True
] if test_util.is_xla_enabled() else [True, False]
for align_corners in align_corner_options:
grad = {}
for use_gpu in [False, True]:
with self.cached_session(use_gpu=use_gpu):
input_tensor = constant_op.constant(x, shape=in_shape)
resized_tensor = image_ops.resize_bilinear(
input_tensor, out_shape[1:3], align_corners=align_corners)
grad[use_gpu] = gradient_checker.compute_gradient(
input_tensor, in_shape, resized_tensor, out_shape, x_init_value=x)
self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4)
@test_util.run_deprecated_v1
def testTypes(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape)
for use_gpu in [False, True]:
with self.cached_session(use_gpu=use_gpu) as sess:
for dtype in [np.float16, np.float32, np.float64]:
input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0]
self.assertAllEqual(in_shape, grad.shape)
# Not using gradient_checker.compute_gradient as I didn't work out
# the changes required to compensate for the lower precision of
# float16 when computing the numeric jacobian.
# Instead, we just test the theoretical jacobian.
self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]],
[[0.], [0.], [0.], [0.], [0.], [0.]],
[[1.], [0.], [1.], [0.], [1.], [0.]],
[[0.], [0.], [0.], [0.], [0.], [0.]]]], grad)
class ResizeBicubicOpTest(test.TestCase):
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(
input_tensor, out_shape[1:3], align_corners=align_corners)
with self.cached_session():
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
@test_util.run_deprecated_v1
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
align_corners=align_corners)
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(
input_tensor, out_shape[1:3], align_corners=align_corners)
with self.cached_session():
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradOnUnsupportedType(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
with self.cached_session():
grad = gradients_impl.gradients(input_tensor, [resize_out])
self.assertEqual([None], grad)
class ScaleAndTranslateOpTest(test.TestCase):
@test_util.run_deprecated_v1
def testGrads(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
kernel_types = [
'lanczos1', 'lanczos3', 'lanczos5', 'gaussian', 'box', 'triangle',
'keyscubic', 'mitchellcubic'
]
scales = [(1.0, 1.0), (0.37, 0.47), (2.1, 2.1)]
translations = [(0.0, 0.0), (3.14, 1.19), (2.1, 3.1), (100.0, 200.0)]
for scale in scales:
for translation in translations:
for kernel_type in kernel_types:
for antialias in [True, False]:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
scale_and_translate_out = image_ops.scale_and_translate(
input_tensor,
out_shape[1:3],
scale=constant_op.constant(scale),
translation=constant_op.constant(translation),
kernel_type=kernel_type,
antialias=antialias)
err = gradient_checker.compute_gradient_error(
input_tensor,
in_shape,
scale_and_translate_out,
out_shape,
x_init_value=x)
self.assertLess(err, 1e-3)
def testIdentityGrads(self):
"""Tests that Gradients for 1.0 scale should be ones for some kernels."""
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
kernel_types = ['lanczos1', 'lanczos3', 'lanczos5', 'triangle', 'keyscubic']
scale = (1.0, 1.0)
translation = (0.0, 0.0)
antialias = True
for kernel_type in kernel_types:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
with backprop.GradientTape() as tape:
tape.watch(input_tensor)
scale_and_translate_out = image_ops.scale_and_translate(
input_tensor,
out_shape[1:3],
scale=constant_op.constant(scale),
translation=constant_op.constant(translation),
kernel_type=kernel_type,
antialias=antialias)
grad = tape.gradient(scale_and_translate_out, input_tensor)[0]
grad_v = self.evaluate(grad)
self.assertAllClose(np.ones_like(grad_v), grad_v)
class CropAndResizeOpTest(test.TestCase):
def testShapeIsCorrectAfterOp(self):
batch = 2
image_height = 3
image_width = 4
crop_height = 4
crop_width = 5
depth = 2
num_boxes = 2
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = np.array([[0, 0, 1, 1], [.1, .2, .7, .8]], dtype=np.float32)
box_ind = np.array([0, 1], dtype=np.int32)
crops = image_ops.crop_and_resize(
constant_op.constant(image, shape=image_shape),
constant_op.constant(boxes, shape=[num_boxes, 4]),
constant_op.constant(box_ind, shape=[num_boxes]),
constant_op.constant(crop_size, shape=[2]))
with self.session(use_gpu=True) as sess:
self.assertEqual(crops_shape, list(crops.get_shape()))
crops = self.evaluate(crops)
self.assertEqual(crops_shape, list(crops.shape))
def _randomUniformAvoidAnchors(self, low, high, anchors, radius, num_samples):
"""Generate samples that are far enough from a set of anchor points.
We generate uniform samples in [low, high], then reject those that are less
than radius away from any point in anchors. We stop after we have accepted
num_samples samples.
Args:
low: The lower end of the interval.
high: The upper end of the interval.
anchors: A list of length num_crops with anchor points to avoid.
radius: Distance threshold for the samples from the anchors.
num_samples: How many samples to produce.
Returns:
samples: A list of length num_samples with the accepted samples.
"""
self.assertTrue(low < high)
self.assertTrue(radius >= 0)
num_anchors = len(anchors)
# Make sure that at least half of the interval is not forbidden.
self.assertTrue(2 * radius * num_anchors < 0.5 * (high - low))
anchors = np.reshape(anchors, num_anchors)
samples = []
while len(samples) < num_samples:
sample = np.random.uniform(low, high)
if np.all(np.fabs(sample - anchors) > radius):
samples.append(sample)
return samples
@test_util.run_deprecated_v1
def testGradRandomBoxes(self):
"""Test that the gradient is correct for randomly generated boxes.
The mapping is piecewise differentiable with respect to the box coordinates.
The points where the function is not differentiable are those which are
mapped to image pixels, i.e., the normalized y coordinates in
np.linspace(0, 1, image_height) and normalized x coordinates in
np.linspace(0, 1, image_width). Make sure that the box coordinates are
sufficiently far away from those rectangular grid centers that are points of
discontinuity, so that the finite difference Jacobian is close to the
computed one.
"""
np.random.seed(1) # Make it reproducible.
delta = 1e-3
radius = 2 * delta
low, high = -0.5, 1.5 # Also covers the case of extrapolation.
image_height = 4
for image_width in range(1, 3):
for crop_height in range(1, 3):
for crop_width in range(2, 4):
for depth in range(1, 3):
for num_boxes in range(1, 3):
batch = num_boxes
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
boxes_shape = [num_boxes, 4]
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = []
for _ in range(num_boxes):
# pylint: disable=unbalanced-tuple-unpacking
y1, y2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_height), radius, 2)
x1, x2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_width), radius, 2)
# pylint: enable=unbalanced-tuple-unpacking
boxes.append([y1, x1, y2, x2])
boxes = np.array(boxes, dtype=np.float32)
box_ind = np.arange(batch, dtype=np.int32)
with self.cached_session(use_gpu=True):
image_tensor = constant_op.constant(image, shape=image_shape)
boxes_tensor = constant_op.constant(boxes, shape=[num_boxes, 4])
box_ind_tensor = constant_op.constant(
box_ind, shape=[num_boxes])
crops = image_ops.crop_and_resize(
image_tensor,
boxes_tensor,
box_ind_tensor,
constant_op.constant(
crop_size, shape=[2]))
err = gradient_checker.compute_gradient_error(
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
crops,
crops_shape,
delta=delta,
x_init_value=[image, boxes])
self.assertLess(err, 2e-3)
@test_util.run_all_in_graph_and_eager_modes
class RGBToHSVOpTest(test.TestCase):
TYPES = [np.float32, np.float64]
def testShapeIsCorrectAfterOp(self):
in_shape = [2, 20, 30, 3]
out_shape = [2, 20, 30, 3]
for nptype in self.TYPES:
x = np.random.randint(0, high=255, size=[2, 20, 30, 3]).astype(nptype)
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
hsv_out = gen_image_ops.rgb_to_hsv(rgb_input_tensor)
with self.cached_session(use_gpu=True):
self.assertEqual(out_shape, list(hsv_out.get_shape()))
hsv_out = self.evaluate(hsv_out)
self.assertEqual(out_shape, list(hsv_out.shape))
def testRGBToHSVGradSimpleCase(self):
def f(x):
return gen_image_ops.rgb_to_hsv(x)
# Building a simple input tensor to avoid any discontinuity
x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8,
0.9]]).astype(np.float32)
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
# Computing Analytical and Numerical gradients of f(x)
analytical, numerical = gradient_checker_v2.compute_gradient(
f, [rgb_input_tensor])
self.assertAllClose(numerical, analytical, atol=1e-4)
def testRGBToHSVGradRandomCase(self):
def f(x):
return gen_image_ops.rgb_to_hsv(x)
np.random.seed(0)
# Building a simple input tensor to avoid any discontinuity
x = np.random.rand(1, 5, 5, 3).astype(np.float32)
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
# Computing Analytical and Numerical gradients of f(x)
self.assertLess(
gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [rgb_input_tensor])), 1e-4)
def testRGBToHSVGradSpecialCaseRGreatest(self):
# This test tests a specific subset of the input space
# with a dummy function implemented with native TF operations.
in_shape = [2, 10, 20, 3]
def f(x):
return gen_image_ops.rgb_to_hsv(x)
def f_dummy(x):
# This dummy function is a implementation of RGB to HSV using
# primitive TF functions for one particular case when R>G>B.
r = x[..., 0]
g = x[..., 1]
b = x[..., 2]
# Since MAX = r and MIN = b, we get the following h,s,v values.
v = r
s = 1 - math_ops.div_no_nan(b, r)
h = 60 * math_ops.div_no_nan(g - b, r - b)
h = h / 360
return array_ops.stack([h, s, v], axis=-1)
# Building a custom input tensor where R>G>B
x_reds = np.ones((in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
x_greens = 0.5 * np.ones(
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
x_blues = 0.2 * np.ones(
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
x = np.stack([x_reds, x_greens, x_blues], axis=-1)
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
# Computing Analytical and Numerical gradients of f(x)
analytical, numerical = gradient_checker_v2.compute_gradient(
f, [rgb_input_tensor])
# Computing Analytical and Numerical gradients of f_dummy(x)
analytical_dummy, numerical_dummy = gradient_checker_v2.compute_gradient(
f_dummy, [rgb_input_tensor])
self.assertAllClose(numerical, analytical, atol=1e-4)
self.assertAllClose(analytical_dummy, analytical, atol=1e-4)
self.assertAllClose(numerical_dummy, numerical, atol=1e-4)
ResizeNearestNeighborOpTest = test_base.ResizeNearestNeighborOpTestBase
ResizeBilinearOpTest = test_base.ResizeBilinearOpTestBase
ResizeBicubicOpTest = test_base.ResizeBicubicOpTestBase
ScaleAndTranslateOpTest = test_base.ScaleAndTranslateOpTestBase
CropAndResizeOpTest = test_base.CropAndResizeOpTestBase
RGBToHSVOpTest = test_base.RGBToHSVOpTestBase
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,622 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Python ops defined in image_grad.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from absl.testing import parameterized
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.platform import test
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import array_ops
@test_util.for_all_test_methods(test_util.disable_xla,
'align_corners=False not supported by XLA')
class ResizeNearestNeighborOpTestBase(test.TestCase):
TYPES = [np.float32, np.float64]
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
for nptype in self.TYPES:
x = np.arange(0, 4).reshape(in_shape).astype(nptype)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_nearest_neighbor(input_tensor,
out_shape[1:3])
with self.cached_session(use_gpu=True):
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = (1, 4, 6, 1)
for nptype in self.TYPES:
x = np.arange(0, 6).reshape(in_shape).astype(nptype)
def resize_nn(t, shape=out_shape):
return image_ops.resize_nearest_neighbor(t, shape[1:3])
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
self.assertLess(err, 1e-3)
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = (1, 2, 3, 1)
for nptype in self.TYPES:
x = np.arange(0, 24).reshape(in_shape).astype(nptype)
def resize_nn(t, shape=out_shape):
return image_ops.resize_nearest_neighbor(t, shape[1:3])
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
self.assertLess(err, 1e-3)
def testCompareGpuVsCpu(self):
in_shape = [1, 4, 6, 3]
out_shape = (1, 8, 16, 3)
for nptype in self.TYPES:
x = np.arange(0, np.prod(in_shape)).reshape(in_shape).astype(nptype)
for align_corners in [True, False]:
def resize_nn(t, shape=out_shape, align_corners=align_corners):
return image_ops.resize_nearest_neighbor(
t, shape[1:3], align_corners=align_corners)
with self.cached_session(use_gpu=False):
input_tensor = constant_op.constant(x, shape=in_shape)
grad_cpu = gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor])
with self.cached_session(use_gpu=True):
input_tensor = constant_op.constant(x, shape=in_shape)
grad_gpu = gradient_checker_v2.compute_gradient(
resize_nn, [input_tensor])
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
def _itGen(self, smaller_shape, larger_shape):
up_sample = (smaller_shape, larger_shape)
down_sample = (larger_shape, smaller_shape)
pass_through = (larger_shape, larger_shape)
shape_pairs = (up_sample, down_sample, pass_through)
# Align corners is deprecated in TF2.0, but align_corners==False is not
# supported by XLA.
options = [(True, False)]
if not test_util.is_xla_enabled():
options += [(False, True), (False, False)]
for align_corners, half_pixel_centers in options:
for in_shape, out_shape in shape_pairs:
yield in_shape, out_shape, align_corners, half_pixel_centers
def _getJacobians(self,
in_shape,
out_shape,
align_corners=False,
half_pixel_centers=False,
dtype=np.float32,
use_gpu=False,
force_gpu=False):
with self.cached_session(use_gpu=use_gpu, force_gpu=force_gpu) as sess:
# Input values should not influence gradients
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(dtype)
input_tensor = constant_op.constant(x, shape=in_shape)
resized_tensor = image_ops.resize_bilinear(
input_tensor,
out_shape[1:3],
align_corners=align_corners,
half_pixel_centers=half_pixel_centers)
# compute_gradient will use a random tensor as the init value
return gradient_checker.compute_gradient(input_tensor, in_shape,
resized_tensor, out_shape)
@parameterized.parameters({
'batch_size': 1,
'channel_count': 1
}, {
'batch_size': 2,
'channel_count': 3
}, {
'batch_size': 5,
'channel_count': 4
})
@test_util.run_deprecated_v1
def testShapes(self, batch_size, channel_count):
smaller_shape = [batch_size, 2, 3, channel_count]
larger_shape = [batch_size, 4, 6, channel_count]
for in_shape, out_shape, align_corners, half_pixel_centers in \
self._itGen(smaller_shape, larger_shape):
# Input values should not influence shapes
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(np.float32)
input_tensor = constant_op.constant(x, shape=in_shape)
resized_tensor = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
self.assertEqual(out_shape, list(resized_tensor.get_shape()))
grad_tensor = gradients_impl.gradients(resized_tensor, input_tensor)[0]
self.assertEqual(in_shape, list(grad_tensor.get_shape()))
with self.cached_session():
resized_values = self.evaluate(resized_tensor)
self.assertEqual(out_shape, list(resized_values.shape))
grad_values = self.evaluate(grad_tensor)
self.assertEqual(in_shape, list(grad_values.shape))
@parameterized.parameters({
'batch_size': 1,
'channel_count': 1
}, {
'batch_size': 4,
'channel_count': 3
}, {
'batch_size': 3,
'channel_count': 2
})
@test_util.run_deprecated_v1
def testGradients(self, batch_size, channel_count):
smaller_shape = [batch_size, 2, 3, channel_count]
larger_shape = [batch_size, 5, 6, channel_count]
for in_shape, out_shape, align_corners, half_pixel_centers in \
self._itGen(smaller_shape, larger_shape):
jacob_a, jacob_n = self._getJacobians(in_shape, out_shape, align_corners,
half_pixel_centers)
threshold = 1e-4
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testTypes(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
for use_gpu in [False, True]:
for dtype in [np.float16, np.float32, np.float64]:
jacob_a, jacob_n = self._getJacobians(
in_shape, out_shape, dtype=dtype, use_gpu=use_gpu)
if dtype == np.float16:
# Compare fp16 analytical gradients to fp32 numerical gradients,
# since fp16 numerical gradients are too imprecise unless great
# care is taken with choosing the inputs and the delta. This is
# a weaker, but pragmatic, check (in particular, it does not test
# the op itself, only its gradient).
_, jacob_n = self._getJacobians(
in_shape, out_shape, dtype=np.float32, use_gpu=use_gpu)
threshold = 1e-3
if dtype == np.float64:
threshold = 1e-5
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
@test_util.run_deprecated_v1
def testGradOnUnsupportedType(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
with self.cached_session():
grad = gradients_impl.gradients(resize_out, [input_tensor])
self.assertEqual([None], grad)
def _gpuVsCpuCase(self, in_shape, out_shape, align_corners,
half_pixel_centers, dtype):
grad = {}
for use_gpu in [False, True]:
grad[use_gpu] = self._getJacobians(
in_shape,
out_shape,
align_corners,
half_pixel_centers,
dtype=dtype,
use_gpu=use_gpu)
threshold = 1e-4
# Note that this is comparing both analytical and numerical Jacobians
self.assertAllClose(grad[False], grad[True], rtol=threshold, atol=threshold)
@parameterized.parameters({
'batch_size': 1,
'channel_count': 1
}, {
'batch_size': 2,
'channel_count': 3
}, {
'batch_size': 5,
'channel_count': 4
})
@test_util.run_deprecated_v1
def testCompareGpuVsCpu(self, batch_size, channel_count):
smaller_shape = [batch_size, 4, 6, channel_count]
larger_shape = [batch_size, 8, 16, channel_count]
for params in self._itGen(smaller_shape, larger_shape):
self._gpuVsCpuCase(*params, dtype=np.float32)
@test_util.run_deprecated_v1
def testCompareGpuVsCpuFloat64(self):
in_shape = [1, 5, 7, 1]
out_shape = [1, 9, 11, 1]
# Note that there is no 16-bit floating-point format registered for GPU
self._gpuVsCpuCase(
in_shape,
out_shape,
align_corners=True,
half_pixel_centers=False,
dtype=np.float64)
class ResizeBicubicOpTestBase(test.TestCase):
def testShapeIsCorrectAfterOp(self):
in_shape = [1, 2, 2, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(
input_tensor, out_shape[1:3], align_corners=align_corners)
with self.cached_session():
self.assertEqual(out_shape, list(resize_out.get_shape()))
resize_out = self.evaluate(resize_out)
self.assertEqual(out_shape, list(resize_out.shape))
@test_util.run_deprecated_v1
def testGradFromResizeToLargerInBothDims(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(
input_tensor, out_shape[1:3], align_corners=align_corners)
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradFromResizeToSmallerInBothDims(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
for align_corners in [True, False]:
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(
input_tensor, out_shape[1:3], align_corners=align_corners)
with self.cached_session():
err = gradient_checker.compute_gradient_error(
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
self.assertLess(err, 1e-3)
@test_util.run_deprecated_v1
def testGradOnUnsupportedType(self):
in_shape = [1, 4, 6, 1]
out_shape = [1, 2, 3, 1]
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
input_tensor = constant_op.constant(x, shape=in_shape)
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
with self.cached_session():
grad = gradients_impl.gradients(resize_out, [input_tensor])
self.assertEqual([None], grad)
class ScaleAndTranslateOpTestBase(test.TestCase):
@test_util.run_deprecated_v1
def testGrads(self):
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
kernel_types = [
'lanczos1', 'lanczos3', 'lanczos5', 'gaussian', 'box', 'triangle',
'keyscubic', 'mitchellcubic'
]
scales = [(1.0, 1.0), (0.37, 0.47), (2.1, 2.1)]
translations = [(0.0, 0.0), (3.14, 1.19), (2.1, 3.1), (100.0, 200.0)]
for scale in scales:
for translation in translations:
for kernel_type in kernel_types:
for antialias in [True, False]:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
scale_and_translate_out = image_ops.scale_and_translate(
input_tensor,
out_shape[1:3],
scale=constant_op.constant(scale),
translation=constant_op.constant(translation),
kernel_type=kernel_type,
antialias=antialias)
err = gradient_checker.compute_gradient_error(
input_tensor,
in_shape,
scale_and_translate_out,
out_shape,
x_init_value=x)
self.assertLess(err, 1e-3)
def testIdentityGrads(self):
"""Tests that Gradients for 1.0 scale should be ones for some kernels."""
in_shape = [1, 2, 3, 1]
out_shape = [1, 4, 6, 1]
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
kernel_types = ['lanczos1', 'lanczos3', 'lanczos5', 'triangle', 'keyscubic']
scale = (1.0, 1.0)
translation = (0.0, 0.0)
antialias = True
for kernel_type in kernel_types:
with self.cached_session():
input_tensor = constant_op.constant(x, shape=in_shape)
with backprop.GradientTape() as tape:
tape.watch(input_tensor)
scale_and_translate_out = image_ops.scale_and_translate(
input_tensor,
out_shape[1:3],
scale=constant_op.constant(scale),
translation=constant_op.constant(translation),
kernel_type=kernel_type,
antialias=antialias)
grad = tape.gradient(scale_and_translate_out, input_tensor)[0]
grad_v = self.evaluate(grad)
self.assertAllClose(np.ones_like(grad_v), grad_v)
class CropAndResizeOpTestBase(test.TestCase):
def testShapeIsCorrectAfterOp(self):
batch = 2
image_height = 3
image_width = 4
crop_height = 4
crop_width = 5
depth = 2
num_boxes = 2
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = np.array([[0, 0, 1, 1], [.1, .2, .7, .8]], dtype=np.float32)
box_ind = np.array([0, 1], dtype=np.int32)
crops = image_ops.crop_and_resize(
constant_op.constant(image, shape=image_shape),
constant_op.constant(boxes, shape=[num_boxes, 4]),
constant_op.constant(box_ind, shape=[num_boxes]),
constant_op.constant(crop_size, shape=[2]))
with self.session(use_gpu=True) as sess:
self.assertEqual(crops_shape, list(crops.get_shape()))
crops = self.evaluate(crops)
self.assertEqual(crops_shape, list(crops.shape))
def _randomUniformAvoidAnchors(self, low, high, anchors, radius, num_samples):
"""Generate samples that are far enough from a set of anchor points.
We generate uniform samples in [low, high], then reject those that are less
than radius away from any point in anchors. We stop after we have accepted
num_samples samples.
Args:
low: The lower end of the interval.
high: The upper end of the interval.
anchors: A list of length num_crops with anchor points to avoid.
radius: Distance threshold for the samples from the anchors.
num_samples: How many samples to produce.
Returns:
samples: A list of length num_samples with the accepted samples.
"""
self.assertTrue(low < high)
self.assertTrue(radius >= 0)
num_anchors = len(anchors)
# Make sure that at least half of the interval is not forbidden.
self.assertTrue(2 * radius * num_anchors < 0.5 * (high - low))
anchors = np.reshape(anchors, num_anchors)
samples = []
while len(samples) < num_samples:
sample = np.random.uniform(low, high)
if np.all(np.fabs(sample - anchors) > radius):
samples.append(sample)
return samples
@test_util.run_deprecated_v1
def testGradRandomBoxes(self):
"""Test that the gradient is correct for randomly generated boxes.
The mapping is piecewise differentiable with respect to the box coordinates.
The points where the function is not differentiable are those which are
mapped to image pixels, i.e., the normalized y coordinates in
np.linspace(0, 1, image_height) and normalized x coordinates in
np.linspace(0, 1, image_width). Make sure that the box coordinates are
sufficiently far away from those rectangular grid centers that are points of
discontinuity, so that the finite difference Jacobian is close to the
computed one.
"""
np.random.seed(1) # Make it reproducible.
delta = 1e-3
radius = 2 * delta
low, high = -0.5, 1.5 # Also covers the case of extrapolation.
image_height = 4
for image_width in range(1, 3):
for crop_height in range(1, 3):
for crop_width in range(2, 4):
for depth in range(1, 3):
for num_boxes in range(1, 3):
batch = num_boxes
image_shape = [batch, image_height, image_width, depth]
crop_size = [crop_height, crop_width]
crops_shape = [num_boxes, crop_height, crop_width, depth]
boxes_shape = [num_boxes, 4]
image = np.arange(0, batch * image_height * image_width *
depth).reshape(image_shape).astype(np.float32)
boxes = []
for _ in range(num_boxes):
# pylint: disable=unbalanced-tuple-unpacking
y1, y2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_height), radius, 2)
x1, x2 = self._randomUniformAvoidAnchors(
low, high, np.linspace(0, 1, image_width), radius, 2)
# pylint: enable=unbalanced-tuple-unpacking
boxes.append([y1, x1, y2, x2])
boxes = np.array(boxes, dtype=np.float32)
box_ind = np.arange(batch, dtype=np.int32)
with self.cached_session(use_gpu=True):
image_tensor = constant_op.constant(image, shape=image_shape)
boxes_tensor = constant_op.constant(boxes, shape=[num_boxes, 4])
box_ind_tensor = constant_op.constant(
box_ind, shape=[num_boxes])
crops = image_ops.crop_and_resize(
image_tensor, boxes_tensor, box_ind_tensor,
constant_op.constant(crop_size, shape=[2]))
err = gradient_checker.compute_gradient_error(
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
crops,
crops_shape,
delta=delta,
x_init_value=[image, boxes])
self.assertLess(err, 2e-3)
@test_util.run_all_in_graph_and_eager_modes
class RGBToHSVOpTestBase(test.TestCase):
TYPES = [np.float32, np.float64]
def testShapeIsCorrectAfterOp(self):
in_shape = [2, 20, 30, 3]
out_shape = [2, 20, 30, 3]
for nptype in self.TYPES:
x = np.random.randint(0, high=255, size=[2, 20, 30, 3]).astype(nptype)
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
hsv_out = gen_image_ops.rgb_to_hsv(rgb_input_tensor)
with self.cached_session(use_gpu=True):
self.assertEqual(out_shape, list(hsv_out.get_shape()))
hsv_out = self.evaluate(hsv_out)
self.assertEqual(out_shape, list(hsv_out.shape))
def testRGBToHSVGradSimpleCase(self):
def f(x):
return gen_image_ops.rgb_to_hsv(x)
# Building a simple input tensor to avoid any discontinuity
x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8,
0.9]]).astype(np.float32)
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
# Computing Analytical and Numerical gradients of f(x)
analytical, numerical = gradient_checker_v2.compute_gradient(
f, [rgb_input_tensor])
self.assertAllClose(numerical, analytical, atol=1e-4)
def testRGBToHSVGradRandomCase(self):
def f(x):
return gen_image_ops.rgb_to_hsv(x)
np.random.seed(0)
# Building a simple input tensor to avoid any discontinuity
x = np.random.rand(1, 5, 5, 3).astype(np.float32)
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
# Computing Analytical and Numerical gradients of f(x)
self.assertLess(
gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(f, [rgb_input_tensor])), 1e-4)
def testRGBToHSVGradSpecialCaseRGreatest(self):
# This test tests a specific subset of the input space
# with a dummy function implemented with native TF operations.
in_shape = [2, 10, 20, 3]
def f(x):
return gen_image_ops.rgb_to_hsv(x)
def f_dummy(x):
# This dummy function is a implementation of RGB to HSV using
# primitive TF functions for one particular case when R>G>B.
r = x[..., 0]
g = x[..., 1]
b = x[..., 2]
# Since MAX = r and MIN = b, we get the following h,s,v values.
v = r
s = 1 - math_ops.div_no_nan(b, r)
h = 60 * math_ops.div_no_nan(g - b, r - b)
h = h / 360
return array_ops.stack([h, s, v], axis=-1)
# Building a custom input tensor where R>G>B
x_reds = np.ones((in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
x_greens = 0.5 * np.ones(
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
x_blues = 0.2 * np.ones(
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
x = np.stack([x_reds, x_greens, x_blues], axis=-1)
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
# Computing Analytical and Numerical gradients of f(x)
analytical, numerical = gradient_checker_v2.compute_gradient(
f, [rgb_input_tensor])
# Computing Analytical and Numerical gradients of f_dummy(x)
analytical_dummy, numerical_dummy = gradient_checker_v2.compute_gradient(
f_dummy, [rgb_input_tensor])
self.assertAllClose(numerical, analytical, atol=1e-4)
self.assertAllClose(analytical_dummy, analytical, atol=1e-4)
self.assertAllClose(numerical_dummy, numerical, atol=1e-4)
if __name__ == '__main__':
test.main()

View File

@ -144,6 +144,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/tools:tools_pip",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python/tpu",
"//tensorflow/python:image_grad_test_base",
"//tensorflow/python:test_ops",
"//tensorflow/python:while_v2",
"//tensorflow/tools/common:public_api",