Merge pull request #23394 from zldrobit:patch-2

PiperOrigin-RevId: 221654513
This commit is contained in:
TensorFlower Gardener 2018-11-15 11:19:26 -08:00
commit e500ab5f8b
2 changed files with 48 additions and 31 deletions

View File

@ -24,7 +24,7 @@ from tensorflow.contrib.image.python.ops import dense_image_warp
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
@ -259,7 +259,7 @@ class DenseImageWarpTest(test_util.TensorFlowTestCase):
shape = [1, 2, 1, 1]
msg = 'Should have raised an exception for invalid image size'
with self.assertRaises(ValueError, msg=msg):
with self.assertRaises(errors.InvalidArgumentError, msg=msg):
self.check_interpolation_correctness(shape, 'float32', 'float32')

View File

@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
@ -60,28 +61,38 @@ def _interpolate_bilinear(grid,
msg = 'Grid must be 4 dimensional. Received size: '
raise ValueError(msg + str(grid.get_shape()))
batch_size, height, width, channels = shape
batch_size, height, width, channels = (array_ops.shape(grid)[0],
array_ops.shape(grid)[1],
array_ops.shape(grid)[2],
array_ops.shape(grid)[3])
shape = [batch_size, height, width, channels]
query_type = query_points.dtype
grid_type = grid.dtype
if (query_points.shape.rank != 3 or
query_points.shape.dims[2].value != 2):
msg = ('Query points must be 3 dimensional and size 2 in dim 2. Received '
'size: ')
raise ValueError(msg + str(query_points.get_shape()))
with ops.control_dependencies([
check_ops.assert_equal(
len(query_points.get_shape()),
3,
message='Query points must be 3 dimensional.'),
check_ops.assert_equal(
array_ops.shape(query_points)[2],
2,
message='Query points must be size 2 in dim 2.')
]):
num_queries = array_ops.shape(query_points)[1]
_, num_queries, _ = query_points.get_shape().as_list()
if height < 2 or width < 2:
msg = 'Grid must be at least batch_size x 2 x 2 in size. Received size: '
raise ValueError(msg + str(grid.get_shape()))
alphas = []
floors = []
ceils = []
index_order = [0, 1] if indexing == 'ij' else [1, 0]
unstacked_query_points = array_ops.unstack(query_points, axis=2)
with ops.control_dependencies([
check_ops.assert_greater_equal(
height, 2, message='Grid height must be at least 2.'),
check_ops.assert_greater_equal(
width, 2, message='Grid width must be at least 2.')
]):
alphas = []
floors = []
ceils = []
index_order = [0, 1] if indexing == 'ij' else [1, 0]
unstacked_query_points = array_ops.unstack(query_points, axis=2)
for dim in index_order:
with ops.name_scope('dim-' + str(dim)):
@ -112,16 +123,18 @@ def _interpolate_bilinear(grid,
alpha = array_ops.expand_dims(alpha, 2)
alphas.append(alpha)
if batch_size * height * width > np.iinfo(np.int32).max / 8:
error_msg = """The image size or batch size is sufficiently large
that the linearized addresses used by array_ops.gather
may exceed the int32 limit."""
raise ValueError(error_msg)
flattened_grid = array_ops.reshape(grid,
[batch_size * height * width, channels])
batch_offsets = array_ops.reshape(
math_ops.range(batch_size) * height * width, [batch_size, 1])
with ops.control_dependencies([
check_ops.assert_less_equal(
math_ops.cast(batch_size * height * width, dtype=dtypes.float32),
np.iinfo(np.int32).max / 8,
message="""The image size or batch size is sufficiently large
that the linearized addresses used by array_ops.gather
may exceed the int32 limit.""")
]):
flattened_grid = array_ops.reshape(
grid, [batch_size * height * width, channels])
batch_offsets = array_ops.reshape(
math_ops.range(batch_size) * height * width, [batch_size, 1])
# This wraps array_ops.gather. We reshape the image data such that the
# batch, y, and x coordinates are pulled into the first dimension.
@ -182,7 +195,11 @@ def dense_image_warp(image, flow, name='dense_image_warp'):
of dimensions.
"""
with ops.name_scope(name):
batch_size, height, width, channels = image.get_shape().as_list()
batch_size, height, width, channels = (array_ops.shape(image)[0],
array_ops.shape(image)[1],
array_ops.shape(image)[2],
array_ops.shape(image)[3])
# The flow is defined on the image grid. Turn the flow into a list of query
# points in the grid space.
grid_x, grid_y = array_ops.meshgrid(