Merge pull request #23394 from zldrobit:patch-2
PiperOrigin-RevId: 221654513
This commit is contained in:
commit
e500ab5f8b
@ -24,7 +24,7 @@ from tensorflow.contrib.image.python.ops import dense_image_warp
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients
|
||||
@ -259,7 +259,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
shape = [1, 2, 1, 1]
|
||||
msg = 'Should have raised an exception for invalid image size'
|
||||
with self.assertRaises(ValueError, msg=msg):
|
||||
with self.assertRaises(errors.InvalidArgumentError, msg=msg):
|
||||
self.check_interpolation_correctness(shape, 'float32', 'float32')
|
||||
|
||||
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
@ -60,28 +61,38 @@ def _interpolate_bilinear(grid,
|
||||
msg = 'Grid must be 4 dimensional. Received size: '
|
||||
raise ValueError(msg + str(grid.get_shape()))
|
||||
|
||||
batch_size, height, width, channels = shape
|
||||
batch_size, height, width, channels = (array_ops.shape(grid)[0],
|
||||
array_ops.shape(grid)[1],
|
||||
array_ops.shape(grid)[2],
|
||||
array_ops.shape(grid)[3])
|
||||
|
||||
shape = [batch_size, height, width, channels]
|
||||
query_type = query_points.dtype
|
||||
grid_type = grid.dtype
|
||||
|
||||
if (query_points.shape.rank != 3 or
|
||||
query_points.shape.dims[2].value != 2):
|
||||
msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received '
|
||||
'size: ')
|
||||
raise ValueError(msg + str(query_points.get_shape()))
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_equal(
|
||||
len(query_points.get_shape()),
|
||||
3,
|
||||
message='Query points must be 3 dimensional.'),
|
||||
check_ops.assert_equal(
|
||||
array_ops.shape(query_points)[2],
|
||||
2,
|
||||
message='Query points must be size 2 in dim 2.')
|
||||
]):
|
||||
num_queries = array_ops.shape(query_points)[1]
|
||||
|
||||
_, num_queries, _ = query_points.get_shape().as_list()
|
||||
|
||||
if height < 2 or width < 2:
|
||||
msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: '
|
||||
raise ValueError(msg + str(grid.get_shape()))
|
||||
|
||||
alphas = []
|
||||
floors = []
|
||||
ceils = []
|
||||
|
||||
index_order = [0, 1] if indexing == 'ij' else [1, 0]
|
||||
unstacked_query_points = array_ops.unstack(query_points, axis=2)
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_greater_equal(
|
||||
height, 2, message='Grid height must be at least 2.'),
|
||||
check_ops.assert_greater_equal(
|
||||
width, 2, message='Grid width must be at least 2.')
|
||||
]):
|
||||
alphas = []
|
||||
floors = []
|
||||
ceils = []
|
||||
index_order = [0, 1] if indexing == 'ij' else [1, 0]
|
||||
unstacked_query_points = array_ops.unstack(query_points, axis=2)
|
||||
|
||||
for dim in index_order:
|
||||
with ops.name_scope('dim-' + str(dim)):
|
||||
@ -112,16 +123,18 @@ def _interpolate_bilinear(grid,
|
||||
alpha = array_ops.expand_dims(alpha, 2)
|
||||
alphas.append(alpha)
|
||||
|
||||
if batch_size * height * width > np.iinfo(np.int32).max / 8:
|
||||
error_msg = """The image size or batch size is sufficiently large
|
||||
that the linearized addresses used by array_ops.gather
|
||||
may exceed the int32 limit."""
|
||||
raise ValueError(error_msg)
|
||||
|
||||
flattened_grid = array_ops.reshape(grid,
|
||||
[batch_size * height * width, channels])
|
||||
batch_offsets = array_ops.reshape(
|
||||
math_ops.range(batch_size) * height * width, [batch_size, 1])
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_less_equal(
|
||||
math_ops.cast(batch_size * height * width, dtype=dtypes.float32),
|
||||
np.iinfo(np.int32).max / 8,
|
||||
message="""The image size or batch size is sufficiently large
|
||||
that the linearized addresses used by array_ops.gather
|
||||
may exceed the int32 limit.""")
|
||||
]):
|
||||
flattened_grid = array_ops.reshape(
|
||||
grid, [batch_size * height * width, channels])
|
||||
batch_offsets = array_ops.reshape(
|
||||
math_ops.range(batch_size) * height * width, [batch_size, 1])
|
||||
|
||||
# This wraps array_ops.gather. We reshape the image data such that the
|
||||
# batch, y, and x coordinates are pulled into the first dimension.
|
||||
@ -182,7 +195,11 @@ def dense_image_warp(image, flow, name='dense_image_warp'):
|
||||
of dimensions.
|
||||
"""
|
||||
with ops.name_scope(name):
|
||||
batch_size, height, width, channels = image.get_shape().as_list()
|
||||
batch_size, height, width, channels = (array_ops.shape(image)[0],
|
||||
array_ops.shape(image)[1],
|
||||
array_ops.shape(image)[2],
|
||||
array_ops.shape(image)[3])
|
||||
|
||||
# The flow is defined on the image grid. Turn the flow into a list of query
|
||||
# points in the grid space.
|
||||
grid_x, grid_y = array_ops.meshgrid(
|
||||
|
Loading…
Reference in New Issue
Block a user