Add LinearOperatorToeplitz for non-block Toeplitz matrices.

PiperOrigin-RevId: 246039525
This commit is contained in:
A. Unique TensorFlower 2019-04-30 15:57:47 -07:00 committed by TensorFlower Gardener
parent c14f7d33bb
commit d38ec4d493
8 changed files with 728 additions and 0 deletions

View File

@ -350,6 +350,31 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "linear_operator_toeplitz_test",
size = "medium",
srcs = ["linear_operator_toeplitz_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python/ops/linalg",
"//tensorflow/python/ops/signal",
],
shard_count = 5,
tags = [
"noasan", # times out, b/63678675
"optonly", # times out, b/79171797
],
xla_enable_strict_auto_jit = True,
)
cuda_py_test(
name = "linear_operator_zeros_test",
size = "medium",

View File

@ -0,0 +1,139 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import numpy as np
import scipy.linalg
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import spectral_ops_test_util
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.ops.linalg import linear_operator_toeplitz
from tensorflow.python.platform import test
linalg = linalg_lib
_to_complex = linear_operator_toeplitz._to_complex
class LinearOperatorToeplitzTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@contextlib.contextmanager
def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
"""We overwrite the FFT operation mapping for testing."""
with test.TestCase._constrain_devices_and_set_default(
self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
def setUp(self):
# TODO(srvasude): Lower these tolerances once specialized solve and
# determinants are implemented.
self._atol[dtypes.float32] = 2e-4
self._rtol[dtypes.float32] = 2e-5
self._atol[dtypes.complex64] = 4e-4
self._rtol[dtypes.complex64] = 4e-5
@staticmethod
def tests_to_skip():
# Skip solve tests, as these could have better stability
# (currently exercises the base class).
# TODO(srvasude): Enable these when solve is implemented.
return ["cholesky", "inverse", "solve", "solve_with_broadcast"]
@staticmethod
def operator_shapes_infos():
shape_info = linear_operator_test_util.OperatorShapesInfo
# non-batch operators (n, n) and batch operators.
return [
shape_info((1, 1)),
shape_info((1, 6, 6)),
shape_info((3, 4, 4)),
shape_info((2, 1, 3, 3))
]
def operator_and_matrix(
self, build_info, dtype, use_placeholder,
ensure_self_adjoint_and_pd=False):
shape = list(build_info.shape)
row = np.random.uniform(low=1., high=5., size=shape[:-1])
col = np.random.uniform(low=1., high=5., size=shape[:-1])
# Make sure first entry is the same
row[..., 0] = col[..., 0]
if ensure_self_adjoint_and_pd:
# Note that a Toeplitz matrix generated from a linearly decreasing
# non-negative sequence is positive definite. See
# https://www.math.cinvestav.mx/~grudsky/Papers/118_29062012_Albrecht.pdf
# for details.
row = np.linspace(start=10., stop=1., num=shape[-1])
# The entries for the first row and column should be the same to guarantee
# symmetric.
row = col
lin_op_row = math_ops.cast(row, dtype=dtype)
lin_op_col = math_ops.cast(col, dtype=dtype)
if use_placeholder:
lin_op_row = array_ops.placeholder_with_default(
lin_op_row, shape=None)
lin_op_col = array_ops.placeholder_with_default(
lin_op_col, shape=None)
operator = linear_operator_toeplitz.LinearOperatorToeplitz(
row=lin_op_row,
col=lin_op_col,
is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
is_positive_definite=True if ensure_self_adjoint_and_pd else None)
flattened_row = np.reshape(row, (-1, shape[-1]))
flattened_col = np.reshape(col, (-1, shape[-1]))
flattened_toeplitz = np.zeros(
[flattened_row.shape[0], shape[-1], shape[-1]])
for i in range(flattened_row.shape[0]):
flattened_toeplitz[i] = scipy.linalg.toeplitz(
flattened_col[i],
flattened_row[i])
matrix = np.reshape(flattened_toeplitz, shape)
matrix = math_ops.cast(matrix, dtype=dtype)
return operator, matrix
def test_scalar_row_col_raises(self):
with self.assertRaisesRegexp(ValueError, "must have at least 1 dimension"):
linear_operator_toeplitz.LinearOperatorToeplitz(1., 1.)
with self.assertRaisesRegexp(ValueError, "must have at least 1 dimension"):
linear_operator_toeplitz.LinearOperatorToeplitz([1.], 1.)
with self.assertRaisesRegexp(ValueError, "must have at least 1 dimension"):
linear_operator_toeplitz.LinearOperatorToeplitz(1., [1.])
if __name__ == "__main__":
linear_operator_test_util.add_tests(LinearOperatorToeplitzTest)
test.main()

View File

@ -37,6 +37,7 @@ from tensorflow.python.ops.linalg.linear_operator_identity import *
from tensorflow.python.ops.linalg.linear_operator_kronecker import *
from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
from tensorflow.python.ops.linalg.linear_operator_toeplitz import *
from tensorflow.python.ops.linalg.linear_operator_zeros import *
# pylint: enable=wildcard-import

View File

@ -0,0 +1,247 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`LinearOperator` acting like a Toeplitz matrix."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_circulant
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.util.tf_export import tf_export
__all__ = ["LinearOperatorToeplitz",]
@tf_export("linalg.LinearOperatorToeplitz")
class LinearOperatorToeplitz(linear_operator.LinearOperator):
"""`LinearOperator` acting like a [batch] of toeplitz matrices.
This operator acts like a [batch] Toeplitz matrix `A` with shape
`[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
an `N x N` matrix. This matrix `A` is not materialized, but for
purposes of broadcasting this shape will be relevant.
#### Description in terms of toeplitz matrices
Toeplitz means that `A` has constant diagonals. Hence, `A` can be generated
with two vectors. One represents the first column of the matrix, and the
other represents the first row.
Below is a 4 x 4 example:
```
A = |a b c d|
|e a b c|
|f e a b|
|g f e a|
```
#### Example of a Toeplitz operator.
```python
# Create a 3 x 3 Toeplitz operator.
col = [1., 2., 3.]
row = [1., 4., -9.]
operator = LinearOperatorToeplitz(col, row)
operator.to_dense()
==> [[1., 4., -9.],
[2., 1., 4.],
[3., 2., 1.]]
operator.shape
==> [3, 3]
operator.log_abs_determinant()
==> scalar Tensor
x = ... Shape [3, 4] Tensor
operator.matmul(x)
==> Shape [3, 4] Tensor
#### Shape compatibility
This operator acts on [batch] matrix with compatible shape.
`x` is a batch matrix with compatible shape for `matmul` and `solve` if
```
operator.shape = [B1,...,Bb] + [N, N], with b >= 0
x.shape = [C1,...,Cc] + [N, R],
and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd]
```
#### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
for `X = non_singular, self_adjoint, positive_definite, square`.
These have the following meaning:
* If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a
runtime assert. For example, finite floating point precision may result
in these promises being violated.
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
"""
def __init__(self,
col,
row,
is_non_singular=None,
is_self_adjoint=None,
is_positive_definite=None,
is_square=None,
name="LinearOperatorToeplitz"):
r"""Initialize a `LinearOperatorToeplitz`.
Args:
col: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
The first column of the operator. Allowed dtypes: `float16`, `float32`,
`float64`, `complex64`, `complex128`. Note that the first entry of
`col` is assumed to be the same as the first entry of `row`.
row: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
The first row of the operator. Allowed dtypes: `float16`, `float32`,
`float64`, `complex64`, `complex128`. Note that the first entry of
`row` is assumed to be the same as the first entry of `col`.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. If `diag.dtype` is real, this is auto-set to `True`.
is_positive_definite: Expect that this operator is positive definite,
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`.
"""
with ops.name_scope(name, values=[row, col]):
self._row = ops.convert_to_tensor(row, name="row")
self._col = ops.convert_to_tensor(col, name="col")
self._check_row_col(self._row, self._col)
circulant_col = array_ops.concat(
[self._col,
array_ops.zeros_like(self._col[..., 0:1]),
array_ops.reverse(self._row[..., 1:], axis=[-1])], axis=-1)
# To be used for matmul.
self._circulant = linear_operator_circulant.LinearOperatorCirculant(
fft_ops.fft(_to_complex(circulant_col)),
input_output_dtype=self._row.dtype)
if is_square is False: # pylint:disable=g-bool-id-comparison
raise ValueError("Only square Toeplitz operators currently supported.")
is_square = True
super(LinearOperatorToeplitz, self).__init__(
dtype=self._row.dtype,
graph_parents=[self._row, self._col],
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
def _check_row_col(self, row, col):
"""Static check of row and column."""
for name, tensor in [["row", row], ["col", col]]:
if tensor.get_shape().ndims is not None and tensor.get_shape().ndims < 1:
raise ValueError("Argument {} must have at least 1 dimension. "
"Found: {}".format(name, tensor))
if row.get_shape()[-1] is not None and col.get_shape()[-1] is not None:
if row.get_shape()[-1] != col.get_shape()[-1]:
raise ValueError(
"Expected square matrix, got row and col with mismatched "
"dimensions.")
def _shape(self):
# If d_shape = [5, 3], we return [5, 3, 3].
v_shape = array_ops.broadcast_static_shape(
self.row.shape, self.col.shape)
return v_shape.concatenate(v_shape[-1:])
def _shape_tensor(self):
v_shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(self.row),
array_ops.shape(self.col))
k = v_shape[-1]
return array_ops.concat((v_shape, [k]), 0)
def _assert_self_adjoint(self):
return check_ops.assert_equal(
self.row,
self.col,
message=("row and col are not the same, and "
"so this operator is not self-adjoint."))
# TODO(srvasude): Add efficient solver and determinant calculations to this
# class (based on Levinson recursion.)
def _matmul(self, x, adjoint=False, adjoint_arg=False):
# Given a Toeplitz matrix, we can embed it in a Circulant matrix to perform
# efficient matrix multiplications. Given a Toeplitz matrix with first row
# [t_0, t_1, ... t_{n-1}] and first column [t0, t_{-1}, ..., t_{-(n-1)},
# let C by the circulant matrix with first column [t0, t_{-1}, ...,
# t_{-(n-1)}, 0, t_{n-1}, ..., t_1]. Also adjoin to our input vector `x`
# `n` zeros, to make it a vector of length `2n` (call it y). It can be shown
# that if we take the first n entries of `Cy`, this is equal to the Toeplitz
# multiplication. See:
# http://math.mit.edu/icg/resources/teaching/18.085-spring2015/toeplitz.pdf
# for more details.
x = linalg.adjoint(x) if adjoint_arg else x
expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
result = self._circulant.matmul(
expanded_x, adjoint=adjoint, adjoint_arg=False)
return math_ops.cast(
result[..., :self.domain_dimension_tensor(), :],
self.dtype)
def _trace(self):
return math_ops.cast(
self.domain_dimension_tensor(),
dtype=self.dtype) * self.col[..., 0]
def _diag_part(self):
diag_entry = self.col[..., 0:1]
return diag_entry * array_ops.ones(
[self.domain_dimension_tensor()], self.dtype)
@property
def col(self):
return self._col
@property
def row(self):
return self._row
def _to_complex(x):
dtype = dtypes.complex64
if x.dtype in [dtypes.float64, dtypes.complex128]:
dtype = dtypes.complex128
return math_ops.cast(x, dtype)

View File

@ -0,0 +1,154 @@
path: "tensorflow.linalg.LinearOperatorToeplitz"
tf_class {
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_toeplitz.LinearOperatorToeplitz\'>"
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
name: "H"
mtype: "<type \'property\'>"
}
member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
member {
name: "col"
mtype: "<type \'property\'>"
}
member {
name: "domain_dimension"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "graph_parents"
mtype: "<type \'property\'>"
}
member {
name: "is_non_singular"
mtype: "<type \'property\'>"
}
member {
name: "is_positive_definite"
mtype: "<type \'property\'>"
}
member {
name: "is_self_adjoint"
mtype: "<type \'property\'>"
}
member {
name: "is_square"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"
}
member {
name: "row"
mtype: "<type \'property\'>"
}
member {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "tensor_rank"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'col\', \'row\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'LinearOperatorToeplitz\'], "
}
member_method {
name: "add_to_tensor"
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
name: "adjoint"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
}
member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
member_method {
name: "assert_positive_definite"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
}
member_method {
name: "assert_self_adjoint"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
}
member_method {
name: "batch_shape_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
}
member_method {
name: "cholesky"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
}
member_method {
name: "determinant"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
}
member_method {
name: "diag_part"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
}
member_method {
name: "domain_dimension_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
}
member_method {
name: "inverse"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
}
member_method {
name: "log_abs_determinant"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
}
member_method {
name: "matmul"
argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
}
member_method {
name: "matvec"
argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
}
member_method {
name: "range_dimension_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
}
member_method {
name: "shape_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
}
member_method {
name: "solve"
argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
}
member_method {
name: "solvevec"
argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
}
member_method {
name: "tensor_rank_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
}
member_method {
name: "to_dense"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
}
member_method {
name: "trace"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
}
}

View File

@ -64,6 +64,10 @@ tf_module {
name: "LinearOperatorScaledIdentity"
mtype: "<type \'type\'>"
}
member {
name: "LinearOperatorToeplitz"
mtype: "<type \'type\'>"
}
member {
name: "LinearOperatorZeros"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,154 @@
path: "tensorflow.linalg.LinearOperatorToeplitz"
tf_class {
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_toeplitz.LinearOperatorToeplitz\'>"
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
name: "H"
mtype: "<type \'property\'>"
}
member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
member {
name: "col"
mtype: "<type \'property\'>"
}
member {
name: "domain_dimension"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "graph_parents"
mtype: "<type \'property\'>"
}
member {
name: "is_non_singular"
mtype: "<type \'property\'>"
}
member {
name: "is_positive_definite"
mtype: "<type \'property\'>"
}
member {
name: "is_self_adjoint"
mtype: "<type \'property\'>"
}
member {
name: "is_square"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"
}
member {
name: "row"
mtype: "<type \'property\'>"
}
member {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "tensor_rank"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'col\', \'row\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'LinearOperatorToeplitz\'], "
}
member_method {
name: "add_to_tensor"
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
name: "adjoint"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
}
member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
member_method {
name: "assert_positive_definite"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
}
member_method {
name: "assert_self_adjoint"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
}
member_method {
name: "batch_shape_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
}
member_method {
name: "cholesky"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
}
member_method {
name: "determinant"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
}
member_method {
name: "diag_part"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
}
member_method {
name: "domain_dimension_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
}
member_method {
name: "inverse"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
}
member_method {
name: "log_abs_determinant"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
}
member_method {
name: "matmul"
argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
}
member_method {
name: "matvec"
argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
}
member_method {
name: "range_dimension_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
}
member_method {
name: "shape_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
}
member_method {
name: "solve"
argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
}
member_method {
name: "solvevec"
argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
}
member_method {
name: "tensor_rank_tensor"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
}
member_method {
name: "to_dense"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
}
member_method {
name: "trace"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
}
}

View File

@ -64,6 +64,10 @@ tf_module {
name: "LinearOperatorScaledIdentity"
mtype: "<type \'type\'>"
}
member {
name: "LinearOperatorToeplitz"
mtype: "<type \'type\'>"
}
member {
name: "LinearOperatorZeros"
mtype: "<type \'type\'>"