Remove tensorflow/contrib/linalg library. linalg remains in core.

PiperOrigin-RevId: 213352573
This commit is contained in:
Ian Langmore 2018-09-17 15:46:30 -07:00 committed by TensorFlower Gardener
parent 3365cd1cc7
commit d5f4c3aa59
11 changed files with 26 additions and 1000 deletions

View File

@ -41,7 +41,6 @@
/tensorflow/contrib/labeled_tensor/ @shoyer
/tensorflow/contrib/layers/ @fchollet @martinwicke
/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
/tensorflow/contrib/linalg/ @langmore
/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
/tensorflow/contrib/lookup/ @ysuematsu @andreasst
/tensorflow/contrib/losses/ @alextp @ispirmustafa

View File

@ -60,7 +60,6 @@ py_library(
"//tensorflow/contrib/learn",
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
"//tensorflow/contrib/libsvm",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lite/python:lite",

View File

@ -63,7 +63,6 @@ from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import legacy_seq2seq
from tensorflow.contrib import linalg
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import losses

View File

@ -273,9 +273,6 @@ tensorflow/contrib/libsvm
tensorflow/contrib/libsvm/python
tensorflow/contrib/libsvm/python/kernel_tests
tensorflow/contrib/libsvm/python/ops
tensorflow/contrib/linalg
tensorflow/contrib/linalg/python
tensorflow/contrib/linalg/python/ops
tensorflow/contrib/linear_optimizer
tensorflow/contrib/linear_optimizer/kernels
tensorflow/contrib/linear_optimizer/kernels/g3doc

View File

@ -183,7 +183,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
file(GLOB_RECURSE tf_test_src_py
${tf_test_src_py}
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"

View File

@ -25,7 +25,6 @@ py_library(
"`tf.contrib.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
@ -61,7 +60,6 @@ py_library(
":bijectors_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/learn",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
@ -706,8 +704,8 @@ cuda_py_test(
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@ -722,8 +720,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python/ops/linalg",
],
shard_count = 4,
tags = ["noasan"], # times out, http://b/78588814
@ -739,8 +737,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@ -794,8 +792,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -831,8 +829,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -852,8 +850,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -871,8 +869,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -907,8 +905,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -926,10 +924,10 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
@ -945,8 +943,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -964,8 +962,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -983,8 +981,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1002,8 +1000,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1021,8 +1019,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1040,8 +1038,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1075,8 +1073,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1126,8 +1124,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1161,8 +1159,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1180,8 +1178,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1201,8 +1199,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1221,8 +1219,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1240,8 +1238,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1259,8 +1257,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1278,8 +1276,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1297,8 +1295,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@ -1316,8 +1314,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",

View File

@ -1,44 +0,0 @@
# Description:
# Contains classes that provide access to common method of a [batch] matrix,
# without the need to instantiate the matrix.
# This allows for exploitation of structure, as well as a generic interface
# suitable for iterative solvers.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "linalg_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:util",
"//tensorflow/python/ops/linalg",
"@six_archive//:six",
],
)
cuda_py_test(
name = "linear_operator_addition_test",
size = "small",
srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
additional_deps = [
":linalg_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)

View File

@ -1,58 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Linear algebra libraries.
See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
guide.
@@LinearOperator
@@LinearOperatorBlockDiag
@@LinearOperatorCirculant
@@LinearOperatorCirculant2D
@@LinearOperatorCirculant3D
@@LinearOperatorDiag
@@LinearOperatorIdentity
@@LinearOperatorScaledIdentity
@@LinearOperatorFullMatrix
@@LinearOperatorKronecker
@@LinearOperatorLowerTriangular
@@LinearOperatorLowRankUpdate
@@LinearOperatorComposition
@@add_operators
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
from tensorflow.python.ops.linalg.linear_operator import *
from tensorflow.python.ops.linalg.linear_operator_block_diag import *
from tensorflow.python.ops.linalg.linear_operator_circulant import *
from tensorflow.python.ops.linalg.linear_operator_composition import *
from tensorflow.python.ops.linalg.linear_operator_diag import *
from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
from tensorflow.python.ops.linalg.linear_operator_identity import *
from tensorflow.python.ops.linalg.linear_operator_kronecker import *
from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)

View File

@ -1,19 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ops module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

View File

@ -1,412 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.linalg.python.ops import linear_operator_addition
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.platform import test
linalg = linalg_lib
random_seed.set_random_seed(23)
rng = np.random.RandomState(0)
add_operators = linear_operator_addition.add_operators
# pylint: disable=unused-argument
class _BadAdder(linear_operator_addition._Adder):
"""Adder that will fail if used."""
def can_add(self, op1, op2):
raise AssertionError("BadAdder.can_add called!")
def _add(self, op1, op2, operator_name, hints):
raise AssertionError("This line should not be reached")
# pylint: enable=unused-argument
class LinearOperatorAdditionCorrectnessTest(test.TestCase):
"""Tests correctness of addition with combinations of a few Adders.
Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
add_operators should reduce all operators resulting in one single operator.
This shows that we are able to correctly combine adders using the tiered
system. All Adders should be tested separately, and there is no need to test
every Adder within this class.
"""
def test_one_operator_is_returned_unchanged(self):
op_a = linalg.LinearOperatorDiag([1., 1.])
op_sum = add_operators([op_a])
self.assertEqual(1, len(op_sum))
self.assertTrue(op_sum[0] is op_a)
def test_at_least_one_operators_required(self):
with self.assertRaisesRegexp(ValueError, "must contain at least one"):
add_operators([])
def test_attempting_to_add_numbers_raises(self):
with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
add_operators([1, 2])
def test_two_diag_operators(self):
op_a = linalg.LinearOperatorDiag(
[1., 1.], is_positive_definite=True, name="A")
op_b = linalg.LinearOperatorDiag(
[2., 2.], is_positive_definite=True, name="B")
with self.cached_session():
op_sum = add_operators([op_a, op_b])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
# Adding positive definite operators produces positive def.
self.assertTrue(op.is_positive_definite)
# Real diagonal ==> self-adjoint.
self.assertTrue(op.is_self_adjoint)
# Positive definite ==> non-singular
self.assertTrue(op.is_non_singular)
# Enforce particular name for this simple case
self.assertEqual("Add/B__A/", op.name)
def test_three_diag_operators(self):
op1 = linalg.LinearOperatorDiag(
[1., 1.], is_positive_definite=True, name="op1")
op2 = linalg.LinearOperatorDiag(
[2., 2.], is_positive_definite=True, name="op2")
op3 = linalg.LinearOperatorDiag(
[3., 3.], is_positive_definite=True, name="op3")
with self.cached_session():
op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
# Adding positive definite operators produces positive def.
self.assertTrue(op.is_positive_definite)
# Real diagonal ==> self-adjoint.
self.assertTrue(op.is_self_adjoint)
# Positive definite ==> non-singular
self.assertTrue(op.is_non_singular)
def test_diag_tril_diag(self):
op1 = linalg.LinearOperatorDiag(
[1., 1.], is_non_singular=True, name="diag_a")
op2 = linalg.LinearOperatorLowerTriangular(
[[2., 0.], [0., 2.]],
is_self_adjoint=True,
is_non_singular=True,
name="tril")
op3 = linalg.LinearOperatorDiag(
[3., 3.], is_non_singular=True, name="diag_b")
with self.cached_session():
op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular))
self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
# The diag operators will be self-adjoint (because real and diagonal).
# The TriL operator has the self-adjoint hint set.
self.assertTrue(op.is_self_adjoint)
# Even though op1/2/3 are non-singular, this does not imply op is.
# Since no custom hint was provided, we default to None (unknown).
self.assertEqual(None, op.is_non_singular)
def test_matrix_diag_tril_diag_uses_custom_name(self):
op0 = linalg.LinearOperatorFullMatrix(
[[-1., -1.], [-1., -1.]], name="matrix")
op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
op2 = linalg.LinearOperatorLowerTriangular(
[[2., 0.], [1.5, 2.]], name="tril")
op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
with self.cached_session():
op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
self.assertEqual(1, len(op_sum))
op = op_sum[0]
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix))
self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
self.assertEqual("my_operator", op.name)
def test_incompatible_domain_dimensions_raises(self):
op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
add_operators([op1, op2])
def test_incompatible_range_dimensions_raises(self):
op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
add_operators([op1, op2])
def test_non_broadcastable_batch_shape_raises(self):
op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
add_operators([op1, op2])
class LinearOperatorOrderOfAdditionTest(test.TestCase):
"""Test that the order of addition is done as specified by tiers."""
def test_tier_0_additions_done_in_tier_0(self):
diag1 = linalg.LinearOperatorDiag([1.])
diag2 = linalg.LinearOperatorDiag([1.])
diag3 = linalg.LinearOperatorDiag([1.])
addition_tiers = [
[linear_operator_addition._AddAndReturnDiag()],
[_BadAdder()],
]
# Should not raise since all were added in tier 0, and tier 1 (with the
# _BadAdder) was never reached.
op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
self.assertEqual(1, len(op_sum))
self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag))
def test_tier_1_additions_done_by_tier_1(self):
diag1 = linalg.LinearOperatorDiag([1.])
diag2 = linalg.LinearOperatorDiag([1.])
tril = linalg.LinearOperatorLowerTriangular([[1.]])
addition_tiers = [
[linear_operator_addition._AddAndReturnDiag()],
[linear_operator_addition._AddAndReturnTriL()],
[_BadAdder()],
]
# Should not raise since all were added by tier 1, and the
# _BadAdder) was never reached.
op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
self.assertEqual(1, len(op_sum))
self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
diag1 = linalg.LinearOperatorDiag([1.])
diag2 = linalg.LinearOperatorDiag([1.])
tril = linalg.LinearOperatorLowerTriangular([[1.]])
addition_tiers = [
[linear_operator_addition._AddAndReturnTriL()],
[linear_operator_addition._AddAndReturnDiag()],
[_BadAdder()],
]
# Tier 0 could convert to TriL, and this converted everything to TriL,
# including the Diags.
# Tier 1 was never used.
# Tier 2 was never used (therefore, _BadAdder didn't raise).
op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
self.assertEqual(1, len(op_sum))
self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
def test_cannot_add_everything_so_return_more_than_one_operator(self):
diag1 = linalg.LinearOperatorDiag([1.])
diag2 = linalg.LinearOperatorDiag([2.])
tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
addition_tiers = [
[linear_operator_addition._AddAndReturnDiag()],
]
# Tier 0 (the only tier) can only convert to Diag, so it combines the two
# diags, but the TriL is unchanged.
# Result should contain two operators, one Diag, one TriL.
op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
self.assertEqual(2, len(op_sum))
found_diag = False
found_tril = False
with self.cached_session():
for op in op_sum:
if isinstance(op, linalg.LinearOperatorDiag):
found_diag = True
self.assertAllClose([[3.]], op.to_dense().eval())
if isinstance(op, linalg.LinearOperatorLowerTriangular):
found_tril = True
self.assertAllClose([[5.]], op.to_dense().eval())
self.assertTrue(found_diag and found_tril)
def test_intermediate_tier_is_not_skipped(self):
diag1 = linalg.LinearOperatorDiag([1.])
diag2 = linalg.LinearOperatorDiag([1.])
tril = linalg.LinearOperatorLowerTriangular([[1.]])
addition_tiers = [
[linear_operator_addition._AddAndReturnDiag()],
[_BadAdder()],
[linear_operator_addition._AddAndReturnTriL()],
]
# tril cannot be added in tier 0, and the intermediate tier 1 with the
# BadAdder will catch it and raise.
with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
class AddAndReturnScaledIdentityTest(test.TestCase):
def setUp(self):
self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
def test_identity_plus_identity(self):
id1 = linalg.LinearOperatorIdentity(num_rows=2)
id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
hints = linear_operator_addition._Hints(
is_positive_definite=True, is_non_singular=True)
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
with self.cached_session():
self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
self.assertEqual("my_operator", operator.name)
def test_identity_plus_scaled_identity(self):
id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
hints = linear_operator_addition._Hints(
is_positive_definite=True, is_non_singular=True)
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
with self.cached_session():
self.assertAllClose(3.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
self.assertEqual("my_operator", operator.name)
def test_scaled_identity_plus_scaled_identity(self):
id1 = linalg.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[2.2, 2.2, 2.2])
id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
hints = linear_operator_addition._Hints(
is_positive_definite=True, is_non_singular=True)
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
with self.cached_session():
self.assertAllClose(1.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
self.assertEqual("my_operator", operator.name)
class AddAndReturnDiagTest(test.TestCase):
def setUp(self):
self._adder = linear_operator_addition._AddAndReturnDiag()
def test_identity_plus_identity_returns_diag(self):
id1 = linalg.LinearOperatorIdentity(num_rows=2)
id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
hints = linear_operator_addition._Hints(
is_positive_definite=True, is_non_singular=True)
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
with self.cached_session():
self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
self.assertEqual("my_operator", operator.name)
def test_diag_plus_diag(self):
diag1 = rng.rand(2, 3, 4)
diag2 = rng.rand(4)
op1 = linalg.LinearOperatorDiag(diag1)
op2 = linalg.LinearOperatorDiag(diag2)
hints = linear_operator_addition._Hints(
is_positive_definite=True, is_non_singular=True)
self.assertTrue(self._adder.can_add(op1, op2))
operator = self._adder.add(op1, op2, "my_operator", hints)
self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
with self.cached_session():
self.assertAllClose(
linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
self.assertEqual("my_operator", operator.name)
class AddAndReturnTriLTest(test.TestCase):
def setUp(self):
self._adder = linear_operator_addition._AddAndReturnTriL()
def test_diag_plus_tril(self):
diag = linalg.LinearOperatorDiag([1., 2.])
tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
hints = linear_operator_addition._Hints(
is_positive_definite=True, is_non_singular=True)
self.assertTrue(self._adder.can_add(diag, diag))
self.assertTrue(self._adder.can_add(diag, tril))
operator = self._adder.add(diag, tril, "my_operator", hints)
self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular))
with self.cached_session():
self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
self.assertEqual("my_operator", operator.name)
class AddAndReturnMatrixTest(test.TestCase):
def setUp(self):
self._adder = linear_operator_addition._AddAndReturnMatrix()
def test_diag_plus_diag(self):
diag1 = linalg.LinearOperatorDiag([1., 2.])
diag2 = linalg.LinearOperatorDiag([-1., 3.])
hints = linear_operator_addition._Hints(
is_positive_definite=False, is_non_singular=False)
self.assertTrue(self._adder.can_add(diag1, diag2))
operator = self._adder.add(diag1, diag2, "my_operator", hints)
self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix))
with self.cached_session():
self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
self.assertFalse(operator.is_positive_definite)
self.assertFalse(operator.is_non_singular)
self.assertEqual("my_operator", operator.name)
if __name__ == "__main__":
test.main()

View File

@ -1,432 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Add one or more `LinearOperators` efficiently."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.linalg import linear_operator_full_matrix
from tensorflow.python.ops.linalg import linear_operator_identity
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
__all__ = []
def add_operators(operators,
operator_name=None,
addition_tiers=None,
name=None):
"""Efficiently add one or more linear operators.
Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
operators `[B1, B2,...]` such that
```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
The operators `Bk` result by adding some of the `Ak`, as allowed by
`addition_tiers`.
Example of efficient adding of diagonal operators.
```python
A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
# Use two tiers, the first contains an Adder that returns Diag. Since both
# A1 and A2 are Diag, they can use this Adder. The second tier will not be
# used.
addition_tiers = [
[_AddAndReturnDiag()],
[_AddAndReturnMatrix()]]
B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
len(B_list)
==> 1
B_list[0].__class__.__name__
==> 'LinearOperatorDiag'
B_list[0].to_dense()
==> [[3., 0.],
[0., 3.]]
B_list[0].name
==> 'Add/A1__A2/'
```
Args:
operators: Iterable of `LinearOperator` objects with same `dtype`, domain
and range dimensions, and broadcastable batch shapes.
operator_name: String name for returned `LinearOperator`. Defaults to
concatenation of "Add/A__B/" that indicates the order of addition steps.
addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
is a list of `Adder` objects. This function attempts to do all additions
in tier `i` before trying tier `i + 1`.
name: A name for this `Op`. Defaults to `add_operators`.
Returns:
Subclass of `LinearOperator`. Class and order of addition may change as new
(and better) addition strategies emerge.
Raises:
ValueError: If `operators` argument is empty.
ValueError: If shapes are incompatible.
"""
# Default setting
if addition_tiers is None:
addition_tiers = _DEFAULT_ADDITION_TIERS
# Argument checking.
check_ops.assert_proper_iterable(operators)
operators = list(reversed(operators))
if len(operators) < 1:
raise ValueError(
"Argument 'operators' must contain at least one operator. "
"Found: %s" % operators)
if not all(
isinstance(op, linear_operator.LinearOperator) for op in operators):
raise TypeError(
"Argument 'operators' must contain only LinearOperator instances. "
"Found: %s" % operators)
_static_check_for_same_dimensions(operators)
_static_check_for_broadcastable_batch_shape(operators)
graph_parents = []
for operator in operators:
graph_parents.extend(operator.graph_parents)
with ops.name_scope(name or "add_operators", values=graph_parents):
# Additions done in one of the tiers. Try tier 0, 1,...
ops_to_try_at_next_tier = list(operators)
for tier in addition_tiers:
ops_to_try_at_this_tier = ops_to_try_at_next_tier
ops_to_try_at_next_tier = []
while ops_to_try_at_this_tier:
op1 = ops_to_try_at_this_tier.pop()
op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
if op2 is not None:
# Will try to add the result of this again at this same tier.
new_operator = adder.add(op1, op2, operator_name)
ops_to_try_at_this_tier.append(new_operator)
else:
ops_to_try_at_next_tier.append(op1)
return ops_to_try_at_next_tier
def _pop_a_match_at_tier(op1, operator_list, tier):
# Search from the back of list to the front in order to create nice default
# order of operations.
for i in range(1, len(operator_list) + 1):
op2 = operator_list[-i]
for adder in tier:
if adder.can_add(op1, op2):
return operator_list.pop(-i), adder
return None, None
def _infer_hints_allowing_override(op1, op2, hints):
"""Infer hints from op1 and op2. hints argument is an override.
Args:
op1: LinearOperator
op2: LinearOperator
hints: _Hints object holding "is_X" boolean hints to use for returned
operator.
If some hint is None, try to set using op1 and op2. If the
hint is provided, ignore op1 and op2 hints. This allows an override
of previous hints, but does not allow forbidden hints (e.g. you still
cannot say a real diagonal operator is not self-adjoint.
Returns:
_Hints object.
"""
hints = hints or _Hints()
# If A, B are self-adjoint, then so is A + B.
if hints.is_self_adjoint is None:
is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
else:
is_self_adjoint = hints.is_self_adjoint
# If A, B are positive definite, then so is A + B.
if hints.is_positive_definite is None:
is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
else:
is_positive_definite = hints.is_positive_definite
# A positive definite operator is always non-singular.
if is_positive_definite and hints.is_positive_definite is None:
is_non_singular = True
else:
is_non_singular = hints.is_non_singular
return _Hints(
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite)
def _static_check_for_same_dimensions(operators):
"""ValueError if operators determined to have different dimensions."""
if len(operators) < 2:
return
domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
if op.domain_dimension.value is not None]
if len(set(value for name, value in domain_dimensions)) > 1:
raise ValueError("Operators must have the same domain dimension. Found: %s"
% domain_dimensions)
range_dimensions = [(op.name, op.range_dimension.value) for op in operators
if op.range_dimension.value is not None]
if len(set(value for name, value in range_dimensions)) > 1:
raise ValueError("Operators must have the same range dimension. Found: %s" %
range_dimensions)
def _static_check_for_broadcastable_batch_shape(operators):
"""ValueError if operators determined to have non-broadcastable shapes."""
if len(operators) < 2:
return
# This will fail if they cannot be broadcast together.
batch_shape = operators[0].batch_shape
for op in operators[1:]:
batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
class _Hints(object):
"""Holds 'is_X' flags that every LinearOperator is initialized with."""
def __init__(self,
is_non_singular=None,
is_positive_definite=None,
is_self_adjoint=None):
self.is_non_singular = is_non_singular
self.is_positive_definite = is_positive_definite
self.is_self_adjoint = is_self_adjoint
################################################################################
# Classes to add two linear operators.
################################################################################
@six.add_metaclass(abc.ABCMeta)
class _Adder(object):
"""Abstract base class to add two operators.
Each `Adder` acts independently, adding everything it can, paying no attention
as to whether another `Adder` could have done the addition more efficiently.
"""
@property
def name(self):
return self.__class__.__name__
@abc.abstractmethod
def can_add(self, op1, op2):
"""Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
pass
@abc.abstractmethod
def _add(self, op1, op2, operator_name, hints):
# Derived classes can assume op1 and op2 have been validated, e.g. they have
# the same dtype, and their domain/range dimensions match.
pass
def add(self, op1, op2, operator_name, hints=None):
"""Return new `LinearOperator` acting like `op1 + op2`.
Args:
op1: `LinearOperator`
op2: `LinearOperator`, with `shape` and `dtype` such that adding to
`op1` is allowed.
operator_name: `String` name to give to returned `LinearOperator`
hints: `_Hints` object. Returned `LinearOperator` will be created with
these hints.
Returns:
`LinearOperator`
"""
updated_hints = _infer_hints_allowing_override(op1, op2, hints)
if operator_name is None:
operator_name = "Add/" + op1.name + "__" + op2.name + "/"
values = op1.graph_parents + op2.graph_parents
scope_name = self.name
if scope_name.startswith("_"):
scope_name = scope_name[1:]
with ops.name_scope(scope_name, values=values):
return self._add(op1, op2, operator_name, updated_hints)
class _AddAndReturnScaledIdentity(_Adder):
"""Handles additions resulting in an Identity family member.
The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
is closed under addition. This `Adder` respects that, and returns an Identity
"""
def can_add(self, op1, op2):
types = {_type(op1), _type(op2)}
return not types.difference(_IDENTITY_FAMILY)
def _add(self, op1, op2, operator_name, hints):
# Will build a LinearOperatorScaledIdentity.
if _type(op1) == _SCALED_IDENTITY:
multiplier_1 = op1.multiplier
else:
multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
if _type(op2) == _SCALED_IDENTITY:
multiplier_2 = op2.multiplier
else:
multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
return linear_operator_identity.LinearOperatorScaledIdentity(
num_rows=op1.range_dimension_tensor(),
multiplier=multiplier_1 + multiplier_2,
is_non_singular=hints.is_non_singular,
is_self_adjoint=hints.is_self_adjoint,
is_positive_definite=hints.is_positive_definite,
name=operator_name)
class _AddAndReturnDiag(_Adder):
"""Handles additions resulting in a Diag operator."""
def can_add(self, op1, op2):
types = {_type(op1), _type(op2)}
return not types.difference(_DIAG_LIKE)
def _add(self, op1, op2, operator_name, hints):
return linear_operator_diag.LinearOperatorDiag(
diag=op1.diag_part() + op2.diag_part(),
is_non_singular=hints.is_non_singular,
is_self_adjoint=hints.is_self_adjoint,
is_positive_definite=hints.is_positive_definite,
name=operator_name)
class _AddAndReturnTriL(_Adder):
"""Handles additions resulting in a TriL operator."""
def can_add(self, op1, op2):
types = {_type(op1), _type(op2)}
return not types.difference(_DIAG_LIKE.union({_TRIL}))
def _add(self, op1, op2, operator_name, hints):
if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
op_add_to_tensor, op_other = op1, op2
else:
op_add_to_tensor, op_other = op2, op1
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
is_non_singular=hints.is_non_singular,
is_self_adjoint=hints.is_self_adjoint,
is_positive_definite=hints.is_positive_definite,
name=operator_name)
class _AddAndReturnMatrix(_Adder):
""""Handles additions resulting in a `LinearOperatorFullMatrix`."""
def can_add(self, op1, op2): # pylint: disable=unused-argument
return isinstance(op1, linear_operator.LinearOperator) and isinstance(
op2, linear_operator.LinearOperator)
def _add(self, op1, op2, operator_name, hints):
if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
op_add_to_tensor, op_other = op1, op2
else:
op_add_to_tensor, op_other = op2, op1
return linear_operator_full_matrix.LinearOperatorFullMatrix(
matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
is_non_singular=hints.is_non_singular,
is_self_adjoint=hints.is_self_adjoint,
is_positive_definite=hints.is_positive_definite,
name=operator_name)
################################################################################
# Constants designating types of LinearOperators
################################################################################
# Type name constants for LinearOperator classes.
_IDENTITY = "identity"
_SCALED_IDENTITY = "scaled_identity"
_DIAG = "diag"
_TRIL = "tril"
_MATRIX = "matrix"
# Groups of operators.
_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
# operators with an efficient .add_to_tensor() method.
_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
def _type(operator):
"""Returns the type name constant (e.g. _TRIL) for operator."""
if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
return _DIAG
if isinstance(operator,
linear_operator_lower_triangular.LinearOperatorLowerTriangular):
return _TRIL
if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
return _MATRIX
if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
return _IDENTITY
if isinstance(operator,
linear_operator_identity.LinearOperatorScaledIdentity):
return _SCALED_IDENTITY
raise TypeError("Operator type unknown: %s" % operator)
################################################################################
# Addition tiers:
# We attempt to use Adders in tier K before K+1.
#
# Organize tiers to
# (i) reduce O(..) complexity of forming final operator, and
# (ii) produce the "most efficient" final operator.
# Dev notes:
# * Results of addition at tier K will be added at tier K or higher.
# * Tiers may change, and we warn the user that it may change.
################################################################################
# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
# dense matrix. So it is sometimes very inefficient.
_DEFAULT_ADDITION_TIERS = [
[_AddAndReturnScaledIdentity()],
[_AddAndReturnDiag()],
[_AddAndReturnTriL()],
[_AddAndReturnMatrix()],
]