Add pfor converters for:

AdjustContrastV2
  AdjustHue
  AdjustSaturation
  AvgPool3D
  AvgPool3DGrad
  BatchToSpaceND
  MaxPoolGradGradV2
  MaxPoolGradV2
  MaxPoolV2
  SpaceToBatchND

PiperOrigin-RevId: 294814076
Change-Id: Icf9bdebce6e4c0699a4e39c30214e128f6641f0a
This commit is contained in:
A. Unique TensorFlower 2020-02-12 19:13:08 -08:00 committed by TensorFlower Gardener
parent 9afba0fc0d
commit af458ea2ed
3 changed files with 174 additions and 0 deletions

View File

@ -438,6 +438,27 @@ class ArrayTest(PForTestCase):
self._test_loop_fn(loop_fn, 2)
def test_batch_to_space_nd(self):
x = random_ops.random_uniform([7, 5 * 2 * 3, 2, 2, 3, 2])
block_shapes = [2, 3]
crops = [[1, 2], [1, 0]]
def loop_fn(i):
x1 = array_ops.gather(x, i)
return array_ops.batch_to_space_nd(x1, block_shapes, crops)
self._test_loop_fn(loop_fn, 7)
def test_space_to_batch_nd(self):
x = random_ops.random_uniform([7, 5, 2 * 2 - 3, 2 * 3 - 1, 3, 2])
block_shapes = [2, 3]
paddings = [[1, 2], [1, 0]]
def loop_fn(i):
x1 = array_ops.gather(x, i)
return array_ops.space_to_batch_nd(x1, block_shapes, paddings)
self._test_loop_fn(loop_fn, 7)
if __name__ == "__main__":
test.main()

View File

@ -41,7 +41,9 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
@ -337,6 +339,37 @@ class BitwiseTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
@test_util.run_all_in_graph_and_eager_modes
class ImageTest(PForTestCase):
def test_adjust_contrast(self):
images = random_ops.random_uniform([3, 2, 4, 4, 3])
def loop_fn(i):
image = array_ops.gather(images, i)
return image_ops.adjust_contrast(image, 2.0)
self._test_loop_fn(loop_fn, 3)
def test_adjust_hue(self):
images = random_ops.random_uniform([3, 2, 4, 4, 3])
def loop_fn(i):
image = array_ops.gather(images, i)
return image_ops.adjust_hue(image, .25)
self._test_loop_fn(loop_fn, 3)
def test_adjust_saturation(self):
images = random_ops.random_uniform([3, 2, 4, 4, 3])
def loop_fn(i):
image = array_ops.gather(images, i)
return image_ops.adjust_saturation(image, 0.1)
self._test_loop_fn(loop_fn, 3)
@test_util.run_all_in_graph_and_eager_modes
class NNTest(PForTestCase):
@ -409,6 +442,27 @@ class NNTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
def test_avg_pool3d(self):
with backprop.GradientTape(persistent=True) as g:
x = random_ops.random_uniform([5, 3, 7, 6, 6, 5])
g.watch(x)
ksize = [1, 2, 2, 2, 1]
strides = [1, 2, 2, 2, 1]
def loop_fn(i):
with g:
x1 = array_ops.gather(x, i)
output = nn.avg_pool3d(
x1,
ksize,
strides=strides,
padding="VALID",
data_format="NDHWC")
loss = nn.l2_loss(output)
return output, g.gradient(loss, x1)
self._test_loop_fn(loop_fn, 3)
def test_max_pool(self):
with backprop.GradientTape(persistent=True) as g:
x = random_ops.random_uniform([3, 2, 12, 12, 3])
@ -430,6 +484,27 @@ class NNTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
def test_max_pool_v2(self):
with backprop.GradientTape(persistent=True) as g:
x = random_ops.random_uniform([3, 2, 12, 12, 3])
g.watch(x)
ksize = [1, 3, 3, 1]
strides = [1, 2, 2, 1]
def loop_fn(i):
with g:
x1 = array_ops.gather(x, i)
output = gen_nn_ops.max_pool_v2(
x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
loss = nn.l2_loss(output)
ones = array_ops.ones_like(output)
g.watch(ones)
grad = g.gradient(loss, x1, output_gradients=ones)
grad_grad = g.gradient(grad, ones)
return output, grad, grad_grad
self._test_loop_fn(loop_fn, 3)
def test_max_pool3d(self):
if test.is_built_with_rocm():
self.skipTest("Pooling with 3D tensors is not supported in ROCm")

View File

@ -44,6 +44,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_parsing_ops
@ -1544,6 +1545,31 @@ class PFor(object):
# The code below defines converters for different operations. Please see comment
# for RegisterPFor to see how converters should be defined.
# image_ops
@RegisterPFor("AdjustContrastv2")
def _convert_adjust_contrastv2(pfor_input):
images = pfor_input.stacked_input(0)
contrast_factor = pfor_input.unstacked_input(1)
return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True)
@RegisterPFor("AdjustHue")
def _convert_adjust_hue(pfor_input):
images = pfor_input.stacked_input(0)
delta = pfor_input.unstacked_input(1)
return wrap(gen_image_ops.adjust_hue(images, delta), True)
@RegisterPFor("AdjustSaturation")
def _convert_adjust_saturation(pfor_input):
images = pfor_input.stacked_input(0)
scale = pfor_input.unstacked_input(1)
return wrap(gen_image_ops.adjust_saturation(images, scale), True)
# nn_ops
@ -1580,12 +1606,16 @@ def _inputs_with_flattening(pfor_input, input_indices):
@RegisterPForWithArgs("Conv2D", dims=[0])
@RegisterPForWithArgs("DepthToSpace", dims=[0])
@RegisterPForWithArgs("AvgPool", dims=[0])
@RegisterPForWithArgs("AvgPool3D", dims=[0])
@RegisterPForWithArgs("MaxPool", dims=[0])
@RegisterPForWithArgs("MaxPoolV2", dims=[0])
@RegisterPForWithArgs("MaxPool3D", dims=[0])
@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2])
@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2])
@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
@RegisterPForWithArgs("SpaceToDepth", dims=[0])
def _convert_flatten_batch(pfor_input, op_type, dims):
@ -1603,6 +1633,53 @@ def _convert_flatten_batch(pfor_input, op_type, dims):
_channel_flatten_input_cache = {}
@RegisterPFor("BatchToSpaceND")
def _convert_batch_to_space_nd(pfor_input):
inp = pfor_input.stacked_input(0)
block_shape = pfor_input.unstacked_input(1)
crops = pfor_input.unstacked_input(2)
inp_shape = array_ops.shape(inp)
n = pfor_input.pfor.loop_len_vector
# Reshape and transpose to move the vectorization axis inside the axes that
# will move to space.
# Reshape to 4D and transpose
block_size = math_ops.reduce_prod(block_shape)
new_shape = [n[0], block_size, inp_shape[1] // block_size, -1]
inp = array_ops.reshape(inp, new_shape)
inp = array_ops.transpose(inp, [1, 0, 2, 3])
# Reshape back to merge the block, vectorization and batch dimension, and
# restore the other dimensions.
new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0)
inp = array_ops.reshape(inp, new_shape)
# Call batch_to_space and then split the new batch axis.
output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops)
output = _unflatten_first_dim(output, n)
return wrap(output, True)
@RegisterPFor("SpaceToBatchND")
def _convert_space_to_batch_nd(pfor_input):
inp = pfor_input.stacked_input(0)
block_shape = pfor_input.unstacked_input(1)
paddings = pfor_input.unstacked_input(2)
n = pfor_input.pfor.loop_len_vector
inp_shape = array_ops.shape(inp)
inp = _flatten_first_two_dims(inp)
output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings)
output_shape = array_ops.shape(output)
block_size = math_ops.reduce_prod(block_shape)
new_shape = [block_size, n[0], -1]
output = array_ops.reshape(output, new_shape)
output = array_ops.transpose(output, [1, 0, 2])
new_shape = array_ops.concat(
[n, block_size * inp_shape[1:2], output_shape[1:]], axis=0)
output = array_ops.reshape(output, new_shape)
return wrap(output, True)
def _channel_flatten_input(x, data_format):
"""Merge the stack dimension with the channel dimension.
@ -1734,6 +1811,7 @@ def _convert_fused_batch_norm_grad(pfor_input):
@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0)
def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
shape_dim):
del op_type