pfor: Add converters for
Bucketize CheckNumerics ClipByValue ConjugateTranspose PiperOrigin-RevId: 301383967 Change-Id: I2762dd56e995a7dc1368762d53596e0bb7d90825
This commit is contained in:
parent
5597c17b6a
commit
c816c56914
tensorflow/python/ops/parallel_for
@ -292,6 +292,17 @@ class ArrayTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_conjugate_transpose(self):
|
||||
x = math_ops.complex(
|
||||
random_ops.random_uniform([3, 2, 3, 4]),
|
||||
random_ops.random_uniform([3, 2, 3, 4]))
|
||||
|
||||
def loop_fn(i):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return array_ops.conjugate_transpose(x_i, [2, 1, 0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_zeros_like(self):
|
||||
x = random_ops.random_uniform([3, 2, 3])
|
||||
|
||||
@ -476,5 +487,15 @@ class ArrayTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 7)
|
||||
|
||||
def test_check_numerics(self):
|
||||
x = random_ops.random_uniform([2, 3, 4])
|
||||
|
||||
def loop_fn(i):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return array_ops.check_numerics(x_i, "test_message")
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -50,7 +50,6 @@ class MathTest(PForTestCase, parameterized.TestCase):
|
||||
x = math_ops.complex(x, y)
|
||||
|
||||
# pylint: disable=cell-var-from-loop
|
||||
output_dtypes = []
|
||||
|
||||
def loop_fn(i):
|
||||
with g:
|
||||
@ -65,8 +64,6 @@ class MathTest(PForTestCase, parameterized.TestCase):
|
||||
grad = g.gradient(loss, x1)
|
||||
if grad is not None:
|
||||
outputs.append(grad)
|
||||
del output_dtypes[:]
|
||||
output_dtypes.extend(t.dtype for t in outputs)
|
||||
return outputs
|
||||
|
||||
# pylint: enable=cell-var-from-loop
|
||||
@ -374,6 +371,24 @@ class MathTest(PForTestCase, parameterized.TestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
def test_bucketize(self):
|
||||
x = random_ops.random_uniform([2, 3, 4])
|
||||
|
||||
def loop_fn(i):
|
||||
a = array_ops.gather(x, i)
|
||||
return math_ops.bucketize(a, [-1, 0.5, 1])
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
def test_clip_by_value(self):
|
||||
x = random_ops.random_uniform([2, 3, 4])
|
||||
|
||||
def loop_fn(i):
|
||||
a = array_ops.gather(x, i)
|
||||
return clip_ops.clip_by_value(a, 0.5, 1.0)
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
def test_cum_sum(self):
|
||||
x = random_ops.random_uniform([2, 3, 4, 5])
|
||||
for axis in (1, -2):
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gen_parsing_ops
|
||||
from tensorflow.python.ops import gen_random_ops
|
||||
@ -2167,12 +2168,13 @@ def _convert_reverse(pfor_input):
|
||||
return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True)
|
||||
|
||||
|
||||
@RegisterPFor("Transpose")
|
||||
def _convert_transpose(pfor_input):
|
||||
@RegisterPForWithArgs("Transpose", gen_array_ops.transpose)
|
||||
@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose)
|
||||
def _convert_transpose(pfor_input, _, op_func):
|
||||
t = pfor_input.stacked_input(0)
|
||||
perm = pfor_input.unstacked_input(1)
|
||||
new_perm = array_ops.concat([[0], perm + 1], axis=0)
|
||||
return wrap(array_ops.transpose(t, new_perm), True)
|
||||
return wrap(op_func(t, new_perm), True)
|
||||
|
||||
|
||||
@RegisterPFor("ZerosLike")
|
||||
@ -2339,6 +2341,13 @@ def _convert_strided_slice_grad(pfor_input):
|
||||
shrink_axis_mask=shrink_axis_mask), True)
|
||||
|
||||
|
||||
@RegisterPFor("CheckNumerics")
|
||||
def _convert_check_numerics(pfor_input):
|
||||
t = pfor_input.stacked_input(0)
|
||||
message = pfor_input.get_attr("message")
|
||||
return wrap(gen_array_ops.check_numerics(t, message), True)
|
||||
|
||||
|
||||
# math_ops
|
||||
|
||||
|
||||
@ -2453,6 +2462,22 @@ def _convert_argmax_argmin(pfor_input, _, op_func):
|
||||
return wrap(op_func(t, axis=dimension, output_type=output_type), True)
|
||||
|
||||
|
||||
@RegisterPFor("Bucketize")
|
||||
def _convert_bucketize(pfor_input):
|
||||
t = pfor_input.stacked_input(0)
|
||||
boundaries = pfor_input.get_attr("boundaries")
|
||||
return wrap(math_ops.bucketize(t, boundaries), True)
|
||||
|
||||
|
||||
@RegisterPFor("ClipByValue")
|
||||
def _convert_clip_by_value(pfor_input):
|
||||
t = pfor_input.stacked_input(0)
|
||||
clip_value_min = pfor_input.unstacked_input(1)
|
||||
clip_value_max = pfor_input.unstacked_input(2)
|
||||
return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max),
|
||||
True)
|
||||
|
||||
|
||||
@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
|
||||
@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
|
||||
def _convert_cumfoo(pfor_input, _, op_func):
|
||||
|
Loading…
Reference in New Issue
Block a user