Merge pull request #35530 from charmasaur:where_v2_in_expm

PiperOrigin-RevId: 288328469
Change-Id: If5507bbdc909d9cb5da873475520ded8cea2ea4d
This commit is contained in:
TensorFlower Gardener 2020-01-06 10:48:21 -08:00
commit 8016fed342

View File

@ -275,16 +275,16 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
math_ops.reduce_sum(
math_ops.abs(matrix),
axis=array_ops.size(array_ops.shape(matrix)) - 2),
axis=-1)
axis=-1)[..., array_ops.newaxis, array_ops.newaxis]
const = lambda x: constant_op.constant(x, l1_norm.dtype)
def _nest_where(vals, cases):
assert len(vals) == len(cases) - 1
if len(vals) == 1:
return array_ops.where(
return array_ops.where_v2(
math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
else:
return array_ops.where(
return array_ops.where_v2(
math_ops.less(l1_norm, const(vals[0])), cases[0],
_nest_where(vals[1:], cases[1:]))
@ -295,9 +295,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
u3, v3 = _matrix_exp_pade3(matrix)
u5, v5 = _matrix_exp_pade5(matrix)
u7, v7 = _matrix_exp_pade7(matrix / math_ops.cast(
math_ops.pow(const(2.0), squarings),
matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis])
u7, v7 = _matrix_exp_pade7(
matrix /
math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
conds = (4.258730016922831e-001, 1.880152677804762e+000)
u = _nest_where(conds, (u3, u5, u7))
v = _nest_where(conds, (v3, v5, v7))
@ -310,9 +310,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
u5, v5 = _matrix_exp_pade5(matrix)
u7, v7 = _matrix_exp_pade7(matrix)
u9, v9 = _matrix_exp_pade9(matrix)
u13, v13 = _matrix_exp_pade13(matrix / math_ops.cast(
math_ops.pow(const(2.0), squarings),
matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis])
u13, v13 = _matrix_exp_pade13(
matrix /
math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
conds = (1.495585217958292e-002, 2.539398330063230e-001,
9.504178996162932e-001, 2.097847961257068e+000)
u = _nest_where(conds, (u3, u5, u7, u9, u13))
@ -329,7 +329,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
c = lambda i, r: math_ops.less(i, max_squarings)
def b(i, r):
return i + 1, array_ops.where(
return i + 1, array_ops.where_v2(
math_ops.less(i, squarings), math_ops.matmul(r, r), r)
_, result = control_flow_ops.while_loop(c, b, [i, result])