pfor: Add converters for

Bucketize
  CheckNumerics
  ClipByValue
  ConjugateTranspose

PiperOrigin-RevId: 301383967
Change-Id: I2762dd56e995a7dc1368762d53596e0bb7d90825
This commit is contained in:
A. Unique TensorFlower 2020-03-17 09:00:09 -07:00 committed by TensorFlower Gardener
parent 5597c17b6a
commit c816c56914
3 changed files with 67 additions and 6 deletions
tensorflow/python/ops/parallel_for

View File

@ -292,6 +292,17 @@ class ArrayTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
def test_conjugate_transpose(self):
x = math_ops.complex(
random_ops.random_uniform([3, 2, 3, 4]),
random_ops.random_uniform([3, 2, 3, 4]))
def loop_fn(i):
x_i = array_ops.gather(x, i)
return array_ops.conjugate_transpose(x_i, [2, 1, 0])
self._test_loop_fn(loop_fn, 3)
def test_zeros_like(self):
x = random_ops.random_uniform([3, 2, 3])
@ -476,5 +487,15 @@ class ArrayTest(PForTestCase):
self._test_loop_fn(loop_fn, 7)
def test_check_numerics(self):
x = random_ops.random_uniform([2, 3, 4])
def loop_fn(i):
x_i = array_ops.gather(x, i)
return array_ops.check_numerics(x_i, "test_message")
self._test_loop_fn(loop_fn, 2)
if __name__ == "__main__":
test.main()

View File

@ -50,7 +50,6 @@ class MathTest(PForTestCase, parameterized.TestCase):
x = math_ops.complex(x, y)
# pylint: disable=cell-var-from-loop
output_dtypes = []
def loop_fn(i):
with g:
@ -65,8 +64,6 @@ class MathTest(PForTestCase, parameterized.TestCase):
grad = g.gradient(loss, x1)
if grad is not None:
outputs.append(grad)
del output_dtypes[:]
output_dtypes.extend(t.dtype for t in outputs)
return outputs
# pylint: enable=cell-var-from-loop
@ -374,6 +371,24 @@ class MathTest(PForTestCase, parameterized.TestCase):
self._test_loop_fn(loop_fn, 2)
def test_bucketize(self):
x = random_ops.random_uniform([2, 3, 4])
def loop_fn(i):
a = array_ops.gather(x, i)
return math_ops.bucketize(a, [-1, 0.5, 1])
self._test_loop_fn(loop_fn, 2)
def test_clip_by_value(self):
x = random_ops.random_uniform([2, 3, 4])
def loop_fn(i):
a = array_ops.gather(x, i)
return clip_ops.clip_by_value(a, 0.5, 1.0)
self._test_loop_fn(loop_fn, 2)
def test_cum_sum(self):
x = random_ops.random_uniform([2, 3, 4, 5])
for axis in (1, -2):

View File

@ -45,6 +45,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_parsing_ops
from tensorflow.python.ops import gen_random_ops
@ -2167,12 +2168,13 @@ def _convert_reverse(pfor_input):
return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True)
@RegisterPFor("Transpose")
def _convert_transpose(pfor_input):
@RegisterPForWithArgs("Transpose", gen_array_ops.transpose)
@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose)
def _convert_transpose(pfor_input, _, op_func):
t = pfor_input.stacked_input(0)
perm = pfor_input.unstacked_input(1)
new_perm = array_ops.concat([[0], perm + 1], axis=0)
return wrap(array_ops.transpose(t, new_perm), True)
return wrap(op_func(t, new_perm), True)
@RegisterPFor("ZerosLike")
@ -2339,6 +2341,13 @@ def _convert_strided_slice_grad(pfor_input):
shrink_axis_mask=shrink_axis_mask), True)
@RegisterPFor("CheckNumerics")
def _convert_check_numerics(pfor_input):
t = pfor_input.stacked_input(0)
message = pfor_input.get_attr("message")
return wrap(gen_array_ops.check_numerics(t, message), True)
# math_ops
@ -2453,6 +2462,22 @@ def _convert_argmax_argmin(pfor_input, _, op_func):
return wrap(op_func(t, axis=dimension, output_type=output_type), True)
@RegisterPFor("Bucketize")
def _convert_bucketize(pfor_input):
t = pfor_input.stacked_input(0)
boundaries = pfor_input.get_attr("boundaries")
return wrap(math_ops.bucketize(t, boundaries), True)
@RegisterPFor("ClipByValue")
def _convert_clip_by_value(pfor_input):
t = pfor_input.stacked_input(0)
clip_value_min = pfor_input.unstacked_input(1)
clip_value_max = pfor_input.unstacked_input(2)
return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max),
True)
@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
def _convert_cumfoo(pfor_input, _, op_func):