From af458ea2ed148519333b9ece3d74e3201a7a5222 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Feb 2020 19:13:08 -0800 Subject: [PATCH] Add pfor converters for: AdjustContrastV2 AdjustHue AdjustSaturation AvgPool3D AvgPool3DGrad BatchToSpaceND MaxPoolGradGradV2 MaxPoolGradV2 MaxPoolV2 SpaceToBatchND PiperOrigin-RevId: 294814076 Change-Id: Icf9bdebce6e4c0699a4e39c30214e128f6641f0a --- .../python/ops/parallel_for/array_test.py | 21 +++++ .../ops/parallel_for/control_flow_ops_test.py | 75 ++++++++++++++++++ tensorflow/python/ops/parallel_for/pfor.py | 78 +++++++++++++++++++ 3 files changed, 174 insertions(+) diff --git a/tensorflow/python/ops/parallel_for/array_test.py b/tensorflow/python/ops/parallel_for/array_test.py index 2792d968e89..23258c077ae 100644 --- a/tensorflow/python/ops/parallel_for/array_test.py +++ b/tensorflow/python/ops/parallel_for/array_test.py @@ -438,6 +438,27 @@ class ArrayTest(PForTestCase): self._test_loop_fn(loop_fn, 2) + def test_batch_to_space_nd(self): + x = random_ops.random_uniform([7, 5 * 2 * 3, 2, 2, 3, 2]) + block_shapes = [2, 3] + crops = [[1, 2], [1, 0]] + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.batch_to_space_nd(x1, block_shapes, crops) + + self._test_loop_fn(loop_fn, 7) + + def test_space_to_batch_nd(self): + x = random_ops.random_uniform([7, 5, 2 * 2 - 3, 2 * 3 - 1, 3, 2]) + block_shapes = [2, 3] + paddings = [[1, 2], [1, 0]] + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return array_ops.space_to_batch_nd(x1, block_shapes, paddings) + + self._test_loop_fn(loop_fn, 7) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 7d3d45326c1..fd071dd413d 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -41,7 +41,9 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradients as gradient_ops +from tensorflow.python.ops import image_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops @@ -337,6 +339,37 @@ class BitwiseTest(PForTestCase): self._test_loop_fn(loop_fn, 3) +@test_util.run_all_in_graph_and_eager_modes +class ImageTest(PForTestCase): + + def test_adjust_contrast(self): + images = random_ops.random_uniform([3, 2, 4, 4, 3]) + + def loop_fn(i): + image = array_ops.gather(images, i) + return image_ops.adjust_contrast(image, 2.0) + + self._test_loop_fn(loop_fn, 3) + + def test_adjust_hue(self): + images = random_ops.random_uniform([3, 2, 4, 4, 3]) + + def loop_fn(i): + image = array_ops.gather(images, i) + return image_ops.adjust_hue(image, .25) + + self._test_loop_fn(loop_fn, 3) + + def test_adjust_saturation(self): + images = random_ops.random_uniform([3, 2, 4, 4, 3]) + + def loop_fn(i): + image = array_ops.gather(images, i) + return image_ops.adjust_saturation(image, 0.1) + + self._test_loop_fn(loop_fn, 3) + + @test_util.run_all_in_graph_and_eager_modes class NNTest(PForTestCase): @@ -409,6 +442,27 @@ class NNTest(PForTestCase): self._test_loop_fn(loop_fn, 3) + def test_avg_pool3d(self): + with backprop.GradientTape(persistent=True) as g: + x = random_ops.random_uniform([5, 3, 7, 6, 6, 5]) + g.watch(x) + ksize = [1, 2, 2, 2, 1] + strides = [1, 2, 2, 2, 1] + + def loop_fn(i): + with g: + x1 = array_ops.gather(x, i) + output = nn.avg_pool3d( + x1, + ksize, + strides=strides, + padding="VALID", + data_format="NDHWC") + loss = nn.l2_loss(output) + return output, g.gradient(loss, x1) + + self._test_loop_fn(loop_fn, 3) + def test_max_pool(self): with backprop.GradientTape(persistent=True) as g: x = random_ops.random_uniform([3, 2, 12, 12, 3]) @@ -430,6 +484,27 @@ class NNTest(PForTestCase): self._test_loop_fn(loop_fn, 3) + def test_max_pool_v2(self): + with backprop.GradientTape(persistent=True) as g: + x = random_ops.random_uniform([3, 2, 12, 12, 3]) + g.watch(x) + ksize = [1, 3, 3, 1] + strides = [1, 2, 2, 1] + + def loop_fn(i): + with g: + x1 = array_ops.gather(x, i) + output = gen_nn_ops.max_pool_v2( + x1, ksize, strides=strides, padding="VALID", data_format="NHWC") + loss = nn.l2_loss(output) + ones = array_ops.ones_like(output) + g.watch(ones) + grad = g.gradient(loss, x1, output_gradients=ones) + grad_grad = g.gradient(grad, ones) + return output, grad, grad_grad + + self._test_loop_fn(loop_fn, 3) + def test_max_pool3d(self): if test.is_built_with_rocm(): self.skipTest("Pooling with 3D tensors is not supported in ROCm") diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 3737cfcde0f..6eea73bfb08 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_parsing_ops @@ -1544,6 +1545,31 @@ class PFor(object): # The code below defines converters for different operations. Please see comment # for RegisterPFor to see how converters should be defined. + +# image_ops + + +@RegisterPFor("AdjustContrastv2") +def _convert_adjust_contrastv2(pfor_input): + images = pfor_input.stacked_input(0) + contrast_factor = pfor_input.unstacked_input(1) + return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True) + + +@RegisterPFor("AdjustHue") +def _convert_adjust_hue(pfor_input): + images = pfor_input.stacked_input(0) + delta = pfor_input.unstacked_input(1) + return wrap(gen_image_ops.adjust_hue(images, delta), True) + + +@RegisterPFor("AdjustSaturation") +def _convert_adjust_saturation(pfor_input): + images = pfor_input.stacked_input(0) + scale = pfor_input.unstacked_input(1) + return wrap(gen_image_ops.adjust_saturation(images, scale), True) + + # nn_ops @@ -1580,12 +1606,16 @@ def _inputs_with_flattening(pfor_input, input_indices): @RegisterPForWithArgs("Conv2D", dims=[0]) @RegisterPForWithArgs("DepthToSpace", dims=[0]) @RegisterPForWithArgs("AvgPool", dims=[0]) +@RegisterPForWithArgs("AvgPool3D", dims=[0]) @RegisterPForWithArgs("MaxPool", dims=[0]) +@RegisterPForWithArgs("MaxPoolV2", dims=[0]) @RegisterPForWithArgs("MaxPool3D", dims=[0]) @RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) @RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2]) @RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) @RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2]) @RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) @RegisterPForWithArgs("SpaceToDepth", dims=[0]) def _convert_flatten_batch(pfor_input, op_type, dims): @@ -1603,6 +1633,53 @@ def _convert_flatten_batch(pfor_input, op_type, dims): _channel_flatten_input_cache = {} +@RegisterPFor("BatchToSpaceND") +def _convert_batch_to_space_nd(pfor_input): + inp = pfor_input.stacked_input(0) + block_shape = pfor_input.unstacked_input(1) + crops = pfor_input.unstacked_input(2) + + inp_shape = array_ops.shape(inp) + n = pfor_input.pfor.loop_len_vector + + # Reshape and transpose to move the vectorization axis inside the axes that + # will move to space. + # Reshape to 4D and transpose + block_size = math_ops.reduce_prod(block_shape) + new_shape = [n[0], block_size, inp_shape[1] // block_size, -1] + inp = array_ops.reshape(inp, new_shape) + inp = array_ops.transpose(inp, [1, 0, 2, 3]) + # Reshape back to merge the block, vectorization and batch dimension, and + # restore the other dimensions. + new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0) + inp = array_ops.reshape(inp, new_shape) + # Call batch_to_space and then split the new batch axis. + output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops) + output = _unflatten_first_dim(output, n) + return wrap(output, True) + + +@RegisterPFor("SpaceToBatchND") +def _convert_space_to_batch_nd(pfor_input): + inp = pfor_input.stacked_input(0) + block_shape = pfor_input.unstacked_input(1) + paddings = pfor_input.unstacked_input(2) + + n = pfor_input.pfor.loop_len_vector + inp_shape = array_ops.shape(inp) + inp = _flatten_first_two_dims(inp) + output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings) + output_shape = array_ops.shape(output) + block_size = math_ops.reduce_prod(block_shape) + new_shape = [block_size, n[0], -1] + output = array_ops.reshape(output, new_shape) + output = array_ops.transpose(output, [1, 0, 2]) + new_shape = array_ops.concat( + [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0) + output = array_ops.reshape(output, new_shape) + return wrap(output, True) + + def _channel_flatten_input(x, data_format): """Merge the stack dimension with the channel dimension. @@ -1734,6 +1811,7 @@ def _convert_fused_batch_norm_grad(pfor_input): @RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) @RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) +@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0) def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, shape_dim): del op_type