Merge pull request #35530 from charmasaur:where_v2_in_expm

PiperOrigin-RevId: 288328469
Change-Id: If5507bbdc909d9cb5da873475520ded8cea2ea4d
This commit is contained in:
TensorFlower Gardener 2020-01-06 10:48:21 -08:00
commit 8016fed342

View File

@ -275,16 +275,16 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
math_ops.reduce_sum( math_ops.reduce_sum(
math_ops.abs(matrix), math_ops.abs(matrix),
axis=array_ops.size(array_ops.shape(matrix)) - 2), axis=array_ops.size(array_ops.shape(matrix)) - 2),
axis=-1) axis=-1)[..., array_ops.newaxis, array_ops.newaxis]
const = lambda x: constant_op.constant(x, l1_norm.dtype) const = lambda x: constant_op.constant(x, l1_norm.dtype)
def _nest_where(vals, cases): def _nest_where(vals, cases):
assert len(vals) == len(cases) - 1 assert len(vals) == len(cases) - 1
if len(vals) == 1: if len(vals) == 1:
return array_ops.where( return array_ops.where_v2(
math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1]) math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
else: else:
return array_ops.where( return array_ops.where_v2(
math_ops.less(l1_norm, const(vals[0])), cases[0], math_ops.less(l1_norm, const(vals[0])), cases[0],
_nest_where(vals[1:], cases[1:])) _nest_where(vals[1:], cases[1:]))
@ -295,9 +295,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0) math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
u3, v3 = _matrix_exp_pade3(matrix) u3, v3 = _matrix_exp_pade3(matrix)
u5, v5 = _matrix_exp_pade5(matrix) u5, v5 = _matrix_exp_pade5(matrix)
u7, v7 = _matrix_exp_pade7(matrix / math_ops.cast( u7, v7 = _matrix_exp_pade7(
math_ops.pow(const(2.0), squarings), matrix /
matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
conds = (4.258730016922831e-001, 1.880152677804762e+000) conds = (4.258730016922831e-001, 1.880152677804762e+000)
u = _nest_where(conds, (u3, u5, u7)) u = _nest_where(conds, (u3, u5, u7))
v = _nest_where(conds, (v3, v5, v7)) v = _nest_where(conds, (v3, v5, v7))
@ -310,9 +310,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
u5, v5 = _matrix_exp_pade5(matrix) u5, v5 = _matrix_exp_pade5(matrix)
u7, v7 = _matrix_exp_pade7(matrix) u7, v7 = _matrix_exp_pade7(matrix)
u9, v9 = _matrix_exp_pade9(matrix) u9, v9 = _matrix_exp_pade9(matrix)
u13, v13 = _matrix_exp_pade13(matrix / math_ops.cast( u13, v13 = _matrix_exp_pade13(
math_ops.pow(const(2.0), squarings), matrix /
matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
conds = (1.495585217958292e-002, 2.539398330063230e-001, conds = (1.495585217958292e-002, 2.539398330063230e-001,
9.504178996162932e-001, 2.097847961257068e+000) 9.504178996162932e-001, 2.097847961257068e+000)
u = _nest_where(conds, (u3, u5, u7, u9, u13)) u = _nest_where(conds, (u3, u5, u7, u9, u13))
@ -329,7 +329,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
c = lambda i, r: math_ops.less(i, max_squarings) c = lambda i, r: math_ops.less(i, max_squarings)
def b(i, r): def b(i, r):
return i + 1, array_ops.where( return i + 1, array_ops.where_v2(
math_ops.less(i, squarings), math_ops.matmul(r, r), r) math_ops.less(i, squarings), math_ops.matmul(r, r), r)
_, result = control_flow_ops.while_loop(c, b, [i, result]) _, result = control_flow_ops.while_loop(c, b, [i, result])