[TF.linalg LinearOperator] Add 'parameters' property to tf LinearOperator. (resubmission)
This matches the behavior of TFP Kernels, Distributions, Bijectors, etc, and allows us to trace the constructor arguments of all objects used to create Distributions and Kernels. PiperOrigin-RevId: 327530603 Change-Id: I9fe6502d4ec5f20d5c185a34e074d122776aeb2d
This commit is contained in:
parent
1087d48004
commit
88379ce456
@ -144,6 +144,35 @@ class SquareLinearOperatorBlockDiagTest(
|
|||||||
self.assertTrue(operator.is_non_singular)
|
self.assertTrue(operator.is_non_singular)
|
||||||
self.assertFalse(operator.is_self_adjoint)
|
self.assertFalse(operator.is_self_adjoint)
|
||||||
|
|
||||||
|
def test_is_x_parameters(self):
|
||||||
|
matrix = [[1., 0.], [1., 1.]]
|
||||||
|
sub_operator = linalg.LinearOperatorFullMatrix(matrix)
|
||||||
|
operator = block_diag.LinearOperatorBlockDiag(
|
||||||
|
[sub_operator],
|
||||||
|
is_positive_definite=True,
|
||||||
|
is_non_singular=True,
|
||||||
|
is_self_adjoint=False)
|
||||||
|
self.assertEqual(
|
||||||
|
operator.parameters,
|
||||||
|
{
|
||||||
|
"name": None,
|
||||||
|
"is_square": True,
|
||||||
|
"is_positive_definite": True,
|
||||||
|
"is_self_adjoint": False,
|
||||||
|
"is_non_singular": True,
|
||||||
|
"operators": [sub_operator],
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
sub_operator.parameters,
|
||||||
|
{
|
||||||
|
"is_non_singular": None,
|
||||||
|
"is_positive_definite": None,
|
||||||
|
"is_self_adjoint": None,
|
||||||
|
"is_square": None,
|
||||||
|
"matrix": matrix,
|
||||||
|
"name": "LinearOperatorFullMatrix",
|
||||||
|
})
|
||||||
|
|
||||||
def test_block_diag_adjoint_type(self):
|
def test_block_diag_adjoint_type(self):
|
||||||
matrix = [[1., 0.], [0., 1.]]
|
matrix = [[1., 0.], [0., 1.]]
|
||||||
operator = block_diag.LinearOperatorBlockDiag(
|
operator = block_diag.LinearOperatorBlockDiag(
|
||||||
|
@ -283,6 +283,18 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
|
|||||||
operator = linalg.LinearOperatorCirculant(
|
operator = linalg.LinearOperatorCirculant(
|
||||||
lin_op_spectrum, input_output_dtype=dtype)
|
lin_op_spectrum, input_output_dtype=dtype)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
operator.parameters,
|
||||||
|
{
|
||||||
|
"input_output_dtype": dtype,
|
||||||
|
"is_non_singular": None,
|
||||||
|
"is_positive_definite": None,
|
||||||
|
"is_self_adjoint": None,
|
||||||
|
"is_square": True,
|
||||||
|
"name": "LinearOperatorCirculant",
|
||||||
|
"spectrum": lin_op_spectrum,
|
||||||
|
})
|
||||||
|
|
||||||
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
|
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
|
||||||
|
|
||||||
return operator, mat
|
return operator, mat
|
||||||
@ -526,6 +538,20 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
|
|||||||
is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
|
is_self_adjoint=True if ensure_self_adjoint_and_pd else None,
|
||||||
input_output_dtype=dtype)
|
input_output_dtype=dtype)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
operator.parameters,
|
||||||
|
{
|
||||||
|
"input_output_dtype": dtype,
|
||||||
|
"is_non_singular": None,
|
||||||
|
"is_positive_definite": (
|
||||||
|
True if ensure_self_adjoint_and_pd else None),
|
||||||
|
"is_self_adjoint": (
|
||||||
|
True if ensure_self_adjoint_and_pd else None),
|
||||||
|
"is_square": True,
|
||||||
|
"name": "LinearOperatorCirculant2D",
|
||||||
|
"spectrum": lin_op_spectrum,
|
||||||
|
})
|
||||||
|
|
||||||
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
|
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
|
||||||
|
|
||||||
return operator, mat
|
return operator, mat
|
||||||
@ -570,6 +596,19 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
|
|||||||
operator = linalg.LinearOperatorCirculant2D(
|
operator = linalg.LinearOperatorCirculant2D(
|
||||||
lin_op_spectrum, input_output_dtype=dtype)
|
lin_op_spectrum, input_output_dtype=dtype)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
operator.parameters,
|
||||||
|
{
|
||||||
|
"input_output_dtype": dtype,
|
||||||
|
"is_non_singular": None,
|
||||||
|
"is_positive_definite": None,
|
||||||
|
"is_self_adjoint": None,
|
||||||
|
"is_square": True,
|
||||||
|
"name": "LinearOperatorCirculant2D",
|
||||||
|
"spectrum": lin_op_spectrum,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
|
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
|
||||||
|
|
||||||
return operator, mat
|
return operator, mat
|
||||||
@ -675,6 +714,18 @@ class LinearOperatorCirculant3DTest(test.TestCase):
|
|||||||
operator = linalg.LinearOperatorCirculant3D(spectrum)
|
operator = linalg.LinearOperatorCirculant3D(spectrum)
|
||||||
self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape)
|
self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
operator.parameters,
|
||||||
|
{
|
||||||
|
"input_output_dtype": dtypes.complex64,
|
||||||
|
"is_non_singular": None,
|
||||||
|
"is_positive_definite": None,
|
||||||
|
"is_self_adjoint": None,
|
||||||
|
"is_square": True,
|
||||||
|
"name": "LinearOperatorCirculant3D",
|
||||||
|
"spectrum": spectrum,
|
||||||
|
})
|
||||||
|
|
||||||
matrix_tensor = operator.to_dense()
|
matrix_tensor = operator.to_dense()
|
||||||
self.assertEqual(matrix_tensor.dtype, dtypes.complex64)
|
self.assertEqual(matrix_tensor.dtype, dtypes.complex64)
|
||||||
matrix_h = linalg.adjoint(matrix_tensor)
|
matrix_h = linalg.adjoint(matrix_tensor)
|
||||||
|
@ -43,6 +43,14 @@ class LinearOperatorShape(linalg.LinearOperator):
|
|||||||
is_self_adjoint=None,
|
is_self_adjoint=None,
|
||||||
is_positive_definite=None,
|
is_positive_definite=None,
|
||||||
is_square=None):
|
is_square=None):
|
||||||
|
parameters = dict(
|
||||||
|
shape=shape,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square
|
||||||
|
)
|
||||||
|
|
||||||
self._stored_shape = shape
|
self._stored_shape = shape
|
||||||
super(LinearOperatorShape, self).__init__(
|
super(LinearOperatorShape, self).__init__(
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
@ -50,7 +58,8 @@ class LinearOperatorShape(linalg.LinearOperator):
|
|||||||
is_non_singular=is_non_singular,
|
is_non_singular=is_non_singular,
|
||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square)
|
is_square=is_square,
|
||||||
|
parameters=parameters)
|
||||||
|
|
||||||
def _shape(self):
|
def _shape(self):
|
||||||
return tensor_shape.TensorShape(self._stored_shape)
|
return tensor_shape.TensorShape(self._stored_shape)
|
||||||
@ -71,13 +80,22 @@ class LinearOperatorMatmulSolve(linalg.LinearOperator):
|
|||||||
is_self_adjoint=None,
|
is_self_adjoint=None,
|
||||||
is_positive_definite=None,
|
is_positive_definite=None,
|
||||||
is_square=None):
|
is_square=None):
|
||||||
|
parameters = dict(
|
||||||
|
matrix=matrix,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square
|
||||||
|
)
|
||||||
|
|
||||||
self._matrix = ops.convert_to_tensor(matrix, name="matrix")
|
self._matrix = ops.convert_to_tensor(matrix, name="matrix")
|
||||||
super(LinearOperatorMatmulSolve, self).__init__(
|
super(LinearOperatorMatmulSolve, self).__init__(
|
||||||
dtype=self._matrix.dtype,
|
dtype=self._matrix.dtype,
|
||||||
is_non_singular=is_non_singular,
|
is_non_singular=is_non_singular,
|
||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square)
|
is_square=is_square,
|
||||||
|
parameters=parameters)
|
||||||
|
|
||||||
def _shape(self):
|
def _shape(self):
|
||||||
return self._matrix.shape
|
return self._matrix.shape
|
||||||
@ -109,6 +127,14 @@ class LinearOperatorTest(test.TestCase):
|
|||||||
self.assertAllEqual((1, 2), operator.batch_shape)
|
self.assertAllEqual((1, 2), operator.batch_shape)
|
||||||
self.assertAllEqual(4, operator.domain_dimension)
|
self.assertAllEqual(4, operator.domain_dimension)
|
||||||
self.assertAllEqual(3, operator.range_dimension)
|
self.assertAllEqual(3, operator.range_dimension)
|
||||||
|
expected_parameters = {
|
||||||
|
"is_non_singular": None,
|
||||||
|
"is_positive_definite": None,
|
||||||
|
"is_self_adjoint": None,
|
||||||
|
"is_square": None,
|
||||||
|
"shape": (1, 2, 3, 4),
|
||||||
|
}
|
||||||
|
self.assertEqual(expected_parameters, operator.parameters)
|
||||||
|
|
||||||
def test_all_shape_methods_defined_by_the_one_method_shape(self):
|
def test_all_shape_methods_defined_by_the_one_method_shape(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -131,6 +157,19 @@ class LinearOperatorTest(test.TestCase):
|
|||||||
self.assertTrue(operator.is_self_adjoint)
|
self.assertTrue(operator.is_self_adjoint)
|
||||||
self.assertFalse(operator.is_positive_definite)
|
self.assertFalse(operator.is_positive_definite)
|
||||||
|
|
||||||
|
def test_nontrivial_parameters(self):
|
||||||
|
matrix = rng.randn(2, 3, 4)
|
||||||
|
matrix_ph = array_ops.placeholder_with_default(input=matrix, shape=None)
|
||||||
|
operator = LinearOperatorMatmulSolve(matrix_ph)
|
||||||
|
expected_parameters = {
|
||||||
|
"is_non_singular": None,
|
||||||
|
"is_positive_definite": None,
|
||||||
|
"is_self_adjoint": None,
|
||||||
|
"is_square": None,
|
||||||
|
"matrix": matrix_ph,
|
||||||
|
}
|
||||||
|
self.assertEqual(expected_parameters, operator.parameters)
|
||||||
|
|
||||||
def test_generic_to_dense_method_non_square_matrix_static(self):
|
def test_generic_to_dense_method_non_square_matrix_static(self):
|
||||||
matrix = rng.randn(2, 3, 4)
|
matrix = rng.randn(2, 3, 4)
|
||||||
operator = LinearOperatorMatmulSolve(matrix)
|
operator = LinearOperatorMatmulSolve(matrix)
|
||||||
|
@ -146,6 +146,27 @@ class LinearOperator(module.Module):
|
|||||||
* If `is_X == False`, callers should expect the operator to not have `X`.
|
* If `is_X == False`, callers should expect the operator to not have `X`.
|
||||||
* If `is_X == None` (the default), callers should have no expectation either
|
* If `is_X == None` (the default), callers should have no expectation either
|
||||||
way.
|
way.
|
||||||
|
|
||||||
|
#### Initialization parameters
|
||||||
|
|
||||||
|
All subclasses of `LinearOperator` are expected to pass a `parameters`
|
||||||
|
argument to `super().__init__()`. This should be a `dict` containing
|
||||||
|
the unadulterated arguments passed to the subclass `__init__`. For example,
|
||||||
|
`MyLinearOperator` with an initializer should look like:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __init__(self, operator, is_square=False, name=None):
|
||||||
|
parameters = dict(
|
||||||
|
operator=operator,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
...
|
||||||
|
super().__init__(..., parameters=parameters)
|
||||||
|
```
|
||||||
|
|
||||||
|
Users can then access `my_linear_operator.parameters` to see all arguments
|
||||||
|
passed to its initializer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
@ -158,7 +179,8 @@ class LinearOperator(module.Module):
|
|||||||
is_self_adjoint=None,
|
is_self_adjoint=None,
|
||||||
is_positive_definite=None,
|
is_positive_definite=None,
|
||||||
is_square=None,
|
is_square=None,
|
||||||
name=None):
|
name=None,
|
||||||
|
parameters=None):
|
||||||
r"""Initialize the `LinearOperator`.
|
r"""Initialize the `LinearOperator`.
|
||||||
|
|
||||||
**This is a private method for subclass use.**
|
**This is a private method for subclass use.**
|
||||||
@ -179,6 +201,8 @@ class LinearOperator(module.Module):
|
|||||||
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
|
https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices
|
||||||
is_square: Expect that this operator acts like square [batch] matrices.
|
is_square: Expect that this operator acts like square [batch] matrices.
|
||||||
name: A name for this `LinearOperator`.
|
name: A name for this `LinearOperator`.
|
||||||
|
parameters: Python `dict` of parameters used to instantiate this
|
||||||
|
`LinearOperator`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If any member of graph_parents is `None` or not a `Tensor`.
|
ValueError: If any member of graph_parents is `None` or not a `Tensor`.
|
||||||
@ -210,6 +234,8 @@ class LinearOperator(module.Module):
|
|||||||
self._is_non_singular = is_non_singular
|
self._is_non_singular = is_non_singular
|
||||||
self._is_self_adjoint = is_self_adjoint
|
self._is_self_adjoint = is_self_adjoint
|
||||||
self._is_positive_definite = is_positive_definite
|
self._is_positive_definite = is_positive_definite
|
||||||
|
self._parameters = self._no_dependency(parameters)
|
||||||
|
self._parameters_sanitized = False
|
||||||
self._name = name or type(self).__name__
|
self._name = name or type(self).__name__
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -221,6 +247,11 @@ class LinearOperator(module.Module):
|
|||||||
with ops.name_scope(full_name) as scope:
|
with ops.name_scope(full_name) as scope:
|
||||||
yield scope
|
yield scope
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self):
|
||||||
|
"""Dictionary of parameters used to instantiate this `LinearOperator`."""
|
||||||
|
return dict(self._parameters)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
"""The `DType` of `Tensor`s handled by this `LinearOperator`."""
|
"""The `DType` of `Tensor`s handled by this `LinearOperator`."""
|
||||||
|
@ -112,6 +112,14 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `operator.is_non_singular` is False.
|
ValueError: If `operator.is_non_singular` is False.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
operator=operator,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
self._operator = operator
|
self._operator = operator
|
||||||
|
|
||||||
@ -150,6 +158,7 @@ class LinearOperatorAdjoint(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents(operator.graph_parents)
|
self._set_graph_parents(operator.graph_parents)
|
||||||
|
@ -163,6 +163,15 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
|||||||
TypeError: If all operators do not have the same `dtype`.
|
TypeError: If all operators do not have the same `dtype`.
|
||||||
ValueError: If `operators` is empty or are non-square.
|
ValueError: If `operators` is empty or are non-square.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
operators=operators,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
# Validate operators.
|
# Validate operators.
|
||||||
check_ops.assert_proper_iterable(operators)
|
check_ops.assert_proper_iterable(operators)
|
||||||
operators = list(operators)
|
operators = list(operators)
|
||||||
@ -224,6 +233,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=True,
|
is_square=True,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
|
@ -231,6 +231,15 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
|||||||
ValueError: If `operators` is empty, contains an erroneous number of
|
ValueError: If `operators` is empty, contains an erroneous number of
|
||||||
elements, or contains operators with incompatible shapes.
|
elements, or contains operators with incompatible shapes.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
operators=operators,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
# Validate operators.
|
# Validate operators.
|
||||||
check_ops.assert_proper_iterable(operators)
|
check_ops.assert_proper_iterable(operators)
|
||||||
for row in operators:
|
for row in operators:
|
||||||
@ -256,6 +265,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def _validate_num_operators(self):
|
def _validate_num_operators(self):
|
||||||
|
@ -63,6 +63,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=None,
|
is_self_adjoint=None,
|
||||||
is_positive_definite=None,
|
is_positive_definite=None,
|
||||||
is_square=True,
|
is_square=True,
|
||||||
|
parameters=None,
|
||||||
name="LinearOperatorCirculant"):
|
name="LinearOperatorCirculant"):
|
||||||
r"""Initialize an `_BaseLinearOperatorCirculant`.
|
r"""Initialize an `_BaseLinearOperatorCirculant`.
|
||||||
|
|
||||||
@ -83,6 +84,8 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
|||||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||||
#Extension_for_non_symmetric_matrices
|
#Extension_for_non_symmetric_matrices
|
||||||
is_square: Expect that this operator acts like square [batch] matrices.
|
is_square: Expect that this operator acts like square [batch] matrices.
|
||||||
|
parameters: Python `dict` of parameters used to instantiate this
|
||||||
|
`LinearOperator`.
|
||||||
name: A name to prepend to all ops created by this class.
|
name: A name to prepend to all ops created by this class.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -121,6 +124,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents([self.spectrum])
|
self._set_graph_parents([self.spectrum])
|
||||||
@ -744,6 +748,15 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
|
|||||||
is_square: Expect that this operator acts like square [batch] matrices.
|
is_square: Expect that this operator acts like square [batch] matrices.
|
||||||
name: A name to prepend to all ops created by this class.
|
name: A name to prepend to all ops created by this class.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
spectrum=spectrum,
|
||||||
|
input_output_dtype=input_output_dtype,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
super(LinearOperatorCirculant, self).__init__(
|
super(LinearOperatorCirculant, self).__init__(
|
||||||
spectrum,
|
spectrum,
|
||||||
block_depth=1,
|
block_depth=1,
|
||||||
@ -752,6 +765,7 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def _eigvals(self):
|
def _eigvals(self):
|
||||||
@ -924,6 +938,15 @@ class LinearOperatorCirculant2D(_BaseLinearOperatorCirculant):
|
|||||||
is_square: Expect that this operator acts like square [batch] matrices.
|
is_square: Expect that this operator acts like square [batch] matrices.
|
||||||
name: A name to prepend to all ops created by this class.
|
name: A name to prepend to all ops created by this class.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
spectrum=spectrum,
|
||||||
|
input_output_dtype=input_output_dtype,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
super(LinearOperatorCirculant2D, self).__init__(
|
super(LinearOperatorCirculant2D, self).__init__(
|
||||||
spectrum,
|
spectrum,
|
||||||
block_depth=2,
|
block_depth=2,
|
||||||
@ -932,6 +955,7 @@ class LinearOperatorCirculant2D(_BaseLinearOperatorCirculant):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
@ -1074,6 +1098,15 @@ class LinearOperatorCirculant3D(_BaseLinearOperatorCirculant):
|
|||||||
is_square: Expect that this operator acts like square [batch] matrices.
|
is_square: Expect that this operator acts like square [batch] matrices.
|
||||||
name: A name to prepend to all ops created by this class.
|
name: A name to prepend to all ops created by this class.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
spectrum=spectrum,
|
||||||
|
input_output_dtype=input_output_dtype,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
super(LinearOperatorCirculant3D, self).__init__(
|
super(LinearOperatorCirculant3D, self).__init__(
|
||||||
spectrum,
|
spectrum,
|
||||||
block_depth=3,
|
block_depth=3,
|
||||||
@ -1082,6 +1115,7 @@ class LinearOperatorCirculant3D(_BaseLinearOperatorCirculant):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,6 +143,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
|||||||
TypeError: If all operators do not have the same `dtype`.
|
TypeError: If all operators do not have the same `dtype`.
|
||||||
ValueError: If `operators` is empty.
|
ValueError: If `operators` is empty.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
operators=operators,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name)
|
||||||
|
|
||||||
# Validate operators.
|
# Validate operators.
|
||||||
check_ops.assert_proper_iterable(operators)
|
check_ops.assert_proper_iterable(operators)
|
||||||
operators = list(operators)
|
operators = list(operators)
|
||||||
@ -182,6 +190,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents(graph_parents)
|
self._set_graph_parents(graph_parents)
|
||||||
|
@ -139,6 +139,14 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
|||||||
TypeError: If `diag.dtype` is not an allowed type.
|
TypeError: If `diag.dtype` is not an allowed type.
|
||||||
ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
|
ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
diag=diag,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
with ops.name_scope(name, values=[diag]):
|
with ops.name_scope(name, values=[diag]):
|
||||||
self._diag = linear_operator_util.convert_nonref_to_tensor(
|
self._diag = linear_operator_util.convert_nonref_to_tensor(
|
||||||
@ -163,6 +171,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents([self._diag])
|
self._set_graph_parents([self._diag])
|
||||||
|
@ -133,6 +133,14 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
|||||||
Raises:
|
Raises:
|
||||||
TypeError: If `diag.dtype` is not an allowed type.
|
TypeError: If `diag.dtype` is not an allowed type.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
matrix=matrix,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
with ops.name_scope(name, values=[matrix]):
|
with ops.name_scope(name, values=[matrix]):
|
||||||
self._matrix = linear_operator_util.convert_nonref_to_tensor(
|
self._matrix = linear_operator_util.convert_nonref_to_tensor(
|
||||||
@ -146,6 +154,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents([self._matrix])
|
self._set_graph_parents([self._matrix])
|
||||||
|
@ -123,6 +123,14 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
|||||||
ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
|
ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
|
||||||
not `False` or `is_square` is not `True`.
|
not `False` or `is_square` is not `True`.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
reflection_axis=reflection_axis,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
with ops.name_scope(name, values=[reflection_axis]):
|
with ops.name_scope(name, values=[reflection_axis]):
|
||||||
self._reflection_axis = linear_operator_util.convert_nonref_to_tensor(
|
self._reflection_axis = linear_operator_util.convert_nonref_to_tensor(
|
||||||
@ -152,6 +160,7 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents([self._reflection_axis])
|
self._set_graph_parents([self._reflection_axis])
|
||||||
|
@ -252,6 +252,17 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
|||||||
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
|
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
|
||||||
TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable).
|
TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable).
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
num_rows=num_rows,
|
||||||
|
batch_shape=batch_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
assert_proper_shapes=assert_proper_shapes,
|
||||||
|
name=name)
|
||||||
|
|
||||||
dtype = dtype or dtypes.float32
|
dtype = dtype or dtypes.float32
|
||||||
self._assert_proper_shapes = assert_proper_shapes
|
self._assert_proper_shapes = assert_proper_shapes
|
||||||
|
|
||||||
@ -272,6 +283,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
|
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
|
||||||
@ -596,6 +608,16 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
|||||||
ValueError: If `num_rows` is determined statically to be non-scalar, or
|
ValueError: If `num_rows` is determined statically to be non-scalar, or
|
||||||
negative.
|
negative.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
num_rows=num_rows,
|
||||||
|
multiplier=multiplier,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
assert_proper_shapes=assert_proper_shapes,
|
||||||
|
name=name)
|
||||||
|
|
||||||
self._assert_proper_shapes = assert_proper_shapes
|
self._assert_proper_shapes = assert_proper_shapes
|
||||||
|
|
||||||
with ops.name_scope(name, values=[multiplier, num_rows]):
|
with ops.name_scope(name, values=[multiplier, num_rows]):
|
||||||
@ -620,6 +642,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
self._num_rows = linear_operator_util.shape_tensor(
|
self._num_rows = linear_operator_util.shape_tensor(
|
||||||
|
@ -113,6 +113,14 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `operator.is_non_singular` is False.
|
ValueError: If `operator.is_non_singular` is False.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
operator=operator,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
self._operator = operator
|
self._operator = operator
|
||||||
|
|
||||||
@ -163,6 +171,7 @@ class LinearOperatorInversion(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents(operator.graph_parents)
|
self._set_graph_parents(operator.graph_parents)
|
||||||
|
@ -167,6 +167,15 @@ class LinearOperatorKronecker(linear_operator.LinearOperator):
|
|||||||
TypeError: If all operators do not have the same `dtype`.
|
TypeError: If all operators do not have the same `dtype`.
|
||||||
ValueError: If `operators` is empty.
|
ValueError: If `operators` is empty.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
operators=operators,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
# Validate operators.
|
# Validate operators.
|
||||||
check_ops.assert_proper_iterable(operators)
|
check_ops.assert_proper_iterable(operators)
|
||||||
operators = list(operators)
|
operators = list(operators)
|
||||||
@ -226,6 +235,7 @@ class LinearOperatorKronecker(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
# TODO(b/143910018) Remove graph_parents in V3.
|
# TODO(b/143910018) Remove graph_parents in V3.
|
||||||
self._set_graph_parents(graph_parents)
|
self._set_graph_parents(graph_parents)
|
||||||
|
@ -182,6 +182,18 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `is_X` flags are set in an inconsistent way.
|
ValueError: If `is_X` flags are set in an inconsistent way.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
base_operator=base_operator,
|
||||||
|
u=u,
|
||||||
|
diag_update=diag_update,
|
||||||
|
v=v,
|
||||||
|
is_diag_update_positive=is_diag_update_positive,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
dtype = base_operator.dtype
|
dtype = base_operator.dtype
|
||||||
|
|
||||||
if diag_update is not None:
|
if diag_update is not None:
|
||||||
@ -253,6 +265,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
self._set_graph_parents(graph_parents)
|
self._set_graph_parents(graph_parents)
|
||||||
|
|
||||||
|
@ -137,6 +137,14 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If `is_square` is `False`.
|
ValueError: If `is_square` is `False`.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
tril=tril,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
if is_square is False:
|
if is_square is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -155,6 +163,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
self._set_graph_parents([self._tril])
|
self._set_graph_parents([self._tril])
|
||||||
|
|
||||||
|
@ -140,6 +140,15 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
|
|||||||
ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
|
ValueError: `is_self_adjoint` is not `True`, `is_positive_definite` is
|
||||||
not `False` or `is_square` is not `True`.
|
not `False` or `is_square` is not `True`.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
perm=perm,
|
||||||
|
dtype=dtype,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
with ops.name_scope(name, values=[perm]):
|
with ops.name_scope(name, values=[perm]):
|
||||||
self._perm = linear_operator_util.convert_nonref_to_tensor(
|
self._perm = linear_operator_util.convert_nonref_to_tensor(
|
||||||
@ -160,6 +169,7 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def _check_perm(self, perm):
|
def _check_perm(self, perm):
|
||||||
|
@ -138,6 +138,15 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
|
|||||||
is_square: Expect that this operator acts like square [batch] matrices.
|
is_square: Expect that this operator acts like square [batch] matrices.
|
||||||
name: A name for this `LinearOperator`.
|
name: A name for this `LinearOperator`.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
col=col,
|
||||||
|
row=row,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
with ops.name_scope(name, values=[row, col]):
|
with ops.name_scope(name, values=[row, col]):
|
||||||
self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row")
|
self._row = linear_operator_util.convert_nonref_to_tensor(row, name="row")
|
||||||
@ -155,7 +164,9 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
self._set_graph_parents([self._row, self._col])
|
self._set_graph_parents([self._row, self._col])
|
||||||
|
|
||||||
def _check_row_col(self, row, col):
|
def _check_row_col(self, row, col):
|
||||||
|
@ -171,6 +171,15 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
|
|||||||
TypeError: If `diag.dtype` is not an allowed type.
|
TypeError: If `diag.dtype` is not an allowed type.
|
||||||
ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
|
ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
diagonals=diagonals,
|
||||||
|
diagonals_format=diagonals_format,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
with ops.name_scope(name, values=[diagonals]):
|
with ops.name_scope(name, values=[diagonals]):
|
||||||
if diagonals_format not in _DIAGONAL_FORMATS:
|
if diagonals_format not in _DIAGONAL_FORMATS:
|
||||||
@ -193,6 +202,7 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
def _shape(self):
|
def _shape(self):
|
||||||
|
@ -176,6 +176,19 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
|
|||||||
ValueError: If any of the following is not `True`:
|
ValueError: If any of the following is not `True`:
|
||||||
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
|
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
|
||||||
"""
|
"""
|
||||||
|
parameters = dict(
|
||||||
|
num_rows=num_rows,
|
||||||
|
num_columns=num_columns,
|
||||||
|
batch_shape=batch_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
is_non_singular=is_non_singular,
|
||||||
|
is_self_adjoint=is_self_adjoint,
|
||||||
|
is_positive_definite=is_positive_definite,
|
||||||
|
is_square=is_square,
|
||||||
|
assert_proper_shapes=assert_proper_shapes,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
dtype = dtype or dtypes.float32
|
dtype = dtype or dtypes.float32
|
||||||
self._assert_proper_shapes = assert_proper_shapes
|
self._assert_proper_shapes = assert_proper_shapes
|
||||||
|
|
||||||
@ -194,6 +207,7 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
|
|||||||
is_self_adjoint=is_self_adjoint,
|
is_self_adjoint=is_self_adjoint,
|
||||||
is_positive_definite=is_positive_definite,
|
is_positive_definite=is_positive_definite,
|
||||||
is_square=is_square,
|
is_square=is_square,
|
||||||
|
parameters=parameters,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
|
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operator"
|
name: "operator"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -59,6 +59,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -59,6 +59,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -59,6 +59,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -51,6 +51,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operator"
|
name: "operator"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -66,6 +66,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "perm"
|
name: "perm"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -55,6 +55,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -58,6 +58,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -49,6 +49,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -75,7 +79,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\', \'parameters\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_to_tensor"
|
name: "add_to_tensor"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operator"
|
name: "operator"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -59,6 +59,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -59,6 +59,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -59,6 +59,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -51,6 +51,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operator"
|
name: "operator"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "operators"
|
name: "operators"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -66,6 +66,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "perm"
|
name: "perm"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -55,6 +55,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -54,6 +54,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -58,6 +58,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -50,6 +50,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
|
@ -49,6 +49,10 @@ tf_class {
|
|||||||
name: "name_scope"
|
name: "name_scope"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "parameters"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "range_dimension"
|
name: "range_dimension"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -75,7 +79,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
argspec: "args=[\'self\', \'dtype\', \'graph_parents\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\', \'parameters\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "add_to_tensor"
|
name: "add_to_tensor"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user