Merge pull request #39243 from duncanriach:deterministic-image-resize-bilinear
PiperOrigin-RevId: 333087088 Change-Id: Ia85994cc8bf62f1bd21b821f035dea0941d12f93
This commit is contained in:
commit
2ef1ff69d2
@ -1,4 +1,4 @@
|
|||||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2016-2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
you may not use this file except in compliance with the License.
|
you may not use this file except in compliance with the License.
|
||||||
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/image/resize_bilinear_op.h"
|
#include "tensorflow/core/kernels/image/resize_bilinear_op.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -228,6 +229,59 @@ __global__ void ResizeBilinearGradKernel(const int32 nthreads,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void ResizeBilinearDeterministicGradKernel(
|
||||||
|
const int32 nthreads, const float* __restrict__ input_grad,
|
||||||
|
float height_scale, float inverse_height_scale, float width_scale,
|
||||||
|
float inverse_width_scale, int batch, int original_height,
|
||||||
|
int original_width, int channels, int resized_height, int resized_width,
|
||||||
|
float offset, T* __restrict__ output_grad) {
|
||||||
|
GPU_1D_KERNEL_LOOP(out_idx, nthreads) {
|
||||||
|
// out_idx = c + channels * (x + original_width * (y + original_height * b))
|
||||||
|
int idx = out_idx;
|
||||||
|
const int c = idx % channels;
|
||||||
|
idx /= channels;
|
||||||
|
const int out_x_center = idx % original_width;
|
||||||
|
idx /= original_width;
|
||||||
|
const int out_y_center = idx % original_height;
|
||||||
|
const int b = idx / original_height;
|
||||||
|
|
||||||
|
int in_y_start = max(
|
||||||
|
0, __float2int_ru((out_y_center - 1 + offset) * inverse_height_scale -
|
||||||
|
offset));
|
||||||
|
const float out_y_start = (in_y_start + offset) * height_scale - offset;
|
||||||
|
int in_x_start =
|
||||||
|
max(0, __float2int_ru(
|
||||||
|
(out_x_center - 1 + offset) * inverse_width_scale - offset));
|
||||||
|
const float out_x_start = (in_x_start + offset) * width_scale - offset;
|
||||||
|
T acc = 0;
|
||||||
|
// For clarity, prior to C++17, while loops are preferable to for loops here
|
||||||
|
float out_y = out_y_start;
|
||||||
|
int in_y = in_y_start;
|
||||||
|
while (out_y < out_y_center + 1 && in_y < resized_height) {
|
||||||
|
float out_x = out_x_start;
|
||||||
|
int in_x = in_x_start;
|
||||||
|
while (out_x < out_x_center + 1 && in_x < resized_width) {
|
||||||
|
int in_idx =
|
||||||
|
((b * resized_height + in_y) * resized_width + in_x) * channels + c;
|
||||||
|
// Clamping to zero is necessary because out_x and out_y can be negative
|
||||||
|
// due to half-pixel adjustments to out_y_start and out_x_start.
|
||||||
|
// Clamping to height/width is necessary when upscaling.
|
||||||
|
float out_y_clamped = fmaxf(0, fminf(out_y, original_height - 1));
|
||||||
|
float out_x_clamped = fmaxf(0, fminf(out_x, original_width - 1));
|
||||||
|
float y_lerp = (1 - fabsf(out_y_clamped - out_y_center));
|
||||||
|
float x_lerp = (1 - fabsf(out_x_clamped - out_x_center));
|
||||||
|
acc += static_cast<T>(input_grad[in_idx] * y_lerp * x_lerp);
|
||||||
|
out_x += width_scale;
|
||||||
|
in_x++;
|
||||||
|
}
|
||||||
|
out_y += height_scale;
|
||||||
|
in_y++;
|
||||||
|
}
|
||||||
|
output_grad[out_idx] = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void LegacyResizeBilinearKernel(
|
__global__ void LegacyResizeBilinearKernel(
|
||||||
const int32 nthreads, const T* __restrict__ images, float height_scale,
|
const int32 nthreads, const T* __restrict__ images, float height_scale,
|
||||||
@ -394,6 +448,17 @@ struct ResizeBilinear<GPUDevice, T> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool RequireDeterminism() {
|
||||||
|
static bool require_determinism = [] {
|
||||||
|
bool deterministic_ops = false;
|
||||||
|
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
|
||||||
|
/*default_val=*/false,
|
||||||
|
&deterministic_ops));
|
||||||
|
return deterministic_ops;
|
||||||
|
}();
|
||||||
|
return require_determinism;
|
||||||
|
}
|
||||||
|
|
||||||
// Partial specialization of ResizeBilinearGrad functor for a GPUDevice.
|
// Partial specialization of ResizeBilinearGrad functor for a GPUDevice.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ResizeBilinearGrad<GPUDevice, T> {
|
struct ResizeBilinearGrad<GPUDevice, T> {
|
||||||
@ -413,14 +478,27 @@ struct ResizeBilinearGrad<GPUDevice, T> {
|
|||||||
int total_count;
|
int total_count;
|
||||||
GpuLaunchConfig config;
|
GpuLaunchConfig config;
|
||||||
|
|
||||||
// Initialize output_grad with all zeros.
|
|
||||||
total_count = batch * original_height * original_width * channels;
|
total_count = batch * original_height * original_width * channels;
|
||||||
if (total_count == 0) return;
|
if (total_count == 0) return;
|
||||||
config = GetGpuLaunchConfig(total_count, d);
|
config = GetGpuLaunchConfig(total_count, d);
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
|
||||||
SetZero<T>, config.block_count, config.thread_per_block, 0, d.stream(),
|
|
||||||
config.virtual_thread_count, output_grad.data()));
|
|
||||||
|
|
||||||
|
if (RequireDeterminism()) {
|
||||||
|
// The scale values below should never be zero, enforced by
|
||||||
|
// ImageResizerGradientState
|
||||||
|
float inverse_height_scale = 1 / height_scale;
|
||||||
|
float inverse_width_scale = 1 / width_scale;
|
||||||
|
float offset = half_pixel_centers ? 0.5 : 0;
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
|
ResizeBilinearDeterministicGradKernel<T>, config.block_count,
|
||||||
|
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
||||||
|
input_grad.data(), height_scale, inverse_height_scale, width_scale,
|
||||||
|
inverse_width_scale, batch, original_height, original_width, channels,
|
||||||
|
resized_height, resized_width, offset, output_grad.data()));
|
||||||
|
} else {
|
||||||
|
// Initialize output_grad with all zeros.
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
|
SetZero<T>, config.block_count, config.thread_per_block, 0,
|
||||||
|
d.stream(), config.virtual_thread_count, output_grad.data()));
|
||||||
// Accumulate.
|
// Accumulate.
|
||||||
total_count = batch * resized_height * resized_width * channels;
|
total_count = batch * resized_height * resized_width * channels;
|
||||||
config = GetGpuLaunchConfig(total_count, d);
|
config = GetGpuLaunchConfig(total_count, d);
|
||||||
@ -428,16 +506,17 @@ struct ResizeBilinearGrad<GPUDevice, T> {
|
|||||||
TF_CHECK_OK(GpuLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
ResizeBilinearGradKernel<T>, config.block_count,
|
ResizeBilinearGradKernel<T>, config.block_count,
|
||||||
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
||||||
input_grad.data(), height_scale, width_scale, batch, original_height,
|
input_grad.data(), height_scale, width_scale, batch,
|
||||||
original_width, channels, resized_height, resized_width,
|
original_height, original_width, channels, resized_height,
|
||||||
output_grad.data()));
|
resized_width, output_grad.data()));
|
||||||
} else {
|
} else {
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
LegacyResizeBilinearGradKernel<T>, config.block_count,
|
LegacyResizeBilinearGradKernel<T>, config.block_count,
|
||||||
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
|
||||||
input_grad.data(), height_scale, width_scale, batch, original_height,
|
input_grad.data(), height_scale, width_scale, batch,
|
||||||
original_width, channels, resized_height, resized_width,
|
original_height, original_width, channels, resized_height,
|
||||||
output_grad.data()));
|
resized_width, output_grad.data()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -192,6 +192,20 @@ struct ImageResizerGradientState {
|
|||||||
original_height = original_image.dim_size(1);
|
original_height = original_image.dim_size(1);
|
||||||
original_width = original_image.dim_size(2);
|
original_width = original_image.dim_size(2);
|
||||||
|
|
||||||
|
// The following check is also carried out for the forward op. It is added
|
||||||
|
// here to prevent a divide-by-zero exception when either height_scale or
|
||||||
|
// width_scale is being calculated.
|
||||||
|
OP_REQUIRES(context, resized_height > 0 && resized_width > 0,
|
||||||
|
errors::InvalidArgument("resized dimensions must be positive"));
|
||||||
|
|
||||||
|
// The following check is also carried out for the forward op. It is added
|
||||||
|
// here to prevent either height_scale or width_scale from being set to
|
||||||
|
// zero, which would cause a divide-by-zero exception in the deterministic
|
||||||
|
// back-prop path.
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, original_height > 0 && original_width > 0,
|
||||||
|
errors::InvalidArgument("original dimensions must be positive"));
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context,
|
context,
|
||||||
FastBoundsCheck(original_height, std::numeric_limits<int32>::max()) &&
|
FastBoundsCheck(original_height, std::numeric_limits<int32>::max()) &&
|
||||||
|
|||||||
@ -5145,12 +5145,30 @@ cuda_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "image_grad_deterministic_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["ops/image_grad_deterministic_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":image_grad_test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "image_grad_test",
|
name = "image_grad_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["ops/image_grad_test.py"],
|
srcs = ["ops/image_grad_test.py"],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tfrt_enabled = True,
|
tfrt_enabled = True,
|
||||||
|
deps = [
|
||||||
|
":image_grad_test_base",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_grad_test_base",
|
||||||
|
srcs = ["ops/image_grad_test_base.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
|
|||||||
142
tensorflow/python/ops/image_grad_deterministic_test.py
Normal file
142
tensorflow/python/ops/image_grad_deterministic_test.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Functional tests for deterministic image op gradient functions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import image_grad_test_base as test_base
|
||||||
|
from tensorflow.python.ops import image_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeBilinearOpDeterministicTest(test_base.ResizeBilinearOpTestBase):
|
||||||
|
|
||||||
|
def _randomNDArray(self, shape):
|
||||||
|
return 2 * np.random.random_sample(shape) - 1
|
||||||
|
|
||||||
|
def _randomDataOp(self, shape, data_type):
|
||||||
|
return constant_op.constant(self._randomNDArray(shape), dtype=data_type)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
# Note that there is no 16-bit floating point format registered for GPU
|
||||||
|
{
|
||||||
|
'align_corners': False,
|
||||||
|
'half_pixel_centers': False,
|
||||||
|
'data_type': dtypes.float32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'align_corners': False,
|
||||||
|
'half_pixel_centers': False,
|
||||||
|
'data_type': dtypes.float64
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'align_corners': True,
|
||||||
|
'half_pixel_centers': False,
|
||||||
|
'data_type': dtypes.float32
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'align_corners': False,
|
||||||
|
'half_pixel_centers': True,
|
||||||
|
'data_type': dtypes.float32
|
||||||
|
})
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@test_util.run_cuda_only
|
||||||
|
def testDeterministicGradients(self, align_corners, half_pixel_centers,
|
||||||
|
data_type):
|
||||||
|
if not align_corners and test_util.is_xla_enabled():
|
||||||
|
# Align corners is deprecated in TF2.0, but align_corners==False is not
|
||||||
|
# supported by XLA.
|
||||||
|
self.skipTest('align_corners==False not currently supported by XLA')
|
||||||
|
with self.session(force_gpu=True):
|
||||||
|
seed = (
|
||||||
|
hash(align_corners) % 256 + hash(half_pixel_centers) % 256 +
|
||||||
|
hash(data_type) % 256)
|
||||||
|
np.random.seed(seed)
|
||||||
|
input_shape = (1, 25, 12, 3) # NHWC
|
||||||
|
output_shape = (1, 200, 250, 3)
|
||||||
|
input_image = self._randomDataOp(input_shape, data_type)
|
||||||
|
repeat_count = 3
|
||||||
|
if context.executing_eagerly():
|
||||||
|
|
||||||
|
def resize_bilinear_gradients(local_seed):
|
||||||
|
np.random.seed(local_seed)
|
||||||
|
upstream_gradients = self._randomDataOp(output_shape, dtypes.float32)
|
||||||
|
with backprop.GradientTape(persistent=True) as tape:
|
||||||
|
tape.watch(input_image)
|
||||||
|
output_image = image_ops.resize_bilinear(
|
||||||
|
input_image,
|
||||||
|
output_shape[1:3],
|
||||||
|
align_corners=align_corners,
|
||||||
|
half_pixel_centers=half_pixel_centers)
|
||||||
|
gradient_injector_output = output_image * upstream_gradients
|
||||||
|
return tape.gradient(gradient_injector_output, input_image)
|
||||||
|
|
||||||
|
for i in range(repeat_count):
|
||||||
|
local_seed = seed + i # select different upstream gradients
|
||||||
|
result_a = resize_bilinear_gradients(local_seed)
|
||||||
|
result_b = resize_bilinear_gradients(local_seed)
|
||||||
|
self.assertAllEqual(result_a, result_b)
|
||||||
|
else: # graph mode
|
||||||
|
upstream_gradients = array_ops.placeholder(
|
||||||
|
dtypes.float32, shape=output_shape, name='upstream_gradients')
|
||||||
|
output_image = image_ops.resize_bilinear(
|
||||||
|
input_image,
|
||||||
|
output_shape[1:3],
|
||||||
|
align_corners=align_corners,
|
||||||
|
half_pixel_centers=half_pixel_centers)
|
||||||
|
gradient_injector_output = output_image * upstream_gradients
|
||||||
|
# The gradient function behaves as if grad_ys is multiplied by the op
|
||||||
|
# gradient result, not passing the upstram gradients through the op's
|
||||||
|
# gradient generation graph. This is the reason for using the
|
||||||
|
# gradient injector
|
||||||
|
resize_bilinear_gradients = gradients_impl.gradients(
|
||||||
|
gradient_injector_output,
|
||||||
|
input_image,
|
||||||
|
grad_ys=None,
|
||||||
|
colocate_gradients_with_ops=True)[0]
|
||||||
|
for i in range(repeat_count):
|
||||||
|
feed_dict = {upstream_gradients: self._randomNDArray(output_shape)}
|
||||||
|
result_a = resize_bilinear_gradients.eval(feed_dict=feed_dict)
|
||||||
|
result_b = resize_bilinear_gradients.eval(feed_dict=feed_dict)
|
||||||
|
self.assertAllEqual(result_a, result_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Note that the effect of setting the following environment variable to
|
||||||
|
# 'true' is not tested. Unless we can find a simpler pattern for testing these
|
||||||
|
# environment variables, it would require this file to be made into a base
|
||||||
|
# and then two more test files to be created.
|
||||||
|
#
|
||||||
|
# When deterministic op functionality can be enabled and disabled between test
|
||||||
|
# cases in the same process, then the tests for deterministic op
|
||||||
|
# functionality, for this op and for other ops, will be able to be included in
|
||||||
|
# the same file with the regular tests, simplifying the organization of tests
|
||||||
|
# and test files.
|
||||||
|
os.environ['TF_DETERMINISTIC_OPS'] = '1'
|
||||||
|
test.main()
|
||||||
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,536 +12,21 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Tests for Python ops defined in image_grad.py."""
|
"""Functional tests for Image Op Gradients."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import numpy as np
|
from tensorflow.python.ops import image_grad_test_base as test_base
|
||||||
|
|
||||||
from tensorflow.python.eager import backprop
|
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
from tensorflow.python.ops import gradient_checker
|
|
||||||
from tensorflow.python.ops import gradient_checker_v2
|
|
||||||
from tensorflow.python.ops import gradients_impl
|
|
||||||
from tensorflow.python.ops import image_ops
|
|
||||||
from tensorflow.python.ops import gen_image_ops
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.for_all_test_methods(test_util.disable_xla,
|
|
||||||
'align_corners=False not supported by XLA')
|
|
||||||
class ResizeNearestNeighborOpTest(test.TestCase):
|
|
||||||
|
|
||||||
TYPES = [np.float32, np.float64]
|
|
||||||
|
|
||||||
def testShapeIsCorrectAfterOp(self):
|
|
||||||
in_shape = [1, 2, 2, 1]
|
|
||||||
out_shape = [1, 4, 6, 1]
|
|
||||||
|
|
||||||
for nptype in self.TYPES:
|
|
||||||
x = np.arange(0, 4).reshape(in_shape).astype(nptype)
|
|
||||||
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_nearest_neighbor(input_tensor,
|
|
||||||
out_shape[1:3])
|
|
||||||
with self.cached_session(use_gpu=True):
|
|
||||||
self.assertEqual(out_shape, list(resize_out.get_shape()))
|
|
||||||
resize_out = self.evaluate(resize_out)
|
|
||||||
self.assertEqual(out_shape, list(resize_out.shape))
|
|
||||||
|
|
||||||
def testGradFromResizeToLargerInBothDims(self):
|
|
||||||
in_shape = [1, 2, 3, 1]
|
|
||||||
out_shape = (1, 4, 6, 1)
|
|
||||||
|
|
||||||
for nptype in self.TYPES:
|
|
||||||
x = np.arange(0, 6).reshape(in_shape).astype(nptype)
|
|
||||||
|
|
||||||
def resize_nn(t, shape=out_shape):
|
|
||||||
return image_ops.resize_nearest_neighbor(t, shape[1:3])
|
|
||||||
|
|
||||||
with self.cached_session(use_gpu=True):
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
err = gradient_checker_v2.max_error(
|
|
||||||
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
|
|
||||||
self.assertLess(err, 1e-3)
|
|
||||||
|
|
||||||
def testGradFromResizeToSmallerInBothDims(self):
|
|
||||||
in_shape = [1, 4, 6, 1]
|
|
||||||
out_shape = (1, 2, 3, 1)
|
|
||||||
|
|
||||||
for nptype in self.TYPES:
|
|
||||||
x = np.arange(0, 24).reshape(in_shape).astype(nptype)
|
|
||||||
|
|
||||||
def resize_nn(t, shape=out_shape):
|
|
||||||
return image_ops.resize_nearest_neighbor(t, shape[1:3])
|
|
||||||
|
|
||||||
with self.cached_session(use_gpu=True):
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
err = gradient_checker_v2.max_error(
|
|
||||||
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
|
|
||||||
self.assertLess(err, 1e-3)
|
|
||||||
|
|
||||||
def testCompareGpuVsCpu(self):
|
|
||||||
in_shape = [1, 4, 6, 3]
|
|
||||||
out_shape = (1, 8, 16, 3)
|
|
||||||
|
|
||||||
for nptype in self.TYPES:
|
|
||||||
x = np.arange(0, np.prod(in_shape)).reshape(in_shape).astype(nptype)
|
|
||||||
for align_corners in [True, False]:
|
|
||||||
|
|
||||||
def resize_nn(t, shape=out_shape, align_corners=align_corners):
|
|
||||||
return image_ops.resize_nearest_neighbor(
|
|
||||||
t, shape[1:3], align_corners=align_corners)
|
|
||||||
|
|
||||||
with self.cached_session(use_gpu=False):
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
grad_cpu = gradient_checker_v2.compute_gradient(resize_nn,
|
|
||||||
[input_tensor])
|
|
||||||
|
|
||||||
with self.cached_session(use_gpu=True):
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
grad_gpu = gradient_checker_v2.compute_gradient(resize_nn,
|
|
||||||
[input_tensor])
|
|
||||||
|
|
||||||
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeBilinearOpTest(test.TestCase):
|
|
||||||
|
|
||||||
def testShapeIsCorrectAfterOp(self):
|
|
||||||
in_shape = [1, 2, 2, 1]
|
|
||||||
out_shape = [1, 4, 6, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
|
||||||
with self.cached_session():
|
|
||||||
self.assertEqual(out_shape, list(resize_out.get_shape()))
|
|
||||||
resize_out = self.evaluate(resize_out)
|
|
||||||
self.assertEqual(out_shape, list(resize_out.shape))
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradFromResizeToLargerInBothDims(self):
|
|
||||||
in_shape = [1, 2, 3, 1]
|
|
||||||
out_shape = [1, 4, 6, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
with self.cached_session():
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
|
||||||
err = gradient_checker.compute_gradient_error(
|
|
||||||
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
|
||||||
self.assertLess(err, 1e-3)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradFromResizeToSmallerInBothDims(self):
|
|
||||||
in_shape = [1, 4, 6, 1]
|
|
||||||
out_shape = [1, 2, 3, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
with self.cached_session():
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
|
||||||
err = gradient_checker.compute_gradient_error(
|
|
||||||
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
|
||||||
self.assertLess(err, 1e-3)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testCompareGpuVsCpu(self):
|
|
||||||
in_shape = [2, 4, 6, 3]
|
|
||||||
out_shape = [2, 8, 16, 3]
|
|
||||||
|
|
||||||
size = np.prod(in_shape)
|
|
||||||
x = 1.0 / size * np.arange(0, size).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
# Align corners will be deprecated for tf2.0 and the false version is not
|
|
||||||
# supported by XLA.
|
|
||||||
align_corner_options = [True
|
|
||||||
] if test_util.is_xla_enabled() else [True, False]
|
|
||||||
for align_corners in align_corner_options:
|
|
||||||
grad = {}
|
|
||||||
for use_gpu in [False, True]:
|
|
||||||
with self.cached_session(use_gpu=use_gpu):
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resized_tensor = image_ops.resize_bilinear(
|
|
||||||
input_tensor, out_shape[1:3], align_corners=align_corners)
|
|
||||||
grad[use_gpu] = gradient_checker.compute_gradient(
|
|
||||||
input_tensor, in_shape, resized_tensor, out_shape, x_init_value=x)
|
|
||||||
|
|
||||||
self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testTypes(self):
|
|
||||||
in_shape = [1, 4, 6, 1]
|
|
||||||
out_shape = [1, 2, 3, 1]
|
|
||||||
x = np.arange(0, 24).reshape(in_shape)
|
|
||||||
|
|
||||||
for use_gpu in [False, True]:
|
|
||||||
with self.cached_session(use_gpu=use_gpu) as sess:
|
|
||||||
for dtype in [np.float16, np.float32, np.float64]:
|
|
||||||
input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
|
||||||
grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0]
|
|
||||||
self.assertAllEqual(in_shape, grad.shape)
|
|
||||||
# Not using gradient_checker.compute_gradient as I didn't work out
|
|
||||||
# the changes required to compensate for the lower precision of
|
|
||||||
# float16 when computing the numeric jacobian.
|
|
||||||
# Instead, we just test the theoretical jacobian.
|
|
||||||
self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]],
|
|
||||||
[[0.], [0.], [0.], [0.], [0.], [0.]],
|
|
||||||
[[1.], [0.], [1.], [0.], [1.], [0.]],
|
|
||||||
[[0.], [0.], [0.], [0.], [0.], [0.]]]], grad)
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeBicubicOpTest(test.TestCase):
|
|
||||||
|
|
||||||
def testShapeIsCorrectAfterOp(self):
|
|
||||||
in_shape = [1, 2, 2, 1]
|
|
||||||
out_shape = [1, 4, 6, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
for align_corners in [True, False]:
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bicubic(
|
|
||||||
input_tensor, out_shape[1:3], align_corners=align_corners)
|
|
||||||
with self.cached_session():
|
|
||||||
self.assertEqual(out_shape, list(resize_out.get_shape()))
|
|
||||||
resize_out = self.evaluate(resize_out)
|
|
||||||
self.assertEqual(out_shape, list(resize_out.shape))
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradFromResizeToLargerInBothDims(self):
|
|
||||||
in_shape = [1, 2, 3, 1]
|
|
||||||
out_shape = [1, 4, 6, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
for align_corners in [True, False]:
|
|
||||||
with self.cached_session():
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3],
|
|
||||||
align_corners=align_corners)
|
|
||||||
err = gradient_checker.compute_gradient_error(
|
|
||||||
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
|
||||||
self.assertLess(err, 1e-3)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradFromResizeToSmallerInBothDims(self):
|
|
||||||
in_shape = [1, 4, 6, 1]
|
|
||||||
out_shape = [1, 2, 3, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
for align_corners in [True, False]:
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bicubic(
|
|
||||||
input_tensor, out_shape[1:3], align_corners=align_corners)
|
|
||||||
with self.cached_session():
|
|
||||||
err = gradient_checker.compute_gradient_error(
|
|
||||||
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
|
||||||
self.assertLess(err, 1e-3)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradOnUnsupportedType(self):
|
|
||||||
in_shape = [1, 4, 6, 1]
|
|
||||||
out_shape = [1, 2, 3, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
|
|
||||||
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
|
|
||||||
with self.cached_session():
|
|
||||||
grad = gradients_impl.gradients(input_tensor, [resize_out])
|
|
||||||
self.assertEqual([None], grad)
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleAndTranslateOpTest(test.TestCase):
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGrads(self):
|
|
||||||
in_shape = [1, 2, 3, 1]
|
|
||||||
out_shape = [1, 4, 6, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
kernel_types = [
|
|
||||||
'lanczos1', 'lanczos3', 'lanczos5', 'gaussian', 'box', 'triangle',
|
|
||||||
'keyscubic', 'mitchellcubic'
|
|
||||||
]
|
|
||||||
scales = [(1.0, 1.0), (0.37, 0.47), (2.1, 2.1)]
|
|
||||||
translations = [(0.0, 0.0), (3.14, 1.19), (2.1, 3.1), (100.0, 200.0)]
|
|
||||||
for scale in scales:
|
|
||||||
for translation in translations:
|
|
||||||
for kernel_type in kernel_types:
|
|
||||||
for antialias in [True, False]:
|
|
||||||
with self.cached_session():
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
scale_and_translate_out = image_ops.scale_and_translate(
|
|
||||||
input_tensor,
|
|
||||||
out_shape[1:3],
|
|
||||||
scale=constant_op.constant(scale),
|
|
||||||
translation=constant_op.constant(translation),
|
|
||||||
kernel_type=kernel_type,
|
|
||||||
antialias=antialias)
|
|
||||||
err = gradient_checker.compute_gradient_error(
|
|
||||||
input_tensor,
|
|
||||||
in_shape,
|
|
||||||
scale_and_translate_out,
|
|
||||||
out_shape,
|
|
||||||
x_init_value=x)
|
|
||||||
self.assertLess(err, 1e-3)
|
|
||||||
|
|
||||||
def testIdentityGrads(self):
|
|
||||||
"""Tests that Gradients for 1.0 scale should be ones for some kernels."""
|
|
||||||
in_shape = [1, 2, 3, 1]
|
|
||||||
out_shape = [1, 4, 6, 1]
|
|
||||||
|
|
||||||
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
|
||||||
|
|
||||||
kernel_types = ['lanczos1', 'lanczos3', 'lanczos5', 'triangle', 'keyscubic']
|
|
||||||
scale = (1.0, 1.0)
|
|
||||||
translation = (0.0, 0.0)
|
|
||||||
antialias = True
|
|
||||||
for kernel_type in kernel_types:
|
|
||||||
with self.cached_session():
|
|
||||||
input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
with backprop.GradientTape() as tape:
|
|
||||||
tape.watch(input_tensor)
|
|
||||||
scale_and_translate_out = image_ops.scale_and_translate(
|
|
||||||
input_tensor,
|
|
||||||
out_shape[1:3],
|
|
||||||
scale=constant_op.constant(scale),
|
|
||||||
translation=constant_op.constant(translation),
|
|
||||||
kernel_type=kernel_type,
|
|
||||||
antialias=antialias)
|
|
||||||
grad = tape.gradient(scale_and_translate_out, input_tensor)[0]
|
|
||||||
grad_v = self.evaluate(grad)
|
|
||||||
self.assertAllClose(np.ones_like(grad_v), grad_v)
|
|
||||||
|
|
||||||
|
|
||||||
class CropAndResizeOpTest(test.TestCase):
|
|
||||||
|
|
||||||
def testShapeIsCorrectAfterOp(self):
|
|
||||||
batch = 2
|
|
||||||
image_height = 3
|
|
||||||
image_width = 4
|
|
||||||
crop_height = 4
|
|
||||||
crop_width = 5
|
|
||||||
depth = 2
|
|
||||||
num_boxes = 2
|
|
||||||
|
|
||||||
image_shape = [batch, image_height, image_width, depth]
|
|
||||||
crop_size = [crop_height, crop_width]
|
|
||||||
crops_shape = [num_boxes, crop_height, crop_width, depth]
|
|
||||||
|
|
||||||
image = np.arange(0, batch * image_height * image_width *
|
|
||||||
depth).reshape(image_shape).astype(np.float32)
|
|
||||||
boxes = np.array([[0, 0, 1, 1], [.1, .2, .7, .8]], dtype=np.float32)
|
|
||||||
box_ind = np.array([0, 1], dtype=np.int32)
|
|
||||||
|
|
||||||
crops = image_ops.crop_and_resize(
|
|
||||||
constant_op.constant(image, shape=image_shape),
|
|
||||||
constant_op.constant(boxes, shape=[num_boxes, 4]),
|
|
||||||
constant_op.constant(box_ind, shape=[num_boxes]),
|
|
||||||
constant_op.constant(crop_size, shape=[2]))
|
|
||||||
with self.session(use_gpu=True) as sess:
|
|
||||||
self.assertEqual(crops_shape, list(crops.get_shape()))
|
|
||||||
crops = self.evaluate(crops)
|
|
||||||
self.assertEqual(crops_shape, list(crops.shape))
|
|
||||||
|
|
||||||
def _randomUniformAvoidAnchors(self, low, high, anchors, radius, num_samples):
|
|
||||||
"""Generate samples that are far enough from a set of anchor points.
|
|
||||||
|
|
||||||
We generate uniform samples in [low, high], then reject those that are less
|
|
||||||
than radius away from any point in anchors. We stop after we have accepted
|
|
||||||
num_samples samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
low: The lower end of the interval.
|
|
||||||
high: The upper end of the interval.
|
|
||||||
anchors: A list of length num_crops with anchor points to avoid.
|
|
||||||
radius: Distance threshold for the samples from the anchors.
|
|
||||||
num_samples: How many samples to produce.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
samples: A list of length num_samples with the accepted samples.
|
|
||||||
"""
|
|
||||||
self.assertTrue(low < high)
|
|
||||||
self.assertTrue(radius >= 0)
|
|
||||||
num_anchors = len(anchors)
|
|
||||||
# Make sure that at least half of the interval is not forbidden.
|
|
||||||
self.assertTrue(2 * radius * num_anchors < 0.5 * (high - low))
|
|
||||||
anchors = np.reshape(anchors, num_anchors)
|
|
||||||
samples = []
|
|
||||||
while len(samples) < num_samples:
|
|
||||||
sample = np.random.uniform(low, high)
|
|
||||||
if np.all(np.fabs(sample - anchors) > radius):
|
|
||||||
samples.append(sample)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testGradRandomBoxes(self):
|
|
||||||
"""Test that the gradient is correct for randomly generated boxes.
|
|
||||||
|
|
||||||
The mapping is piecewise differentiable with respect to the box coordinates.
|
|
||||||
The points where the function is not differentiable are those which are
|
|
||||||
mapped to image pixels, i.e., the normalized y coordinates in
|
|
||||||
np.linspace(0, 1, image_height) and normalized x coordinates in
|
|
||||||
np.linspace(0, 1, image_width). Make sure that the box coordinates are
|
|
||||||
sufficiently far away from those rectangular grid centers that are points of
|
|
||||||
discontinuity, so that the finite difference Jacobian is close to the
|
|
||||||
computed one.
|
|
||||||
"""
|
|
||||||
np.random.seed(1) # Make it reproducible.
|
|
||||||
delta = 1e-3
|
|
||||||
radius = 2 * delta
|
|
||||||
low, high = -0.5, 1.5 # Also covers the case of extrapolation.
|
|
||||||
|
|
||||||
image_height = 4
|
|
||||||
for image_width in range(1, 3):
|
|
||||||
for crop_height in range(1, 3):
|
|
||||||
for crop_width in range(2, 4):
|
|
||||||
for depth in range(1, 3):
|
|
||||||
for num_boxes in range(1, 3):
|
|
||||||
|
|
||||||
batch = num_boxes
|
|
||||||
image_shape = [batch, image_height, image_width, depth]
|
|
||||||
crop_size = [crop_height, crop_width]
|
|
||||||
crops_shape = [num_boxes, crop_height, crop_width, depth]
|
|
||||||
boxes_shape = [num_boxes, 4]
|
|
||||||
|
|
||||||
image = np.arange(0, batch * image_height * image_width *
|
|
||||||
depth).reshape(image_shape).astype(np.float32)
|
|
||||||
boxes = []
|
|
||||||
for _ in range(num_boxes):
|
|
||||||
# pylint: disable=unbalanced-tuple-unpacking
|
|
||||||
y1, y2 = self._randomUniformAvoidAnchors(
|
|
||||||
low, high, np.linspace(0, 1, image_height), radius, 2)
|
|
||||||
x1, x2 = self._randomUniformAvoidAnchors(
|
|
||||||
low, high, np.linspace(0, 1, image_width), radius, 2)
|
|
||||||
# pylint: enable=unbalanced-tuple-unpacking
|
|
||||||
boxes.append([y1, x1, y2, x2])
|
|
||||||
|
|
||||||
boxes = np.array(boxes, dtype=np.float32)
|
|
||||||
box_ind = np.arange(batch, dtype=np.int32)
|
|
||||||
|
|
||||||
with self.cached_session(use_gpu=True):
|
|
||||||
image_tensor = constant_op.constant(image, shape=image_shape)
|
|
||||||
boxes_tensor = constant_op.constant(boxes, shape=[num_boxes, 4])
|
|
||||||
box_ind_tensor = constant_op.constant(
|
|
||||||
box_ind, shape=[num_boxes])
|
|
||||||
crops = image_ops.crop_and_resize(
|
|
||||||
image_tensor,
|
|
||||||
boxes_tensor,
|
|
||||||
box_ind_tensor,
|
|
||||||
constant_op.constant(
|
|
||||||
crop_size, shape=[2]))
|
|
||||||
|
|
||||||
err = gradient_checker.compute_gradient_error(
|
|
||||||
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
|
|
||||||
crops,
|
|
||||||
crops_shape,
|
|
||||||
delta=delta,
|
|
||||||
x_init_value=[image, boxes])
|
|
||||||
|
|
||||||
self.assertLess(err, 2e-3)
|
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
|
||||||
class RGBToHSVOpTest(test.TestCase):
|
|
||||||
|
|
||||||
TYPES = [np.float32, np.float64]
|
|
||||||
|
|
||||||
def testShapeIsCorrectAfterOp(self):
|
|
||||||
in_shape = [2, 20, 30, 3]
|
|
||||||
out_shape = [2, 20, 30, 3]
|
|
||||||
|
|
||||||
for nptype in self.TYPES:
|
|
||||||
x = np.random.randint(0, high=255, size=[2, 20, 30, 3]).astype(nptype)
|
|
||||||
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
hsv_out = gen_image_ops.rgb_to_hsv(rgb_input_tensor)
|
|
||||||
with self.cached_session(use_gpu=True):
|
|
||||||
self.assertEqual(out_shape, list(hsv_out.get_shape()))
|
|
||||||
hsv_out = self.evaluate(hsv_out)
|
|
||||||
self.assertEqual(out_shape, list(hsv_out.shape))
|
|
||||||
|
|
||||||
def testRGBToHSVGradSimpleCase(self):
|
|
||||||
|
|
||||||
def f(x):
|
|
||||||
return gen_image_ops.rgb_to_hsv(x)
|
|
||||||
|
|
||||||
# Building a simple input tensor to avoid any discontinuity
|
|
||||||
x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8,
|
|
||||||
0.9]]).astype(np.float32)
|
|
||||||
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
|
|
||||||
# Computing Analytical and Numerical gradients of f(x)
|
|
||||||
analytical, numerical = gradient_checker_v2.compute_gradient(
|
|
||||||
f, [rgb_input_tensor])
|
|
||||||
self.assertAllClose(numerical, analytical, atol=1e-4)
|
|
||||||
|
|
||||||
def testRGBToHSVGradRandomCase(self):
|
|
||||||
|
|
||||||
def f(x):
|
|
||||||
return gen_image_ops.rgb_to_hsv(x)
|
|
||||||
|
|
||||||
np.random.seed(0)
|
|
||||||
# Building a simple input tensor to avoid any discontinuity
|
|
||||||
x = np.random.rand(1, 5, 5, 3).astype(np.float32)
|
|
||||||
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
|
|
||||||
# Computing Analytical and Numerical gradients of f(x)
|
|
||||||
self.assertLess(
|
|
||||||
gradient_checker_v2.max_error(
|
|
||||||
*gradient_checker_v2.compute_gradient(f, [rgb_input_tensor])), 1e-4)
|
|
||||||
|
|
||||||
def testRGBToHSVGradSpecialCaseRGreatest(self):
|
|
||||||
# This test tests a specific subset of the input space
|
|
||||||
# with a dummy function implemented with native TF operations.
|
|
||||||
in_shape = [2, 10, 20, 3]
|
|
||||||
|
|
||||||
def f(x):
|
|
||||||
return gen_image_ops.rgb_to_hsv(x)
|
|
||||||
|
|
||||||
def f_dummy(x):
|
|
||||||
# This dummy function is a implementation of RGB to HSV using
|
|
||||||
# primitive TF functions for one particular case when R>G>B.
|
|
||||||
r = x[..., 0]
|
|
||||||
g = x[..., 1]
|
|
||||||
b = x[..., 2]
|
|
||||||
# Since MAX = r and MIN = b, we get the following h,s,v values.
|
|
||||||
v = r
|
|
||||||
s = 1 - math_ops.div_no_nan(b, r)
|
|
||||||
h = 60 * math_ops.div_no_nan(g - b, r - b)
|
|
||||||
h = h / 360
|
|
||||||
return array_ops.stack([h, s, v], axis=-1)
|
|
||||||
|
|
||||||
# Building a custom input tensor where R>G>B
|
|
||||||
x_reds = np.ones((in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
|
|
||||||
x_greens = 0.5 * np.ones(
|
|
||||||
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
|
|
||||||
x_blues = 0.2 * np.ones(
|
|
||||||
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
|
|
||||||
x = np.stack([x_reds, x_greens, x_blues], axis=-1)
|
|
||||||
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
|
|
||||||
|
|
||||||
# Computing Analytical and Numerical gradients of f(x)
|
|
||||||
analytical, numerical = gradient_checker_v2.compute_gradient(
|
|
||||||
f, [rgb_input_tensor])
|
|
||||||
# Computing Analytical and Numerical gradients of f_dummy(x)
|
|
||||||
analytical_dummy, numerical_dummy = gradient_checker_v2.compute_gradient(
|
|
||||||
f_dummy, [rgb_input_tensor])
|
|
||||||
self.assertAllClose(numerical, analytical, atol=1e-4)
|
|
||||||
self.assertAllClose(analytical_dummy, analytical, atol=1e-4)
|
|
||||||
self.assertAllClose(numerical_dummy, numerical, atol=1e-4)
|
|
||||||
|
|
||||||
|
ResizeNearestNeighborOpTest = test_base.ResizeNearestNeighborOpTestBase
|
||||||
|
ResizeBilinearOpTest = test_base.ResizeBilinearOpTestBase
|
||||||
|
ResizeBicubicOpTest = test_base.ResizeBicubicOpTestBase
|
||||||
|
ScaleAndTranslateOpTest = test_base.ScaleAndTranslateOpTestBase
|
||||||
|
CropAndResizeOpTest = test_base.CropAndResizeOpTestBase
|
||||||
|
RGBToHSVOpTest = test_base.RGBToHSVOpTestBase
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|||||||
622
tensorflow/python/ops/image_grad_test_base.py
Normal file
622
tensorflow/python/ops/image_grad_test_base.py
Normal file
@ -0,0 +1,622 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for Python ops defined in image_grad.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import image_ops
|
||||||
|
from tensorflow.python.ops import gen_image_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.for_all_test_methods(test_util.disable_xla,
|
||||||
|
'align_corners=False not supported by XLA')
|
||||||
|
class ResizeNearestNeighborOpTestBase(test.TestCase):
|
||||||
|
|
||||||
|
TYPES = [np.float32, np.float64]
|
||||||
|
|
||||||
|
def testShapeIsCorrectAfterOp(self):
|
||||||
|
in_shape = [1, 2, 2, 1]
|
||||||
|
out_shape = [1, 4, 6, 1]
|
||||||
|
|
||||||
|
for nptype in self.TYPES:
|
||||||
|
x = np.arange(0, 4).reshape(in_shape).astype(nptype)
|
||||||
|
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_nearest_neighbor(input_tensor,
|
||||||
|
out_shape[1:3])
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
self.assertEqual(out_shape, list(resize_out.get_shape()))
|
||||||
|
resize_out = self.evaluate(resize_out)
|
||||||
|
self.assertEqual(out_shape, list(resize_out.shape))
|
||||||
|
|
||||||
|
def testGradFromResizeToLargerInBothDims(self):
|
||||||
|
in_shape = [1, 2, 3, 1]
|
||||||
|
out_shape = (1, 4, 6, 1)
|
||||||
|
|
||||||
|
for nptype in self.TYPES:
|
||||||
|
x = np.arange(0, 6).reshape(in_shape).astype(nptype)
|
||||||
|
|
||||||
|
def resize_nn(t, shape=out_shape):
|
||||||
|
return image_ops.resize_nearest_neighbor(t, shape[1:3])
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
err = gradient_checker_v2.max_error(
|
||||||
|
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
|
||||||
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
|
def testGradFromResizeToSmallerInBothDims(self):
|
||||||
|
in_shape = [1, 4, 6, 1]
|
||||||
|
out_shape = (1, 2, 3, 1)
|
||||||
|
|
||||||
|
for nptype in self.TYPES:
|
||||||
|
x = np.arange(0, 24).reshape(in_shape).astype(nptype)
|
||||||
|
|
||||||
|
def resize_nn(t, shape=out_shape):
|
||||||
|
return image_ops.resize_nearest_neighbor(t, shape[1:3])
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
err = gradient_checker_v2.max_error(
|
||||||
|
*gradient_checker_v2.compute_gradient(resize_nn, [input_tensor]))
|
||||||
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
|
def testCompareGpuVsCpu(self):
|
||||||
|
in_shape = [1, 4, 6, 3]
|
||||||
|
out_shape = (1, 8, 16, 3)
|
||||||
|
|
||||||
|
for nptype in self.TYPES:
|
||||||
|
x = np.arange(0, np.prod(in_shape)).reshape(in_shape).astype(nptype)
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
|
||||||
|
def resize_nn(t, shape=out_shape, align_corners=align_corners):
|
||||||
|
return image_ops.resize_nearest_neighbor(
|
||||||
|
t, shape[1:3], align_corners=align_corners)
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=False):
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
grad_cpu = gradient_checker_v2.compute_gradient(
|
||||||
|
resize_nn, [input_tensor])
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
grad_gpu = gradient_checker_v2.compute_gradient(
|
||||||
|
resize_nn, [input_tensor])
|
||||||
|
|
||||||
|
self.assertAllClose(grad_cpu, grad_gpu, rtol=1e-5, atol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeBilinearOpTestBase(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _itGen(self, smaller_shape, larger_shape):
|
||||||
|
up_sample = (smaller_shape, larger_shape)
|
||||||
|
down_sample = (larger_shape, smaller_shape)
|
||||||
|
pass_through = (larger_shape, larger_shape)
|
||||||
|
shape_pairs = (up_sample, down_sample, pass_through)
|
||||||
|
# Align corners is deprecated in TF2.0, but align_corners==False is not
|
||||||
|
# supported by XLA.
|
||||||
|
options = [(True, False)]
|
||||||
|
if not test_util.is_xla_enabled():
|
||||||
|
options += [(False, True), (False, False)]
|
||||||
|
for align_corners, half_pixel_centers in options:
|
||||||
|
for in_shape, out_shape in shape_pairs:
|
||||||
|
yield in_shape, out_shape, align_corners, half_pixel_centers
|
||||||
|
|
||||||
|
def _getJacobians(self,
|
||||||
|
in_shape,
|
||||||
|
out_shape,
|
||||||
|
align_corners=False,
|
||||||
|
half_pixel_centers=False,
|
||||||
|
dtype=np.float32,
|
||||||
|
use_gpu=False,
|
||||||
|
force_gpu=False):
|
||||||
|
with self.cached_session(use_gpu=use_gpu, force_gpu=force_gpu) as sess:
|
||||||
|
# Input values should not influence gradients
|
||||||
|
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(dtype)
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resized_tensor = image_ops.resize_bilinear(
|
||||||
|
input_tensor,
|
||||||
|
out_shape[1:3],
|
||||||
|
align_corners=align_corners,
|
||||||
|
half_pixel_centers=half_pixel_centers)
|
||||||
|
# compute_gradient will use a random tensor as the init value
|
||||||
|
return gradient_checker.compute_gradient(input_tensor, in_shape,
|
||||||
|
resized_tensor, out_shape)
|
||||||
|
|
||||||
|
@parameterized.parameters({
|
||||||
|
'batch_size': 1,
|
||||||
|
'channel_count': 1
|
||||||
|
}, {
|
||||||
|
'batch_size': 2,
|
||||||
|
'channel_count': 3
|
||||||
|
}, {
|
||||||
|
'batch_size': 5,
|
||||||
|
'channel_count': 4
|
||||||
|
})
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testShapes(self, batch_size, channel_count):
|
||||||
|
smaller_shape = [batch_size, 2, 3, channel_count]
|
||||||
|
larger_shape = [batch_size, 4, 6, channel_count]
|
||||||
|
for in_shape, out_shape, align_corners, half_pixel_centers in \
|
||||||
|
self._itGen(smaller_shape, larger_shape):
|
||||||
|
# Input values should not influence shapes
|
||||||
|
x = np.arange(np.prod(in_shape)).reshape(in_shape).astype(np.float32)
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resized_tensor = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
||||||
|
self.assertEqual(out_shape, list(resized_tensor.get_shape()))
|
||||||
|
grad_tensor = gradients_impl.gradients(resized_tensor, input_tensor)[0]
|
||||||
|
self.assertEqual(in_shape, list(grad_tensor.get_shape()))
|
||||||
|
with self.cached_session():
|
||||||
|
resized_values = self.evaluate(resized_tensor)
|
||||||
|
self.assertEqual(out_shape, list(resized_values.shape))
|
||||||
|
grad_values = self.evaluate(grad_tensor)
|
||||||
|
self.assertEqual(in_shape, list(grad_values.shape))
|
||||||
|
|
||||||
|
@parameterized.parameters({
|
||||||
|
'batch_size': 1,
|
||||||
|
'channel_count': 1
|
||||||
|
}, {
|
||||||
|
'batch_size': 4,
|
||||||
|
'channel_count': 3
|
||||||
|
}, {
|
||||||
|
'batch_size': 3,
|
||||||
|
'channel_count': 2
|
||||||
|
})
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradients(self, batch_size, channel_count):
|
||||||
|
smaller_shape = [batch_size, 2, 3, channel_count]
|
||||||
|
larger_shape = [batch_size, 5, 6, channel_count]
|
||||||
|
for in_shape, out_shape, align_corners, half_pixel_centers in \
|
||||||
|
self._itGen(smaller_shape, larger_shape):
|
||||||
|
jacob_a, jacob_n = self._getJacobians(in_shape, out_shape, align_corners,
|
||||||
|
half_pixel_centers)
|
||||||
|
threshold = 1e-4
|
||||||
|
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testTypes(self):
|
||||||
|
in_shape = [1, 4, 6, 1]
|
||||||
|
out_shape = [1, 2, 3, 1]
|
||||||
|
for use_gpu in [False, True]:
|
||||||
|
for dtype in [np.float16, np.float32, np.float64]:
|
||||||
|
jacob_a, jacob_n = self._getJacobians(
|
||||||
|
in_shape, out_shape, dtype=dtype, use_gpu=use_gpu)
|
||||||
|
if dtype == np.float16:
|
||||||
|
# Compare fp16 analytical gradients to fp32 numerical gradients,
|
||||||
|
# since fp16 numerical gradients are too imprecise unless great
|
||||||
|
# care is taken with choosing the inputs and the delta. This is
|
||||||
|
# a weaker, but pragmatic, check (in particular, it does not test
|
||||||
|
# the op itself, only its gradient).
|
||||||
|
_, jacob_n = self._getJacobians(
|
||||||
|
in_shape, out_shape, dtype=np.float32, use_gpu=use_gpu)
|
||||||
|
threshold = 1e-3
|
||||||
|
if dtype == np.float64:
|
||||||
|
threshold = 1e-5
|
||||||
|
self.assertAllClose(jacob_a, jacob_n, threshold, threshold)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradOnUnsupportedType(self):
|
||||||
|
in_shape = [1, 4, 6, 1]
|
||||||
|
out_shape = [1, 2, 3, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
|
||||||
|
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3])
|
||||||
|
with self.cached_session():
|
||||||
|
grad = gradients_impl.gradients(resize_out, [input_tensor])
|
||||||
|
self.assertEqual([None], grad)
|
||||||
|
|
||||||
|
def _gpuVsCpuCase(self, in_shape, out_shape, align_corners,
|
||||||
|
half_pixel_centers, dtype):
|
||||||
|
grad = {}
|
||||||
|
for use_gpu in [False, True]:
|
||||||
|
grad[use_gpu] = self._getJacobians(
|
||||||
|
in_shape,
|
||||||
|
out_shape,
|
||||||
|
align_corners,
|
||||||
|
half_pixel_centers,
|
||||||
|
dtype=dtype,
|
||||||
|
use_gpu=use_gpu)
|
||||||
|
threshold = 1e-4
|
||||||
|
# Note that this is comparing both analytical and numerical Jacobians
|
||||||
|
self.assertAllClose(grad[False], grad[True], rtol=threshold, atol=threshold)
|
||||||
|
|
||||||
|
@parameterized.parameters({
|
||||||
|
'batch_size': 1,
|
||||||
|
'channel_count': 1
|
||||||
|
}, {
|
||||||
|
'batch_size': 2,
|
||||||
|
'channel_count': 3
|
||||||
|
}, {
|
||||||
|
'batch_size': 5,
|
||||||
|
'channel_count': 4
|
||||||
|
})
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testCompareGpuVsCpu(self, batch_size, channel_count):
|
||||||
|
smaller_shape = [batch_size, 4, 6, channel_count]
|
||||||
|
larger_shape = [batch_size, 8, 16, channel_count]
|
||||||
|
for params in self._itGen(smaller_shape, larger_shape):
|
||||||
|
self._gpuVsCpuCase(*params, dtype=np.float32)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testCompareGpuVsCpuFloat64(self):
|
||||||
|
in_shape = [1, 5, 7, 1]
|
||||||
|
out_shape = [1, 9, 11, 1]
|
||||||
|
# Note that there is no 16-bit floating-point format registered for GPU
|
||||||
|
self._gpuVsCpuCase(
|
||||||
|
in_shape,
|
||||||
|
out_shape,
|
||||||
|
align_corners=True,
|
||||||
|
half_pixel_centers=False,
|
||||||
|
dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeBicubicOpTestBase(test.TestCase):
|
||||||
|
|
||||||
|
def testShapeIsCorrectAfterOp(self):
|
||||||
|
in_shape = [1, 2, 2, 1]
|
||||||
|
out_shape = [1, 4, 6, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 4).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(
|
||||||
|
input_tensor, out_shape[1:3], align_corners=align_corners)
|
||||||
|
with self.cached_session():
|
||||||
|
self.assertEqual(out_shape, list(resize_out.get_shape()))
|
||||||
|
resize_out = self.evaluate(resize_out)
|
||||||
|
self.assertEqual(out_shape, list(resize_out.shape))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradFromResizeToLargerInBothDims(self):
|
||||||
|
in_shape = [1, 2, 3, 1]
|
||||||
|
out_shape = [1, 4, 6, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
with self.cached_session():
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(
|
||||||
|
input_tensor, out_shape[1:3], align_corners=align_corners)
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
||||||
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradFromResizeToSmallerInBothDims(self):
|
||||||
|
in_shape = [1, 4, 6, 1]
|
||||||
|
out_shape = [1, 2, 3, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 24).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
for align_corners in [True, False]:
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(
|
||||||
|
input_tensor, out_shape[1:3], align_corners=align_corners)
|
||||||
|
with self.cached_session():
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
input_tensor, in_shape, resize_out, out_shape, x_init_value=x)
|
||||||
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradOnUnsupportedType(self):
|
||||||
|
in_shape = [1, 4, 6, 1]
|
||||||
|
out_shape = [1, 2, 3, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 24).reshape(in_shape).astype(np.uint8)
|
||||||
|
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
resize_out = image_ops.resize_bicubic(input_tensor, out_shape[1:3])
|
||||||
|
with self.cached_session():
|
||||||
|
grad = gradients_impl.gradients(resize_out, [input_tensor])
|
||||||
|
self.assertEqual([None], grad)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleAndTranslateOpTestBase(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGrads(self):
|
||||||
|
in_shape = [1, 2, 3, 1]
|
||||||
|
out_shape = [1, 4, 6, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
kernel_types = [
|
||||||
|
'lanczos1', 'lanczos3', 'lanczos5', 'gaussian', 'box', 'triangle',
|
||||||
|
'keyscubic', 'mitchellcubic'
|
||||||
|
]
|
||||||
|
scales = [(1.0, 1.0), (0.37, 0.47), (2.1, 2.1)]
|
||||||
|
translations = [(0.0, 0.0), (3.14, 1.19), (2.1, 3.1), (100.0, 200.0)]
|
||||||
|
for scale in scales:
|
||||||
|
for translation in translations:
|
||||||
|
for kernel_type in kernel_types:
|
||||||
|
for antialias in [True, False]:
|
||||||
|
with self.cached_session():
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
scale_and_translate_out = image_ops.scale_and_translate(
|
||||||
|
input_tensor,
|
||||||
|
out_shape[1:3],
|
||||||
|
scale=constant_op.constant(scale),
|
||||||
|
translation=constant_op.constant(translation),
|
||||||
|
kernel_type=kernel_type,
|
||||||
|
antialias=antialias)
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
input_tensor,
|
||||||
|
in_shape,
|
||||||
|
scale_and_translate_out,
|
||||||
|
out_shape,
|
||||||
|
x_init_value=x)
|
||||||
|
self.assertLess(err, 1e-3)
|
||||||
|
|
||||||
|
def testIdentityGrads(self):
|
||||||
|
"""Tests that Gradients for 1.0 scale should be ones for some kernels."""
|
||||||
|
in_shape = [1, 2, 3, 1]
|
||||||
|
out_shape = [1, 4, 6, 1]
|
||||||
|
|
||||||
|
x = np.arange(0, 6).reshape(in_shape).astype(np.float32)
|
||||||
|
|
||||||
|
kernel_types = ['lanczos1', 'lanczos3', 'lanczos5', 'triangle', 'keyscubic']
|
||||||
|
scale = (1.0, 1.0)
|
||||||
|
translation = (0.0, 0.0)
|
||||||
|
antialias = True
|
||||||
|
for kernel_type in kernel_types:
|
||||||
|
with self.cached_session():
|
||||||
|
input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(input_tensor)
|
||||||
|
scale_and_translate_out = image_ops.scale_and_translate(
|
||||||
|
input_tensor,
|
||||||
|
out_shape[1:3],
|
||||||
|
scale=constant_op.constant(scale),
|
||||||
|
translation=constant_op.constant(translation),
|
||||||
|
kernel_type=kernel_type,
|
||||||
|
antialias=antialias)
|
||||||
|
grad = tape.gradient(scale_and_translate_out, input_tensor)[0]
|
||||||
|
grad_v = self.evaluate(grad)
|
||||||
|
self.assertAllClose(np.ones_like(grad_v), grad_v)
|
||||||
|
|
||||||
|
|
||||||
|
class CropAndResizeOpTestBase(test.TestCase):
|
||||||
|
|
||||||
|
def testShapeIsCorrectAfterOp(self):
|
||||||
|
batch = 2
|
||||||
|
image_height = 3
|
||||||
|
image_width = 4
|
||||||
|
crop_height = 4
|
||||||
|
crop_width = 5
|
||||||
|
depth = 2
|
||||||
|
num_boxes = 2
|
||||||
|
|
||||||
|
image_shape = [batch, image_height, image_width, depth]
|
||||||
|
crop_size = [crop_height, crop_width]
|
||||||
|
crops_shape = [num_boxes, crop_height, crop_width, depth]
|
||||||
|
|
||||||
|
image = np.arange(0, batch * image_height * image_width *
|
||||||
|
depth).reshape(image_shape).astype(np.float32)
|
||||||
|
boxes = np.array([[0, 0, 1, 1], [.1, .2, .7, .8]], dtype=np.float32)
|
||||||
|
box_ind = np.array([0, 1], dtype=np.int32)
|
||||||
|
|
||||||
|
crops = image_ops.crop_and_resize(
|
||||||
|
constant_op.constant(image, shape=image_shape),
|
||||||
|
constant_op.constant(boxes, shape=[num_boxes, 4]),
|
||||||
|
constant_op.constant(box_ind, shape=[num_boxes]),
|
||||||
|
constant_op.constant(crop_size, shape=[2]))
|
||||||
|
with self.session(use_gpu=True) as sess:
|
||||||
|
self.assertEqual(crops_shape, list(crops.get_shape()))
|
||||||
|
crops = self.evaluate(crops)
|
||||||
|
self.assertEqual(crops_shape, list(crops.shape))
|
||||||
|
|
||||||
|
def _randomUniformAvoidAnchors(self, low, high, anchors, radius, num_samples):
|
||||||
|
"""Generate samples that are far enough from a set of anchor points.
|
||||||
|
|
||||||
|
We generate uniform samples in [low, high], then reject those that are less
|
||||||
|
than radius away from any point in anchors. We stop after we have accepted
|
||||||
|
num_samples samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
low: The lower end of the interval.
|
||||||
|
high: The upper end of the interval.
|
||||||
|
anchors: A list of length num_crops with anchor points to avoid.
|
||||||
|
radius: Distance threshold for the samples from the anchors.
|
||||||
|
num_samples: How many samples to produce.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
samples: A list of length num_samples with the accepted samples.
|
||||||
|
"""
|
||||||
|
self.assertTrue(low < high)
|
||||||
|
self.assertTrue(radius >= 0)
|
||||||
|
num_anchors = len(anchors)
|
||||||
|
# Make sure that at least half of the interval is not forbidden.
|
||||||
|
self.assertTrue(2 * radius * num_anchors < 0.5 * (high - low))
|
||||||
|
anchors = np.reshape(anchors, num_anchors)
|
||||||
|
samples = []
|
||||||
|
while len(samples) < num_samples:
|
||||||
|
sample = np.random.uniform(low, high)
|
||||||
|
if np.all(np.fabs(sample - anchors) > radius):
|
||||||
|
samples.append(sample)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testGradRandomBoxes(self):
|
||||||
|
"""Test that the gradient is correct for randomly generated boxes.
|
||||||
|
|
||||||
|
The mapping is piecewise differentiable with respect to the box coordinates.
|
||||||
|
The points where the function is not differentiable are those which are
|
||||||
|
mapped to image pixels, i.e., the normalized y coordinates in
|
||||||
|
np.linspace(0, 1, image_height) and normalized x coordinates in
|
||||||
|
np.linspace(0, 1, image_width). Make sure that the box coordinates are
|
||||||
|
sufficiently far away from those rectangular grid centers that are points of
|
||||||
|
discontinuity, so that the finite difference Jacobian is close to the
|
||||||
|
computed one.
|
||||||
|
"""
|
||||||
|
np.random.seed(1) # Make it reproducible.
|
||||||
|
delta = 1e-3
|
||||||
|
radius = 2 * delta
|
||||||
|
low, high = -0.5, 1.5 # Also covers the case of extrapolation.
|
||||||
|
|
||||||
|
image_height = 4
|
||||||
|
for image_width in range(1, 3):
|
||||||
|
for crop_height in range(1, 3):
|
||||||
|
for crop_width in range(2, 4):
|
||||||
|
for depth in range(1, 3):
|
||||||
|
for num_boxes in range(1, 3):
|
||||||
|
|
||||||
|
batch = num_boxes
|
||||||
|
image_shape = [batch, image_height, image_width, depth]
|
||||||
|
crop_size = [crop_height, crop_width]
|
||||||
|
crops_shape = [num_boxes, crop_height, crop_width, depth]
|
||||||
|
boxes_shape = [num_boxes, 4]
|
||||||
|
|
||||||
|
image = np.arange(0, batch * image_height * image_width *
|
||||||
|
depth).reshape(image_shape).astype(np.float32)
|
||||||
|
boxes = []
|
||||||
|
for _ in range(num_boxes):
|
||||||
|
# pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
y1, y2 = self._randomUniformAvoidAnchors(
|
||||||
|
low, high, np.linspace(0, 1, image_height), radius, 2)
|
||||||
|
x1, x2 = self._randomUniformAvoidAnchors(
|
||||||
|
low, high, np.linspace(0, 1, image_width), radius, 2)
|
||||||
|
# pylint: enable=unbalanced-tuple-unpacking
|
||||||
|
boxes.append([y1, x1, y2, x2])
|
||||||
|
|
||||||
|
boxes = np.array(boxes, dtype=np.float32)
|
||||||
|
box_ind = np.arange(batch, dtype=np.int32)
|
||||||
|
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
image_tensor = constant_op.constant(image, shape=image_shape)
|
||||||
|
boxes_tensor = constant_op.constant(boxes, shape=[num_boxes, 4])
|
||||||
|
box_ind_tensor = constant_op.constant(
|
||||||
|
box_ind, shape=[num_boxes])
|
||||||
|
crops = image_ops.crop_and_resize(
|
||||||
|
image_tensor, boxes_tensor, box_ind_tensor,
|
||||||
|
constant_op.constant(crop_size, shape=[2]))
|
||||||
|
|
||||||
|
err = gradient_checker.compute_gradient_error(
|
||||||
|
[image_tensor, boxes_tensor], [image_shape, boxes_shape],
|
||||||
|
crops,
|
||||||
|
crops_shape,
|
||||||
|
delta=delta,
|
||||||
|
x_init_value=[image, boxes])
|
||||||
|
|
||||||
|
self.assertLess(err, 2e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class RGBToHSVOpTestBase(test.TestCase):
|
||||||
|
|
||||||
|
TYPES = [np.float32, np.float64]
|
||||||
|
|
||||||
|
def testShapeIsCorrectAfterOp(self):
|
||||||
|
in_shape = [2, 20, 30, 3]
|
||||||
|
out_shape = [2, 20, 30, 3]
|
||||||
|
|
||||||
|
for nptype in self.TYPES:
|
||||||
|
x = np.random.randint(0, high=255, size=[2, 20, 30, 3]).astype(nptype)
|
||||||
|
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
hsv_out = gen_image_ops.rgb_to_hsv(rgb_input_tensor)
|
||||||
|
with self.cached_session(use_gpu=True):
|
||||||
|
self.assertEqual(out_shape, list(hsv_out.get_shape()))
|
||||||
|
hsv_out = self.evaluate(hsv_out)
|
||||||
|
self.assertEqual(out_shape, list(hsv_out.shape))
|
||||||
|
|
||||||
|
def testRGBToHSVGradSimpleCase(self):
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return gen_image_ops.rgb_to_hsv(x)
|
||||||
|
|
||||||
|
# Building a simple input tensor to avoid any discontinuity
|
||||||
|
x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8,
|
||||||
|
0.9]]).astype(np.float32)
|
||||||
|
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
|
||||||
|
# Computing Analytical and Numerical gradients of f(x)
|
||||||
|
analytical, numerical = gradient_checker_v2.compute_gradient(
|
||||||
|
f, [rgb_input_tensor])
|
||||||
|
self.assertAllClose(numerical, analytical, atol=1e-4)
|
||||||
|
|
||||||
|
def testRGBToHSVGradRandomCase(self):
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return gen_image_ops.rgb_to_hsv(x)
|
||||||
|
|
||||||
|
np.random.seed(0)
|
||||||
|
# Building a simple input tensor to avoid any discontinuity
|
||||||
|
x = np.random.rand(1, 5, 5, 3).astype(np.float32)
|
||||||
|
rgb_input_tensor = constant_op.constant(x, shape=x.shape)
|
||||||
|
# Computing Analytical and Numerical gradients of f(x)
|
||||||
|
self.assertLess(
|
||||||
|
gradient_checker_v2.max_error(
|
||||||
|
*gradient_checker_v2.compute_gradient(f, [rgb_input_tensor])), 1e-4)
|
||||||
|
|
||||||
|
def testRGBToHSVGradSpecialCaseRGreatest(self):
|
||||||
|
# This test tests a specific subset of the input space
|
||||||
|
# with a dummy function implemented with native TF operations.
|
||||||
|
in_shape = [2, 10, 20, 3]
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return gen_image_ops.rgb_to_hsv(x)
|
||||||
|
|
||||||
|
def f_dummy(x):
|
||||||
|
# This dummy function is a implementation of RGB to HSV using
|
||||||
|
# primitive TF functions for one particular case when R>G>B.
|
||||||
|
r = x[..., 0]
|
||||||
|
g = x[..., 1]
|
||||||
|
b = x[..., 2]
|
||||||
|
# Since MAX = r and MIN = b, we get the following h,s,v values.
|
||||||
|
v = r
|
||||||
|
s = 1 - math_ops.div_no_nan(b, r)
|
||||||
|
h = 60 * math_ops.div_no_nan(g - b, r - b)
|
||||||
|
h = h / 360
|
||||||
|
return array_ops.stack([h, s, v], axis=-1)
|
||||||
|
|
||||||
|
# Building a custom input tensor where R>G>B
|
||||||
|
x_reds = np.ones((in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
|
||||||
|
x_greens = 0.5 * np.ones(
|
||||||
|
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
|
||||||
|
x_blues = 0.2 * np.ones(
|
||||||
|
(in_shape[0], in_shape[1], in_shape[2])).astype(np.float32)
|
||||||
|
x = np.stack([x_reds, x_greens, x_blues], axis=-1)
|
||||||
|
rgb_input_tensor = constant_op.constant(x, shape=in_shape)
|
||||||
|
|
||||||
|
# Computing Analytical and Numerical gradients of f(x)
|
||||||
|
analytical, numerical = gradient_checker_v2.compute_gradient(
|
||||||
|
f, [rgb_input_tensor])
|
||||||
|
# Computing Analytical and Numerical gradients of f_dummy(x)
|
||||||
|
analytical_dummy, numerical_dummy = gradient_checker_v2.compute_gradient(
|
||||||
|
f_dummy, [rgb_input_tensor])
|
||||||
|
self.assertAllClose(numerical, analytical, atol=1e-4)
|
||||||
|
self.assertAllClose(analytical_dummy, analytical, atol=1e-4)
|
||||||
|
self.assertAllClose(numerical_dummy, numerical, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
||||||
@ -144,6 +144,7 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/python/tools:tools_pip",
|
"//tensorflow/python/tools:tools_pip",
|
||||||
"//tensorflow/python/tools/api/generator:create_python_api",
|
"//tensorflow/python/tools/api/generator:create_python_api",
|
||||||
"//tensorflow/python/tpu",
|
"//tensorflow/python/tpu",
|
||||||
|
"//tensorflow/python:image_grad_test_base",
|
||||||
"//tensorflow/python:test_ops",
|
"//tensorflow/python:test_ops",
|
||||||
"//tensorflow/python:while_v2",
|
"//tensorflow/python:while_v2",
|
||||||
"//tensorflow/tools/common:public_api",
|
"//tensorflow/tools/common:public_api",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user