From 66fe41900b21396221df20eb2c5ced1126ffb0f5 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 21 Sep 2020 10:34:34 -0700 Subject: [PATCH] Convert tf.einsum docstring to use doctest. PiperOrigin-RevId: 332881293 Change-Id: I9b26c57bc58e241205e5f8b1e531852797442d6a --- tensorflow/python/ops/special_math_ops.py | 119 +++++++++++++++++----- 1 file changed, 93 insertions(+), 26 deletions(-) diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py index 6bddd3ea9bf..2f4589cdf98 100644 --- a/tensorflow/python/ops/special_math_ops.py +++ b/tensorflow/python/ops/special_math_ops.py @@ -606,21 +606,25 @@ def _enclosing_tpu_context(): @tf_export('einsum', 'linalg.einsum') @dispatch.add_dispatch_support def einsum(equation, *inputs, **kwargs): - """Tensor contraction over specified indices and outer product. + r"""Tensor contraction over specified indices and outer product. Einsum allows defining Tensors by defining their element-wise computation. This computation is defined by `equation`, a shorthand form based on Einstein summation. As an example, consider multiplying two matrices A and B to form a matrix C. The elements of C are given by: - ``` - C[i,k] = sum_j A[i,j] * B[j,k] - ``` + $$ C_{i,k} = \sum_j A_{i,j} B_{j,k} $$ - The corresponding `equation` is: + or ``` - ij,jk->ik + C[i,k] = sum_j A[i,j] * B[j,k] + ``` + + The corresponding einsum `equation` is: + + ``` + ij,jk->ik ``` In general, to convert the element-wise equation into the `equation` string, @@ -632,35 +636,98 @@ def einsum(equation, *inputs, **kwargs): 3. drop summation signs, and (`ik = ij, jk`) 4. move the output to the right, while replacing "=" with "->". (`ij,jk->ik`) + Note: If the output indices are not specified repeated indices are summed. + So `ij,jk->ik` can be simplified to `ij,jk`. + Many common operations can be expressed in this way. For example: - ```python - # Matrix multiplication - einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] + **Matrix multiplication** - # Dot product - einsum('i,i->', u, v) # output = sum_i u[i]*v[i] + >>> m0 = tf.random.normal(shape=[2, 3]) + >>> m1 = tf.random.normal(shape=[3, 5]) + >>> e = tf.einsum('ij,jk->ik', m0, m1) + >>> # output[i,k] = sum_j m0[i,j] * m1[j, k] + >>> print(e.shape) + (2, 5) - # Outer product - einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] + Repeated indices are summed if the output indices are not specified. - # Transpose - einsum('ij->ji', m) # output[j,i] = m[i,j] + >>> e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] + >>> print(e.shape) + (2, 5) - # Trace - einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] - # Batch matrix multiplication - einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k] - ``` + **Dot product** - To enable and control broadcasting, use an ellipsis. For example, to perform - batch matrix multiplication with NumPy-style broadcasting across the batch - dimensions, use: + >>> u = tf.random.normal(shape=[5]) + >>> v = tf.random.normal(shape=[5]) + >>> e = tf.einsum('i,i->', u, v) # output = sum_i u[i]*v[i] + >>> print(e.shape) + () - ```python - einsum('...ij,...jk->...ik', u, v) - ``` + **Outer product** + + >>> u = tf.random.normal(shape=[3]) + >>> v = tf.random.normal(shape=[5]) + >>> e = tf.einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] + >>> print(e.shape) + (3, 5) + + **Transpose** + + >>> m = tf.ones(2,3) + >>> e = tf.einsum('ij->ji', m0) # output[j,i] = m0[i,j] + >>> print(e.shape) + (3, 2) + + **Diag** + + >>> m = tf.reshape(tf.range(9), [3,3]) + >>> diag = tf.einsum('ii->i', m) + >>> print(diag.shape) + (3,) + + **Trace** + + >>> # Repeated indices are summed. + >>> trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] + >>> assert trace == sum(diag) + >>> print(trace.shape) + () + + **Batch matrix multiplication** + + >>> s = tf.random.normal(shape=[7,5,3]) + >>> t = tf.random.normal(shape=[7,3,2]) + >>> e = tf.einsum('bij,bjk->bik', s, t) + >>> # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k] + >>> print(e.shape) + (7, 5, 2) + + This method does not support broadcasting on named-axes. All axes with + matching labels should have the same length. If you have length-1 axes, + use `tf.squeseze` or `tf.reshape` to eliminate them. + + To write code that is agnostic to the number of indices in the input + use an ellipsis. The ellipsis is a placeholder for "whatever other indices + fit here". + + For example, to perform a NumPy-style broadcasting-batch-matrix multiplication + where the matrix multiply acts on the last two axes of the input, use: + + >>> s = tf.random.normal(shape=[11, 7, 5, 3]) + >>> t = tf.random.normal(shape=[11, 7, 3, 2]) + >>> e = tf.einsum('...ij,...jk->...ik', s, t) + >>> print(e.shape) + (11, 7, 5, 2) + + Einsum **will** broadcast over axes covered by the ellipsis. + + >>> s = tf.random.normal(shape=[11, 1, 5, 3]) + >>> t = tf.random.normal(shape=[1, 7, 3, 2]) + >>> e = tf.einsum('...ij,...jk->...ik', s, t) + >>> print(e.shape) + (11, 7, 5, 2) Args: equation: a `str` describing the contraction, in the same format as