Merge pull request #35530 from charmasaur:where_v2_in_expm
PiperOrigin-RevId: 288328469 Change-Id: If5507bbdc909d9cb5da873475520ded8cea2ea4d
This commit is contained in:
commit
8016fed342
@ -275,16 +275,16 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
|
||||
math_ops.reduce_sum(
|
||||
math_ops.abs(matrix),
|
||||
axis=array_ops.size(array_ops.shape(matrix)) - 2),
|
||||
axis=-1)
|
||||
axis=-1)[..., array_ops.newaxis, array_ops.newaxis]
|
||||
const = lambda x: constant_op.constant(x, l1_norm.dtype)
|
||||
|
||||
def _nest_where(vals, cases):
|
||||
assert len(vals) == len(cases) - 1
|
||||
if len(vals) == 1:
|
||||
return array_ops.where(
|
||||
return array_ops.where_v2(
|
||||
math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1])
|
||||
else:
|
||||
return array_ops.where(
|
||||
return array_ops.where_v2(
|
||||
math_ops.less(l1_norm, const(vals[0])), cases[0],
|
||||
_nest_where(vals[1:], cases[1:]))
|
||||
|
||||
@ -295,9 +295,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
|
||||
math_ops.log(l1_norm / maxnorm) / math_ops.log(const(2.0))), 0)
|
||||
u3, v3 = _matrix_exp_pade3(matrix)
|
||||
u5, v5 = _matrix_exp_pade5(matrix)
|
||||
u7, v7 = _matrix_exp_pade7(matrix / math_ops.cast(
|
||||
math_ops.pow(const(2.0), squarings),
|
||||
matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis])
|
||||
u7, v7 = _matrix_exp_pade7(
|
||||
matrix /
|
||||
math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
|
||||
conds = (4.258730016922831e-001, 1.880152677804762e+000)
|
||||
u = _nest_where(conds, (u3, u5, u7))
|
||||
v = _nest_where(conds, (v3, v5, v7))
|
||||
@ -310,9 +310,9 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
|
||||
u5, v5 = _matrix_exp_pade5(matrix)
|
||||
u7, v7 = _matrix_exp_pade7(matrix)
|
||||
u9, v9 = _matrix_exp_pade9(matrix)
|
||||
u13, v13 = _matrix_exp_pade13(matrix / math_ops.cast(
|
||||
math_ops.pow(const(2.0), squarings),
|
||||
matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis])
|
||||
u13, v13 = _matrix_exp_pade13(
|
||||
matrix /
|
||||
math_ops.cast(math_ops.pow(const(2.0), squarings), matrix.dtype))
|
||||
conds = (1.495585217958292e-002, 2.539398330063230e-001,
|
||||
9.504178996162932e-001, 2.097847961257068e+000)
|
||||
u = _nest_where(conds, (u3, u5, u7, u9, u13))
|
||||
@ -329,7 +329,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin
|
||||
c = lambda i, r: math_ops.less(i, max_squarings)
|
||||
|
||||
def b(i, r):
|
||||
return i + 1, array_ops.where(
|
||||
return i + 1, array_ops.where_v2(
|
||||
math_ops.less(i, squarings), math_ops.matmul(r, r), r)
|
||||
|
||||
_, result = control_flow_ops.while_loop(c, b, [i, result])
|
||||
|
Loading…
Reference in New Issue
Block a user