[TF.linalg LinearOperator] Add 'parameters' property to tf LinearOperator. (resubmission)

This matches the behavior of TFP Kernels, Distributions, Bijectors, etc, and
allows us to trace the constructor arguments of all objects used to create
Distributions and Kernels.

PiperOrigin-RevId: 327530603
Change-Id: I9fe6502d4ec5f20d5c185a34e074d122776aeb2d
This commit is contained in:
Eugene Brevdo 2020-08-19 16:32:23 -07:00 committed by TensorFlower Gardener
parent 1087d48004
commit 88379ce456
63 changed files with 531 additions and 5 deletions

View File

@ -144,6 +144,35 @@ class SquareLinearOperatorBlockDiagTest(
self.assertTrue(operator.is_non_singular)
self.assertFalse(operator.is_self_adjoint)
def test_is_x_parameters(self):
matrix = [[1., 0.], [1., 1.]]
sub_operator = linalg.LinearOperatorFullMatrix(matrix)
operator = block_diag.LinearOperatorBlockDiag(
[sub_operator],
is_positive_definite=True,
is_non_singular=True,
is_self_adjoint=False)
self.assertEqual(
operator.parameters,
{
"name": None,
"is_square": True,
"is_positive_definite": True,
"is_self_adjoint": False,
"is_non_singular": True,
"operators": [sub_operator],
})
self.assertEqual(
sub_operator.parameters,
{
"is_non_singular": None,
"is_positive_definite": None,
"is_self_adjoint": None,
"is_square": None,
"matrix": matrix,
"name": "LinearOperatorFullMatrix",
})
def test_block_diag_adjoint_type(self):
matrix = [[1., 0.], [0., 1.]]
operator = block_diag.LinearOperatorBlockDiag(

View File

@ -283,6 +283,18 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
self.assertEqual(
operator.parameters,
{
"input_output_dtype": dtype,
"is_non_singular": None,
"is_positive_definite": None,
"is_self_adjoint": None,
"is_square": True,
"name": "LinearOperatorCirculant",
"spectrum": lin_op_spectrum,
})
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
return operator, mat
@ -526,6 +538,20 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
input_output_dtype=dtype)
self.assertEqual(
operator.parameters,
{
"input_output_dtype": dtype,
"is_non_singular": None,
"is_positive_definite": (
True if ensure_self_adjoint_and_pd else None),
"is_self_adjoint": (
True if ensure_self_adjoint_and_pd else None),
"is_square": True,
"name": "LinearOperatorCirculant2D",
"spectrum": lin_op_spectrum,
})
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
return operator, mat
@ -570,6 +596,19 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
self.assertEqual(
operator.parameters,
{
"input_output_dtype": dtype,
"is_non_singular": None,
"is_positive_definite": None,
"is_self_adjoint": None,
"is_square": True,
"name": "LinearOperatorCirculant2D",
"spectrum": lin_op_spectrum,
}
)
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
return operator, mat
@ -675,6 +714,18 @@ class LinearOperatorCirculant3DTest(test.TestCase):
operator = linalg.LinearOperatorCirculant3D(spectrum)
self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape)
self.assertEqual(
operator.parameters,
{
"input_output_dtype": dtypes.complex64,
"is_non_singular": None,
"is_positive_definite": None,
"is_self_adjoint": None,
"is_square": True,
"name": "LinearOperatorCirculant3D",
"spectrum": spectrum,
})
matrix_tensor = operator.to_dense()
self.assertEqual(matrix_tensor.dtype, dtypes.complex64)
matrix_h = linalg.adjoint(matrix_tensor)

View File

@ -43,6 +43,14 @@ class LinearOperatorShape(linalg.LinearOperator):
is_self_adjoint=None,
is_positive_definite=None,
is_square=None):
parameters = dict(
shape=shape,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square
)
self._stored_shape = shape
super(LinearOperatorShape, self).__init__(
dtype=dtypes.float32,
@ -50,7 +58,8 @@ class LinearOperatorShape(linalg.LinearOperator):
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square)
is_square=is_square,
parameters=parameters)
def _shape(self):
return tensor_shape.TensorShape(self._stored_shape)
@ -71,13 +80,22 @@ class LinearOperatorMatmulSolve(linalg.LinearOperator):
is_self_adjoint=None,
is_positive_definite=None,
is_square=None):
parameters = dict(
matrix=matrix,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square
)
self._matrix = ops.convert_to_tensor(matrix, name="matrix")
super(LinearOperatorMatmulSolve, self).__init__(
dtype=self._matrix.dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square)
is_square=is_square,
parameters=parameters)
def _shape(self):
return self._matrix.shape
@ -109,6 +127,14 @@ class LinearOperatorTest(test.TestCase):
self.assertAllEqual((1, 2), operator.batch_shape)
self.assertAllEqual(4, operator.domain_dimension)
self.assertAllEqual(3, operator.range_dimension)
expected_parameters = {
"is_non_singular": None,
"is_positive_definite": None,
"is_self_adjoint": None,
"is_square": None,
"shape": (1, 2, 3, 4),
}
self.assertEqual(expected_parameters, operator.parameters)
def test_all_shape_methods_defined_by_the_one_method_shape(self):
with self.cached_session():
@ -131,6 +157,19 @@ class LinearOperatorTest(test.TestCase):
self.assertTrue(operator.is_self_adjoint)
self.assertFalse(operator.is_positive_definite)
def test_nontrivial_parameters(self):
matrix = rng.randn(2, 3, 4)
matrix_ph = array_ops.placeholder_with_default(input=matrix, shape=None)
operator = LinearOperatorMatmulSolve(matrix_ph)
expected_parameters = {
"is_non_singular": None,
"is_positive_definite": None,
"is_self_adjoint": None,
"is_square": None,
"matrix": matrix_ph,
}
self.assertEqual(expected_parameters, operator.parameters)
def test_generic_to_dense_method_non_square_matrix_static(self):
matrix = rng.randn(2, 3, 4)
operator = LinearOperatorMatmulSolve(matrix)

View File

@ -146,6 +146,27 @@ class LinearOperator(module.Module):
* If `is_X == False`, callers should expect the operator to not have `X`.
* If `is_X == None` (the default), callers should have no expectation either
way.
#### Initialization parameters
All subclasses of `LinearOperator` are expected to pass a `parameters`
argument to `super().__init__()`. This should be a `dict` containing
the unadulterated arguments passed to the subclass `__init__`. For example,
`MyLinearOperator` with an initializer should look like:
```python
def __init__(self, operator, is_square=False, name=None):
parameters = dict(
operator=operator,
is_square=is_square,
name=name
)
...
super().__init__(..., parameters=parameters)
```
Users can then access `my_linear_operator.parameters` to see all arguments
passed to its initializer.
"""
# TODO(b/143910018) Remove graph_parents in V3.
@ -158,7 +179,8 @@ class LinearOperator(module.Module):
is_self_adjoint=None,
is_positive_definite=None,
is_square=None,
name=None):
name=None,
parameters=None):
r"""Initialize the `LinearOperator`.
**This is a private method for subclass use.**
@ -179,6 +201,8 @@ class LinearOperator(module.Module):
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`.
parameters: Python `dict` of parameters used to instantiate this
`LinearOperator`.
Raises:
ValueError: If any member of graph_parents is `None` or not a `Tensor`.
@ -210,6 +234,8 @@ class LinearOperator(module.Module):
self._is_non_singular = is_non_singular
self._is_self_adjoint = is_self_adjoint
self._is_positive_definite = is_positive_definite
self._parameters = self._no_dependency(parameters)
self._parameters_sanitized = False
self._name = name or type(self).__name__
@contextlib.contextmanager
@ -221,6 +247,11 @@ class LinearOperator(module.Module):
with ops.name_scope(full_name) as scope:
yield scope
@property
def parameters(self):
"""Dictionary of parameters used to instantiate this `LinearOperator`."""
return dict(self._parameters)
@property
def dtype(self):
"""The `DType` of `Tensor`s handled by this `LinearOperator`."""

View File

@ -112,6 +112,14 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
Raises:
ValueError: If `operator.is_non_singular` is False.
"""
parameters = dict(
operator=operator,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name,
)
self._operator = operator
@ -150,6 +158,7 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(operator.graph_parents)

View File

@ -163,6 +163,15 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
TypeError: If all operators do not have the same `dtype`.
ValueError: If `operators` is empty or are non-square.
"""
parameters = dict(
operators=operators,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
# Validate operators.
check_ops.assert_proper_iterable(operators)
operators = list(operators)
@ -224,6 +233,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=True,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.

View File

@ -231,6 +231,15 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
ValueError: If `operators` is empty, contains an erroneous number of
elements, or contains operators with incompatible shapes.
"""
parameters = dict(
operators=operators,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
# Validate operators.
check_ops.assert_proper_iterable(operators)
for row in operators:
@ -256,6 +265,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
def _validate_num_operators(self):

View File

@ -63,6 +63,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
is_self_adjoint=None,
is_positive_definite=None,
is_square=True,
parameters=None,
name="LinearOperatorCirculant"):
r"""Initialize an `_BaseLinearOperatorCirculant`.
@ -83,6 +84,8 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
parameters: Python `dict` of parameters used to instantiate this
`LinearOperator`.
name: A name to prepend to all ops created by this class.
Raises:
@ -121,6 +124,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self.spectrum])
@ -744,6 +748,15 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
is_square: Expect that this operator acts like square [batch] matrices.
name: A name to prepend to all ops created by this class.
"""
parameters = dict(
spectrum=spectrum,
input_output_dtype=input_output_dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
super(LinearOperatorCirculant, self).__init__(
spectrum,
block_depth=1,
@ -752,6 +765,7 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
def _eigvals(self):
@ -924,6 +938,15 @@ class LinearOperatorCirculant2D(_BaseLinearOperatorCirculant):
is_square: Expect that this operator acts like square [batch] matrices.
name: A name to prepend to all ops created by this class.
"""
parameters = dict(
spectrum=spectrum,
input_output_dtype=input_output_dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
super(LinearOperatorCirculant2D, self).__init__(
spectrum,
block_depth=2,
@ -932,6 +955,7 @@ class LinearOperatorCirculant2D(_BaseLinearOperatorCirculant):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
@ -1074,6 +1098,15 @@ class LinearOperatorCirculant3D(_BaseLinearOperatorCirculant):
is_square: Expect that this operator acts like square [batch] matrices.
name: A name to prepend to all ops created by this class.
"""
parameters = dict(
spectrum=spectrum,
input_output_dtype=input_output_dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
super(LinearOperatorCirculant3D, self).__init__(
spectrum,
block_depth=3,
@ -1082,6 +1115,7 @@ class LinearOperatorCirculant3D(_BaseLinearOperatorCirculant):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)

View File

@ -143,6 +143,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
TypeError: If all operators do not have the same `dtype`.
ValueError: If `operators` is empty.
"""
parameters = dict(
operators=operators,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# Validate operators.
check_ops.assert_proper_iterable(operators)
operators = list(operators)
@ -182,6 +190,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(graph_parents)

View File

@ -139,6 +139,14 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
TypeError: If `diag.dtype` is not an allowed type.
ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
"""
parameters = dict(
diag=diag,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
with ops.name_scope(name, values=[diag]):
self._diag = linear_operator_util.convert_nonref_to_tensor(
@ -163,6 +171,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self._diag])

View File

@ -133,6 +133,14 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
Raises:
TypeError: If `diag.dtype` is not an allowed type.
"""
parameters = dict(
matrix=matrix,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
with ops.name_scope(name, values=[matrix]):
self._matrix = linear_operator_util.convert_nonref_to_tensor(
@ -146,6 +154,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self._matrix])

View File

@ -123,6 +123,14 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
not `False` or `is_square` is not `True`.
"""
parameters = dict(
reflection_axis=reflection_axis,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
with ops.name_scope(name, values=[reflection_axis]):
self._reflection_axis = linear_operator_util.convert_nonref_to_tensor(
@ -152,6 +160,7 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents([self._reflection_axis])

View File

@ -252,6 +252,17 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable).
"""
parameters = dict(
num_rows=num_rows,
batch_shape=batch_shape,
dtype=dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
assert_proper_shapes=assert_proper_shapes,
name=name)
dtype = dtype or dtypes.float32
self._assert_proper_shapes = assert_proper_shapes
@ -272,6 +283,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
@ -596,6 +608,16 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
ValueError: If `num_rows` is determined statically to be non-scalar, or
negative.
"""
parameters = dict(
num_rows=num_rows,
multiplier=multiplier,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
assert_proper_shapes=assert_proper_shapes,
name=name)
self._assert_proper_shapes = assert_proper_shapes
with ops.name_scope(name, values=[multiplier, num_rows]):
@ -620,6 +642,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
self._num_rows = linear_operator_util.shape_tensor(

View File

@ -113,6 +113,14 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
Raises:
ValueError: If `operator.is_non_singular` is False.
"""
parameters = dict(
operator=operator,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
self._operator = operator
@ -163,6 +171,7 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(operator.graph_parents)

View File

@ -167,6 +167,15 @@ class LinearOperatorKronecker(linear_operator.LinearOperator):
TypeError: If all operators do not have the same `dtype`.
ValueError: If `operators` is empty.
"""
parameters = dict(
operators=operators,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
# Validate operators.
check_ops.assert_proper_iterable(operators)
operators = list(operators)
@ -226,6 +235,7 @@ class LinearOperatorKronecker(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
# TODO(b/143910018) Remove graph_parents in V3.
self._set_graph_parents(graph_parents)

View File

@ -182,6 +182,18 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
Raises:
ValueError: If `is_X` flags are set in an inconsistent way.
"""
parameters = dict(
base_operator=base_operator,
u=u,
diag_update=diag_update,
v=v,
is_diag_update_positive=is_diag_update_positive,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
dtype = base_operator.dtype
if diag_update is not None:
@ -253,6 +265,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
self._set_graph_parents(graph_parents)

View File

@ -137,6 +137,14 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
Raises:
ValueError: If `is_square` is `False`.
"""
parameters = dict(
tril=tril,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
if is_square is False:
raise ValueError(
@ -155,6 +163,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
self._set_graph_parents([self._tril])

View File

@ -140,6 +140,15 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
not `False` or `is_square` is not `True`.
"""
parameters = dict(
perm=perm,
dtype=dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
with ops.name_scope(name, values=[perm]):
self._perm = linear_operator_util.convert_nonref_to_tensor(
@ -160,6 +169,7 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
def _check_perm(self, perm):

View File

@ -138,6 +138,15 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`.
"""
parameters = dict(
col=col,
row=row,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
with ops.name_scope(name, values=[row, col]):
self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row")
@ -155,7 +164,9 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
self._set_graph_parents([self._row, self._col])
def _check_row_col(self, row, col):

View File

@ -171,6 +171,15 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
TypeError: If `diag.dtype` is not an allowed type.
ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
"""
parameters = dict(
diagonals=diagonals,
diagonals_format=diagonals_format,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name
)
with ops.name_scope(name, values=[diagonals]):
if diagonals_format not in _DIAGONAL_FORMATS:
@ -193,6 +202,7 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
def _shape(self):

View File

@ -176,6 +176,19 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
ValueError: If any of the following is not `True`:
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
"""
parameters = dict(
num_rows=num_rows,
num_columns=num_columns,
batch_shape=batch_shape,
dtype=dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
assert_proper_shapes=assert_proper_shapes,
name=name
)
dtype = dtype or dtypes.float32
self._assert_proper_shapes = assert_proper_shapes
@ -194,6 +207,7 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
parameters=parameters,
name=name)
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")

View File

@ -54,6 +54,10 @@ tf_class {
name: "operator"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -51,6 +51,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operator"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -66,6 +66,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "perm"
mtype: "<type \'property\'>"

View File

@ -55,6 +55,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -58,6 +58,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -49,6 +49,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"
@ -75,7 +79,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\', \'parameters\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_to_tensor"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operator"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -59,6 +59,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -51,6 +51,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operator"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "operators"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -66,6 +66,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "perm"
mtype: "<type \'property\'>"

View File

@ -55,6 +55,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -54,6 +54,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -58,6 +58,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -50,6 +50,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"

View File

@ -49,6 +49,10 @@ tf_class {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "parameters"
mtype: "<type \'property\'>"
}
member {
name: "range_dimension"
mtype: "<type \'property\'>"
@ -75,7 +79,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\', \'parameters\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_to_tensor"