Remove tensorflow/contrib/linalg library. linalg remains in core.
PiperOrigin-RevId: 213352573
This commit is contained in:
parent
3365cd1cc7
commit
d5f4c3aa59
@ -41,7 +41,6 @@
|
||||
/tensorflow/contrib/labeled_tensor/ @shoyer
|
||||
/tensorflow/contrib/layers/ @fchollet @martinwicke
|
||||
/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
|
||||
/tensorflow/contrib/linalg/ @langmore
|
||||
/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
|
||||
/tensorflow/contrib/lookup/ @ysuematsu @andreasst
|
||||
/tensorflow/contrib/losses/ @alextp @ispirmustafa
|
||||
|
@ -60,7 +60,6 @@ py_library(
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
|
||||
"//tensorflow/contrib/libsvm",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
|
||||
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
|
||||
"//tensorflow/contrib/lite/python:lite",
|
||||
|
@ -63,7 +63,6 @@ from tensorflow.contrib import labeled_tensor
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import learn
|
||||
from tensorflow.contrib import legacy_seq2seq
|
||||
from tensorflow.contrib import linalg
|
||||
from tensorflow.contrib import linear_optimizer
|
||||
from tensorflow.contrib import lookup
|
||||
from tensorflow.contrib import losses
|
||||
|
@ -273,9 +273,6 @@ tensorflow/contrib/libsvm
|
||||
tensorflow/contrib/libsvm/python
|
||||
tensorflow/contrib/libsvm/python/kernel_tests
|
||||
tensorflow/contrib/libsvm/python/ops
|
||||
tensorflow/contrib/linalg
|
||||
tensorflow/contrib/linalg/python
|
||||
tensorflow/contrib/linalg/python/ops
|
||||
tensorflow/contrib/linear_optimizer
|
||||
tensorflow/contrib/linear_optimizer/kernels
|
||||
tensorflow/contrib/linear_optimizer/kernels/g3doc
|
||||
|
@ -183,7 +183,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
file(GLOB_RECURSE tf_test_src_py
|
||||
${tf_test_src_py}
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
|
||||
|
@ -25,7 +25,6 @@ py_library(
|
||||
"`tf.contrib.distributions` to `tfp.distributions`."),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
@ -61,7 +60,6 @@ py_library(
|
||||
":bijectors_py",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
@ -706,8 +704,8 @@ cuda_py_test(
|
||||
":bijectors_py",
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -722,8 +720,8 @@ cuda_py_test(
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = ["noasan"], # times out, http://b/78588814
|
||||
@ -739,8 +737,8 @@ cuda_py_test(
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -794,8 +792,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -831,8 +829,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -852,8 +850,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -871,8 +869,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -907,8 +905,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -926,10 +924,10 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -945,8 +943,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -964,8 +962,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -983,8 +981,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1002,8 +1000,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1021,8 +1019,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1040,8 +1038,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1075,8 +1073,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1126,8 +1124,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1161,8 +1159,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1180,8 +1178,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1201,8 +1199,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1221,8 +1219,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1240,8 +1238,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1259,8 +1257,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1278,8 +1276,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1297,8 +1295,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -1316,8 +1314,8 @@ cuda_py_test(
|
||||
":distributions_py",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
"//tensorflow/contrib/linalg:linalg_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
|
@ -1,44 +0,0 @@
|
||||
# Description:
|
||||
# Contains classes that provide access to common method of a [batch] matrix,
|
||||
# without the need to instantiate the matrix.
|
||||
# This allows for exploitation of structure, as well as a generic interface
|
||||
# suitable for iterative solvers.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
|
||||
py_library(
|
||||
name = "linalg_py",
|
||||
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/linalg",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "linear_operator_addition_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
|
||||
additional_deps = [
|
||||
":linalg_py",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
@ -1,58 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Linear algebra libraries.
|
||||
|
||||
See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
|
||||
guide.
|
||||
|
||||
@@LinearOperator
|
||||
@@LinearOperatorBlockDiag
|
||||
@@LinearOperatorCirculant
|
||||
@@LinearOperatorCirculant2D
|
||||
@@LinearOperatorCirculant3D
|
||||
@@LinearOperatorDiag
|
||||
@@LinearOperatorIdentity
|
||||
@@LinearOperatorScaledIdentity
|
||||
@@LinearOperatorFullMatrix
|
||||
@@LinearOperatorKronecker
|
||||
@@LinearOperatorLowerTriangular
|
||||
@@LinearOperatorLowRankUpdate
|
||||
@@LinearOperatorComposition
|
||||
@@add_operators
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||
|
||||
from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
|
||||
from tensorflow.python.ops.linalg.linear_operator import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_block_diag import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_circulant import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_composition import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_diag import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_identity import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_kronecker import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
|
||||
from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
|
||||
|
||||
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
remove_undocumented(__name__)
|
@ -1,19 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ops module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -1,412 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_addition
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
linalg = linalg_lib
|
||||
random_seed.set_random_seed(23)
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
add_operators = linear_operator_addition.add_operators
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
class _BadAdder(linear_operator_addition._Adder):
|
||||
"""Adder that will fail if used."""
|
||||
|
||||
def can_add(self, op1, op2):
|
||||
raise AssertionError("BadAdder.can_add called!")
|
||||
|
||||
def _add(self, op1, op2, operator_name, hints):
|
||||
raise AssertionError("This line should not be reached")
|
||||
|
||||
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
class LinearOperatorAdditionCorrectnessTest(test.TestCase):
|
||||
"""Tests correctness of addition with combinations of a few Adders.
|
||||
|
||||
Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
|
||||
add_operators should reduce all operators resulting in one single operator.
|
||||
|
||||
This shows that we are able to correctly combine adders using the tiered
|
||||
system. All Adders should be tested separately, and there is no need to test
|
||||
every Adder within this class.
|
||||
"""
|
||||
|
||||
def test_one_operator_is_returned_unchanged(self):
|
||||
op_a = linalg.LinearOperatorDiag([1., 1.])
|
||||
op_sum = add_operators([op_a])
|
||||
self.assertEqual(1, len(op_sum))
|
||||
self.assertTrue(op_sum[0] is op_a)
|
||||
|
||||
def test_at_least_one_operators_required(self):
|
||||
with self.assertRaisesRegexp(ValueError, "must contain at least one"):
|
||||
add_operators([])
|
||||
|
||||
def test_attempting_to_add_numbers_raises(self):
|
||||
with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
|
||||
add_operators([1, 2])
|
||||
|
||||
def test_two_diag_operators(self):
|
||||
op_a = linalg.LinearOperatorDiag(
|
||||
[1., 1.], is_positive_definite=True, name="A")
|
||||
op_b = linalg.LinearOperatorDiag(
|
||||
[2., 2.], is_positive_definite=True, name="B")
|
||||
with self.cached_session():
|
||||
op_sum = add_operators([op_a, op_b])
|
||||
self.assertEqual(1, len(op_sum))
|
||||
op = op_sum[0]
|
||||
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
|
||||
self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
|
||||
# Adding positive definite operators produces positive def.
|
||||
self.assertTrue(op.is_positive_definite)
|
||||
# Real diagonal ==> self-adjoint.
|
||||
self.assertTrue(op.is_self_adjoint)
|
||||
# Positive definite ==> non-singular
|
||||
self.assertTrue(op.is_non_singular)
|
||||
# Enforce particular name for this simple case
|
||||
self.assertEqual("Add/B__A/", op.name)
|
||||
|
||||
def test_three_diag_operators(self):
|
||||
op1 = linalg.LinearOperatorDiag(
|
||||
[1., 1.], is_positive_definite=True, name="op1")
|
||||
op2 = linalg.LinearOperatorDiag(
|
||||
[2., 2.], is_positive_definite=True, name="op2")
|
||||
op3 = linalg.LinearOperatorDiag(
|
||||
[3., 3.], is_positive_definite=True, name="op3")
|
||||
with self.cached_session():
|
||||
op_sum = add_operators([op1, op2, op3])
|
||||
self.assertEqual(1, len(op_sum))
|
||||
op = op_sum[0]
|
||||
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
|
||||
self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
|
||||
# Adding positive definite operators produces positive def.
|
||||
self.assertTrue(op.is_positive_definite)
|
||||
# Real diagonal ==> self-adjoint.
|
||||
self.assertTrue(op.is_self_adjoint)
|
||||
# Positive definite ==> non-singular
|
||||
self.assertTrue(op.is_non_singular)
|
||||
|
||||
def test_diag_tril_diag(self):
|
||||
op1 = linalg.LinearOperatorDiag(
|
||||
[1., 1.], is_non_singular=True, name="diag_a")
|
||||
op2 = linalg.LinearOperatorLowerTriangular(
|
||||
[[2., 0.], [0., 2.]],
|
||||
is_self_adjoint=True,
|
||||
is_non_singular=True,
|
||||
name="tril")
|
||||
op3 = linalg.LinearOperatorDiag(
|
||||
[3., 3.], is_non_singular=True, name="diag_b")
|
||||
with self.cached_session():
|
||||
op_sum = add_operators([op1, op2, op3])
|
||||
self.assertEqual(1, len(op_sum))
|
||||
op = op_sum[0]
|
||||
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular))
|
||||
self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
|
||||
|
||||
# The diag operators will be self-adjoint (because real and diagonal).
|
||||
# The TriL operator has the self-adjoint hint set.
|
||||
self.assertTrue(op.is_self_adjoint)
|
||||
|
||||
# Even though op1/2/3 are non-singular, this does not imply op is.
|
||||
# Since no custom hint was provided, we default to None (unknown).
|
||||
self.assertEqual(None, op.is_non_singular)
|
||||
|
||||
def test_matrix_diag_tril_diag_uses_custom_name(self):
|
||||
op0 = linalg.LinearOperatorFullMatrix(
|
||||
[[-1., -1.], [-1., -1.]], name="matrix")
|
||||
op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
|
||||
op2 = linalg.LinearOperatorLowerTriangular(
|
||||
[[2., 0.], [1.5, 2.]], name="tril")
|
||||
op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
|
||||
with self.cached_session():
|
||||
op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
|
||||
self.assertEqual(1, len(op_sum))
|
||||
op = op_sum[0]
|
||||
self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix))
|
||||
self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
|
||||
self.assertEqual("my_operator", op.name)
|
||||
|
||||
def test_incompatible_domain_dimensions_raises(self):
|
||||
op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
|
||||
op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
|
||||
with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
|
||||
add_operators([op1, op2])
|
||||
|
||||
def test_incompatible_range_dimensions_raises(self):
|
||||
op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
|
||||
op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
|
||||
with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
|
||||
add_operators([op1, op2])
|
||||
|
||||
def test_non_broadcastable_batch_shape_raises(self):
|
||||
op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
|
||||
op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
|
||||
with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
|
||||
add_operators([op1, op2])
|
||||
|
||||
|
||||
class LinearOperatorOrderOfAdditionTest(test.TestCase):
|
||||
"""Test that the order of addition is done as specified by tiers."""
|
||||
|
||||
def test_tier_0_additions_done_in_tier_0(self):
|
||||
diag1 = linalg.LinearOperatorDiag([1.])
|
||||
diag2 = linalg.LinearOperatorDiag([1.])
|
||||
diag3 = linalg.LinearOperatorDiag([1.])
|
||||
addition_tiers = [
|
||||
[linear_operator_addition._AddAndReturnDiag()],
|
||||
[_BadAdder()],
|
||||
]
|
||||
# Should not raise since all were added in tier 0, and tier 1 (with the
|
||||
# _BadAdder) was never reached.
|
||||
op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
|
||||
self.assertEqual(1, len(op_sum))
|
||||
self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag))
|
||||
|
||||
def test_tier_1_additions_done_by_tier_1(self):
|
||||
diag1 = linalg.LinearOperatorDiag([1.])
|
||||
diag2 = linalg.LinearOperatorDiag([1.])
|
||||
tril = linalg.LinearOperatorLowerTriangular([[1.]])
|
||||
addition_tiers = [
|
||||
[linear_operator_addition._AddAndReturnDiag()],
|
||||
[linear_operator_addition._AddAndReturnTriL()],
|
||||
[_BadAdder()],
|
||||
]
|
||||
# Should not raise since all were added by tier 1, and the
|
||||
# _BadAdder) was never reached.
|
||||
op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
|
||||
self.assertEqual(1, len(op_sum))
|
||||
self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
|
||||
|
||||
def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
|
||||
diag1 = linalg.LinearOperatorDiag([1.])
|
||||
diag2 = linalg.LinearOperatorDiag([1.])
|
||||
tril = linalg.LinearOperatorLowerTriangular([[1.]])
|
||||
addition_tiers = [
|
||||
[linear_operator_addition._AddAndReturnTriL()],
|
||||
[linear_operator_addition._AddAndReturnDiag()],
|
||||
[_BadAdder()],
|
||||
]
|
||||
# Tier 0 could convert to TriL, and this converted everything to TriL,
|
||||
# including the Diags.
|
||||
# Tier 1 was never used.
|
||||
# Tier 2 was never used (therefore, _BadAdder didn't raise).
|
||||
op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
|
||||
self.assertEqual(1, len(op_sum))
|
||||
self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
|
||||
|
||||
def test_cannot_add_everything_so_return_more_than_one_operator(self):
|
||||
diag1 = linalg.LinearOperatorDiag([1.])
|
||||
diag2 = linalg.LinearOperatorDiag([2.])
|
||||
tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
|
||||
addition_tiers = [
|
||||
[linear_operator_addition._AddAndReturnDiag()],
|
||||
]
|
||||
# Tier 0 (the only tier) can only convert to Diag, so it combines the two
|
||||
# diags, but the TriL is unchanged.
|
||||
# Result should contain two operators, one Diag, one TriL.
|
||||
op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
|
||||
self.assertEqual(2, len(op_sum))
|
||||
found_diag = False
|
||||
found_tril = False
|
||||
with self.cached_session():
|
||||
for op in op_sum:
|
||||
if isinstance(op, linalg.LinearOperatorDiag):
|
||||
found_diag = True
|
||||
self.assertAllClose([[3.]], op.to_dense().eval())
|
||||
if isinstance(op, linalg.LinearOperatorLowerTriangular):
|
||||
found_tril = True
|
||||
self.assertAllClose([[5.]], op.to_dense().eval())
|
||||
self.assertTrue(found_diag and found_tril)
|
||||
|
||||
def test_intermediate_tier_is_not_skipped(self):
|
||||
diag1 = linalg.LinearOperatorDiag([1.])
|
||||
diag2 = linalg.LinearOperatorDiag([1.])
|
||||
tril = linalg.LinearOperatorLowerTriangular([[1.]])
|
||||
addition_tiers = [
|
||||
[linear_operator_addition._AddAndReturnDiag()],
|
||||
[_BadAdder()],
|
||||
[linear_operator_addition._AddAndReturnTriL()],
|
||||
]
|
||||
# tril cannot be added in tier 0, and the intermediate tier 1 with the
|
||||
# BadAdder will catch it and raise.
|
||||
with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
|
||||
add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
|
||||
|
||||
|
||||
class AddAndReturnScaledIdentityTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
|
||||
|
||||
def test_identity_plus_identity(self):
|
||||
id1 = linalg.LinearOperatorIdentity(num_rows=2)
|
||||
id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
|
||||
hints = linear_operator_addition._Hints(
|
||||
is_positive_definite=True, is_non_singular=True)
|
||||
|
||||
self.assertTrue(self._adder.can_add(id1, id2))
|
||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||
self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose(2 *
|
||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||
operator.to_dense().eval())
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertEqual("my_operator", operator.name)
|
||||
|
||||
def test_identity_plus_scaled_identity(self):
|
||||
id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
|
||||
id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
|
||||
hints = linear_operator_addition._Hints(
|
||||
is_positive_definite=True, is_non_singular=True)
|
||||
|
||||
self.assertTrue(self._adder.can_add(id1, id2))
|
||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||
self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose(3.2 *
|
||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||
operator.to_dense().eval())
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertEqual("my_operator", operator.name)
|
||||
|
||||
def test_scaled_identity_plus_scaled_identity(self):
|
||||
id1 = linalg.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[2.2, 2.2, 2.2])
|
||||
id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
|
||||
hints = linear_operator_addition._Hints(
|
||||
is_positive_definite=True, is_non_singular=True)
|
||||
|
||||
self.assertTrue(self._adder.can_add(id1, id2))
|
||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||
self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose(1.2 *
|
||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||
operator.to_dense().eval())
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertEqual("my_operator", operator.name)
|
||||
|
||||
|
||||
class AddAndReturnDiagTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._adder = linear_operator_addition._AddAndReturnDiag()
|
||||
|
||||
def test_identity_plus_identity_returns_diag(self):
|
||||
id1 = linalg.LinearOperatorIdentity(num_rows=2)
|
||||
id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
|
||||
hints = linear_operator_addition._Hints(
|
||||
is_positive_definite=True, is_non_singular=True)
|
||||
|
||||
self.assertTrue(self._adder.can_add(id1, id2))
|
||||
operator = self._adder.add(id1, id2, "my_operator", hints)
|
||||
self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose(2 *
|
||||
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
|
||||
operator.to_dense().eval())
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertEqual("my_operator", operator.name)
|
||||
|
||||
def test_diag_plus_diag(self):
|
||||
diag1 = rng.rand(2, 3, 4)
|
||||
diag2 = rng.rand(4)
|
||||
op1 = linalg.LinearOperatorDiag(diag1)
|
||||
op2 = linalg.LinearOperatorDiag(diag2)
|
||||
hints = linear_operator_addition._Hints(
|
||||
is_positive_definite=True, is_non_singular=True)
|
||||
|
||||
self.assertTrue(self._adder.can_add(op1, op2))
|
||||
operator = self._adder.add(op1, op2, "my_operator", hints)
|
||||
self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose(
|
||||
linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
|
||||
operator.to_dense().eval())
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertEqual("my_operator", operator.name)
|
||||
|
||||
|
||||
class AddAndReturnTriLTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._adder = linear_operator_addition._AddAndReturnTriL()
|
||||
|
||||
def test_diag_plus_tril(self):
|
||||
diag = linalg.LinearOperatorDiag([1., 2.])
|
||||
tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
|
||||
hints = linear_operator_addition._Hints(
|
||||
is_positive_definite=True, is_non_singular=True)
|
||||
|
||||
self.assertTrue(self._adder.can_add(diag, diag))
|
||||
self.assertTrue(self._adder.can_add(diag, tril))
|
||||
operator = self._adder.add(diag, tril, "my_operator", hints)
|
||||
self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
|
||||
self.assertTrue(operator.is_positive_definite)
|
||||
self.assertTrue(operator.is_non_singular)
|
||||
self.assertEqual("my_operator", operator.name)
|
||||
|
||||
|
||||
class AddAndReturnMatrixTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._adder = linear_operator_addition._AddAndReturnMatrix()
|
||||
|
||||
def test_diag_plus_diag(self):
|
||||
diag1 = linalg.LinearOperatorDiag([1., 2.])
|
||||
diag2 = linalg.LinearOperatorDiag([-1., 3.])
|
||||
hints = linear_operator_addition._Hints(
|
||||
is_positive_definite=False, is_non_singular=False)
|
||||
|
||||
self.assertTrue(self._adder.can_add(diag1, diag2))
|
||||
operator = self._adder.add(diag1, diag2, "my_operator", hints)
|
||||
self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
|
||||
self.assertFalse(operator.is_positive_definite)
|
||||
self.assertFalse(operator.is_non_singular)
|
||||
self.assertEqual("my_operator", operator.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,432 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Add one or more `LinearOperators` efficiently."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_diag
|
||||
from tensorflow.python.ops.linalg import linear_operator_full_matrix
|
||||
from tensorflow.python.ops.linalg import linear_operator_identity
|
||||
from tensorflow.python.ops.linalg import linear_operator_lower_triangular
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
def add_operators(operators,
|
||||
operator_name=None,
|
||||
addition_tiers=None,
|
||||
name=None):
|
||||
"""Efficiently add one or more linear operators.
|
||||
|
||||
Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
|
||||
operators `[B1, B2,...]` such that
|
||||
|
||||
```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
|
||||
|
||||
The operators `Bk` result by adding some of the `Ak`, as allowed by
|
||||
`addition_tiers`.
|
||||
|
||||
Example of efficient adding of diagonal operators.
|
||||
|
||||
```python
|
||||
A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
|
||||
A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
|
||||
|
||||
# Use two tiers, the first contains an Adder that returns Diag. Since both
|
||||
# A1 and A2 are Diag, they can use this Adder. The second tier will not be
|
||||
# used.
|
||||
addition_tiers = [
|
||||
[_AddAndReturnDiag()],
|
||||
[_AddAndReturnMatrix()]]
|
||||
B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
|
||||
|
||||
len(B_list)
|
||||
==> 1
|
||||
|
||||
B_list[0].__class__.__name__
|
||||
==> 'LinearOperatorDiag'
|
||||
|
||||
B_list[0].to_dense()
|
||||
==> [[3., 0.],
|
||||
[0., 3.]]
|
||||
|
||||
B_list[0].name
|
||||
==> 'Add/A1__A2/'
|
||||
```
|
||||
|
||||
Args:
|
||||
operators: Iterable of `LinearOperator` objects with same `dtype`, domain
|
||||
and range dimensions, and broadcastable batch shapes.
|
||||
operator_name: String name for returned `LinearOperator`. Defaults to
|
||||
concatenation of "Add/A__B/" that indicates the order of addition steps.
|
||||
addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
|
||||
is a list of `Adder` objects. This function attempts to do all additions
|
||||
in tier `i` before trying tier `i + 1`.
|
||||
name: A name for this `Op`. Defaults to `add_operators`.
|
||||
|
||||
Returns:
|
||||
Subclass of `LinearOperator`. Class and order of addition may change as new
|
||||
(and better) addition strategies emerge.
|
||||
|
||||
Raises:
|
||||
ValueError: If `operators` argument is empty.
|
||||
ValueError: If shapes are incompatible.
|
||||
"""
|
||||
# Default setting
|
||||
if addition_tiers is None:
|
||||
addition_tiers = _DEFAULT_ADDITION_TIERS
|
||||
|
||||
# Argument checking.
|
||||
check_ops.assert_proper_iterable(operators)
|
||||
operators = list(reversed(operators))
|
||||
if len(operators) < 1:
|
||||
raise ValueError(
|
||||
"Argument 'operators' must contain at least one operator. "
|
||||
"Found: %s" % operators)
|
||||
if not all(
|
||||
isinstance(op, linear_operator.LinearOperator) for op in operators):
|
||||
raise TypeError(
|
||||
"Argument 'operators' must contain only LinearOperator instances. "
|
||||
"Found: %s" % operators)
|
||||
_static_check_for_same_dimensions(operators)
|
||||
_static_check_for_broadcastable_batch_shape(operators)
|
||||
|
||||
graph_parents = []
|
||||
for operator in operators:
|
||||
graph_parents.extend(operator.graph_parents)
|
||||
|
||||
with ops.name_scope(name or "add_operators", values=graph_parents):
|
||||
|
||||
# Additions done in one of the tiers. Try tier 0, 1,...
|
||||
ops_to_try_at_next_tier = list(operators)
|
||||
for tier in addition_tiers:
|
||||
ops_to_try_at_this_tier = ops_to_try_at_next_tier
|
||||
ops_to_try_at_next_tier = []
|
||||
while ops_to_try_at_this_tier:
|
||||
op1 = ops_to_try_at_this_tier.pop()
|
||||
op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
|
||||
if op2 is not None:
|
||||
# Will try to add the result of this again at this same tier.
|
||||
new_operator = adder.add(op1, op2, operator_name)
|
||||
ops_to_try_at_this_tier.append(new_operator)
|
||||
else:
|
||||
ops_to_try_at_next_tier.append(op1)
|
||||
|
||||
return ops_to_try_at_next_tier
|
||||
|
||||
|
||||
def _pop_a_match_at_tier(op1, operator_list, tier):
|
||||
# Search from the back of list to the front in order to create nice default
|
||||
# order of operations.
|
||||
for i in range(1, len(operator_list) + 1):
|
||||
op2 = operator_list[-i]
|
||||
for adder in tier:
|
||||
if adder.can_add(op1, op2):
|
||||
return operator_list.pop(-i), adder
|
||||
return None, None
|
||||
|
||||
|
||||
def _infer_hints_allowing_override(op1, op2, hints):
|
||||
"""Infer hints from op1 and op2. hints argument is an override.
|
||||
|
||||
Args:
|
||||
op1: LinearOperator
|
||||
op2: LinearOperator
|
||||
hints: _Hints object holding "is_X" boolean hints to use for returned
|
||||
operator.
|
||||
If some hint is None, try to set using op1 and op2. If the
|
||||
hint is provided, ignore op1 and op2 hints. This allows an override
|
||||
of previous hints, but does not allow forbidden hints (e.g. you still
|
||||
cannot say a real diagonal operator is not self-adjoint.
|
||||
|
||||
Returns:
|
||||
_Hints object.
|
||||
"""
|
||||
hints = hints or _Hints()
|
||||
# If A, B are self-adjoint, then so is A + B.
|
||||
if hints.is_self_adjoint is None:
|
||||
is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
|
||||
else:
|
||||
is_self_adjoint = hints.is_self_adjoint
|
||||
|
||||
# If A, B are positive definite, then so is A + B.
|
||||
if hints.is_positive_definite is None:
|
||||
is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
|
||||
else:
|
||||
is_positive_definite = hints.is_positive_definite
|
||||
|
||||
# A positive definite operator is always non-singular.
|
||||
if is_positive_definite and hints.is_positive_definite is None:
|
||||
is_non_singular = True
|
||||
else:
|
||||
is_non_singular = hints.is_non_singular
|
||||
|
||||
return _Hints(
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite)
|
||||
|
||||
|
||||
def _static_check_for_same_dimensions(operators):
|
||||
"""ValueError if operators determined to have different dimensions."""
|
||||
if len(operators) < 2:
|
||||
return
|
||||
|
||||
domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
|
||||
if op.domain_dimension.value is not None]
|
||||
if len(set(value for name, value in domain_dimensions)) > 1:
|
||||
raise ValueError("Operators must have the same domain dimension. Found: %s"
|
||||
% domain_dimensions)
|
||||
|
||||
range_dimensions = [(op.name, op.range_dimension.value) for op in operators
|
||||
if op.range_dimension.value is not None]
|
||||
if len(set(value for name, value in range_dimensions)) > 1:
|
||||
raise ValueError("Operators must have the same range dimension. Found: %s" %
|
||||
range_dimensions)
|
||||
|
||||
|
||||
def _static_check_for_broadcastable_batch_shape(operators):
|
||||
"""ValueError if operators determined to have non-broadcastable shapes."""
|
||||
if len(operators) < 2:
|
||||
return
|
||||
|
||||
# This will fail if they cannot be broadcast together.
|
||||
batch_shape = operators[0].batch_shape
|
||||
for op in operators[1:]:
|
||||
batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
|
||||
|
||||
|
||||
class _Hints(object):
|
||||
"""Holds 'is_X' flags that every LinearOperator is initialized with."""
|
||||
|
||||
def __init__(self,
|
||||
is_non_singular=None,
|
||||
is_positive_definite=None,
|
||||
is_self_adjoint=None):
|
||||
self.is_non_singular = is_non_singular
|
||||
self.is_positive_definite = is_positive_definite
|
||||
self.is_self_adjoint = is_self_adjoint
|
||||
|
||||
|
||||
################################################################################
|
||||
# Classes to add two linear operators.
|
||||
################################################################################
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class _Adder(object):
|
||||
"""Abstract base class to add two operators.
|
||||
|
||||
Each `Adder` acts independently, adding everything it can, paying no attention
|
||||
as to whether another `Adder` could have done the addition more efficiently.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
@abc.abstractmethod
|
||||
def can_add(self, op1, op2):
|
||||
"""Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _add(self, op1, op2, operator_name, hints):
|
||||
# Derived classes can assume op1 and op2 have been validated, e.g. they have
|
||||
# the same dtype, and their domain/range dimensions match.
|
||||
pass
|
||||
|
||||
def add(self, op1, op2, operator_name, hints=None):
|
||||
"""Return new `LinearOperator` acting like `op1 + op2`.
|
||||
|
||||
Args:
|
||||
op1: `LinearOperator`
|
||||
op2: `LinearOperator`, with `shape` and `dtype` such that adding to
|
||||
`op1` is allowed.
|
||||
operator_name: `String` name to give to returned `LinearOperator`
|
||||
hints: `_Hints` object. Returned `LinearOperator` will be created with
|
||||
these hints.
|
||||
|
||||
Returns:
|
||||
`LinearOperator`
|
||||
"""
|
||||
updated_hints = _infer_hints_allowing_override(op1, op2, hints)
|
||||
|
||||
if operator_name is None:
|
||||
operator_name = "Add/" + op1.name + "__" + op2.name + "/"
|
||||
|
||||
values = op1.graph_parents + op2.graph_parents
|
||||
scope_name = self.name
|
||||
if scope_name.startswith("_"):
|
||||
scope_name = scope_name[1:]
|
||||
with ops.name_scope(scope_name, values=values):
|
||||
return self._add(op1, op2, operator_name, updated_hints)
|
||||
|
||||
|
||||
class _AddAndReturnScaledIdentity(_Adder):
|
||||
"""Handles additions resulting in an Identity family member.
|
||||
|
||||
The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
|
||||
is closed under addition. This `Adder` respects that, and returns an Identity
|
||||
"""
|
||||
|
||||
def can_add(self, op1, op2):
|
||||
types = {_type(op1), _type(op2)}
|
||||
return not types.difference(_IDENTITY_FAMILY)
|
||||
|
||||
def _add(self, op1, op2, operator_name, hints):
|
||||
# Will build a LinearOperatorScaledIdentity.
|
||||
|
||||
if _type(op1) == _SCALED_IDENTITY:
|
||||
multiplier_1 = op1.multiplier
|
||||
else:
|
||||
multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
|
||||
|
||||
if _type(op2) == _SCALED_IDENTITY:
|
||||
multiplier_2 = op2.multiplier
|
||||
else:
|
||||
multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
|
||||
|
||||
return linear_operator_identity.LinearOperatorScaledIdentity(
|
||||
num_rows=op1.range_dimension_tensor(),
|
||||
multiplier=multiplier_1 + multiplier_2,
|
||||
is_non_singular=hints.is_non_singular,
|
||||
is_self_adjoint=hints.is_self_adjoint,
|
||||
is_positive_definite=hints.is_positive_definite,
|
||||
name=operator_name)
|
||||
|
||||
|
||||
class _AddAndReturnDiag(_Adder):
|
||||
"""Handles additions resulting in a Diag operator."""
|
||||
|
||||
def can_add(self, op1, op2):
|
||||
types = {_type(op1), _type(op2)}
|
||||
return not types.difference(_DIAG_LIKE)
|
||||
|
||||
def _add(self, op1, op2, operator_name, hints):
|
||||
return linear_operator_diag.LinearOperatorDiag(
|
||||
diag=op1.diag_part() + op2.diag_part(),
|
||||
is_non_singular=hints.is_non_singular,
|
||||
is_self_adjoint=hints.is_self_adjoint,
|
||||
is_positive_definite=hints.is_positive_definite,
|
||||
name=operator_name)
|
||||
|
||||
|
||||
class _AddAndReturnTriL(_Adder):
|
||||
"""Handles additions resulting in a TriL operator."""
|
||||
|
||||
def can_add(self, op1, op2):
|
||||
types = {_type(op1), _type(op2)}
|
||||
return not types.difference(_DIAG_LIKE.union({_TRIL}))
|
||||
|
||||
def _add(self, op1, op2, operator_name, hints):
|
||||
if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
|
||||
op_add_to_tensor, op_other = op1, op2
|
||||
else:
|
||||
op_add_to_tensor, op_other = op2, op1
|
||||
|
||||
return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
|
||||
tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
|
||||
is_non_singular=hints.is_non_singular,
|
||||
is_self_adjoint=hints.is_self_adjoint,
|
||||
is_positive_definite=hints.is_positive_definite,
|
||||
name=operator_name)
|
||||
|
||||
|
||||
class _AddAndReturnMatrix(_Adder):
|
||||
""""Handles additions resulting in a `LinearOperatorFullMatrix`."""
|
||||
|
||||
def can_add(self, op1, op2): # pylint: disable=unused-argument
|
||||
return isinstance(op1, linear_operator.LinearOperator) and isinstance(
|
||||
op2, linear_operator.LinearOperator)
|
||||
|
||||
def _add(self, op1, op2, operator_name, hints):
|
||||
if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
|
||||
op_add_to_tensor, op_other = op1, op2
|
||||
else:
|
||||
op_add_to_tensor, op_other = op2, op1
|
||||
return linear_operator_full_matrix.LinearOperatorFullMatrix(
|
||||
matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
|
||||
is_non_singular=hints.is_non_singular,
|
||||
is_self_adjoint=hints.is_self_adjoint,
|
||||
is_positive_definite=hints.is_positive_definite,
|
||||
name=operator_name)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Constants designating types of LinearOperators
|
||||
################################################################################
|
||||
|
||||
# Type name constants for LinearOperator classes.
|
||||
_IDENTITY = "identity"
|
||||
_SCALED_IDENTITY = "scaled_identity"
|
||||
_DIAG = "diag"
|
||||
_TRIL = "tril"
|
||||
_MATRIX = "matrix"
|
||||
|
||||
# Groups of operators.
|
||||
_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
|
||||
_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
|
||||
# operators with an efficient .add_to_tensor() method.
|
||||
_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
|
||||
|
||||
|
||||
def _type(operator):
|
||||
"""Returns the type name constant (e.g. _TRIL) for operator."""
|
||||
if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
|
||||
return _DIAG
|
||||
if isinstance(operator,
|
||||
linear_operator_lower_triangular.LinearOperatorLowerTriangular):
|
||||
return _TRIL
|
||||
if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
|
||||
return _MATRIX
|
||||
if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
|
||||
return _IDENTITY
|
||||
if isinstance(operator,
|
||||
linear_operator_identity.LinearOperatorScaledIdentity):
|
||||
return _SCALED_IDENTITY
|
||||
raise TypeError("Operator type unknown: %s" % operator)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Addition tiers:
|
||||
# We attempt to use Adders in tier K before K+1.
|
||||
#
|
||||
# Organize tiers to
|
||||
# (i) reduce O(..) complexity of forming final operator, and
|
||||
# (ii) produce the "most efficient" final operator.
|
||||
# Dev notes:
|
||||
# * Results of addition at tier K will be added at tier K or higher.
|
||||
# * Tiers may change, and we warn the user that it may change.
|
||||
################################################################################
|
||||
|
||||
# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
|
||||
# dense matrix. So it is sometimes very inefficient.
|
||||
_DEFAULT_ADDITION_TIERS = [
|
||||
[_AddAndReturnScaledIdentity()],
|
||||
[_AddAndReturnDiag()],
|
||||
[_AddAndReturnTriL()],
|
||||
[_AddAndReturnMatrix()],
|
||||
]
|
Loading…
Reference in New Issue
Block a user