From 66fbe49c0747fb4fa626e127d34b9baf73fa287b Mon Sep 17 00:00:00 2001 From: Harry Slatyer Date: Thu, 2 Jan 2020 16:04:50 +1100 Subject: [PATCH] Switch from where to where_v2 in matrix_exponential (because `where` is deprecated). `where` and `where_v2` have different broadcasting rules: the former "broadcasts" by treating a 1D condition as a mask on the *outer* dimension of x and y; the latter follows standard broadcasting rules, which cause a 1D condition to act as a mask on the *inner* dimension of x and y. E.g. with `where_v2`, if x and y are [n, d, d], then a condition [n] will either fail to build (if n != d) or be treated as [1, 1, n] (if n == d). In this case, we want the condition to act as a mask on the outer dimension, e.g. be treated as [n, 1, 1]. The way to make that happen with `where_v2` is simply to expand the condition's dimensions to that shape manually. In the case of matrix exponential, by expanding the shape of `l1_norm` to [n, 1, 1]: - the conditions in `_nest_where` become the right shape, - the `squarings` variables get expanded too (i.e. [n] -> [n, 1, 1]), which means that... - ... when scaling the Pade approximants we no longer need to expand the dimensions of 2^squarings, and ... - ... the condition in `b` (used for the squaring while loop) becomes the right shape. Which is what we need. --- tensorflow/python/ops/linalg/linalg_impl.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index 3412486fb9e..8fc8646a99f 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -276,16 +276,16 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin math_ops.reduce_sum( math_ops.abs(matrix), axis=array_ops.size(array_ops.shape(matrix)) - 2), - axis=-1) + axis=-1)[..., array_ops.newaxis, array_ops.newaxis] const = lambda x: constant_op.constant(x, l1_norm.dtype) def _nest_where(vals, cases): assert len(vals) == len(cases) - 1 if len(vals) == 1: - return array_ops.where( + return array_ops.where_v2( math_ops.less(l1_norm, const(vals[0])), cases[0], cases[1]) else: - return array_ops.where( + return array_ops.where_v2( math_ops.less(l1_norm, const(vals[0])), cases[0], _nest_where(vals[1:], cases[1:])) @@ -297,8 +297,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin u3, v3 = _matrix_exp_pade3(matrix) u5, v5 = _matrix_exp_pade5(matrix) u7, v7 = _matrix_exp_pade7(matrix / math_ops.cast( - math_ops.pow(const(2.0), squarings), - matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) + math_ops.pow(const(2.0), squarings), matrix.dtype)) conds = (4.258730016922831e-001, 1.880152677804762e+000) u = _nest_where(conds, (u3, u5, u7)) v = _nest_where(conds, (v3, v5, v7)) @@ -312,8 +311,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin u7, v7 = _matrix_exp_pade7(matrix) u9, v9 = _matrix_exp_pade9(matrix) u13, v13 = _matrix_exp_pade13(matrix / math_ops.cast( - math_ops.pow(const(2.0), squarings), - matrix.dtype)[..., array_ops.newaxis, array_ops.newaxis]) + math_ops.pow(const(2.0), squarings), matrix.dtype)) conds = (1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000) u = _nest_where(conds, (u3, u5, u7, u9, u13)) @@ -330,7 +328,7 @@ def matrix_exponential(input, name=None): # pylint: disable=redefined-builtin c = lambda i, r: math_ops.less(i, max_squarings) def b(i, r): - return i + 1, array_ops.where( + return i + 1, array_ops.where_v2( math_ops.less(i, squarings), math_ops.matmul(r, r), r) _, result = control_flow_ops.while_loop(c, b, [i, result])