The linalg.LinearOperator* Module APIs do not support top-level dispatching because they are classes w/ methods instead of top-level methods in TF's APIs. But, their class methods call out to APIs that do support dispatching.

This CL updates the convert_to_tensor calls in the `linalg.LinearOperator*` APIs to use the publicly exposed, dispatching `convert_to_tensor_v2_with_dispatch`, which enables the Operators to effectively work with dispatching as the APIs they call out to support dispatching as well.

PiperOrigin-RevId: 324834645
Change-Id: If2e9f17be101e74f8835497d8ca51a0174055053
This commit is contained in:
Tomer Kaftan 2020-08-04 10:08:00 -07:00 committed by TensorFlower Gardener
parent c8ddf5a1df
commit 12b7b9e06d
13 changed files with 106 additions and 54 deletions

View File

@ -385,7 +385,7 @@ class LinearOperator(module.Module):
# `shape` may be passed in if this can be pre-computed in a
# more efficient manner, e.g. without excessive Tensor conversions.
if self.tensor_rank is not None:
return ops.convert_to_tensor(self.tensor_rank)
return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank)
else:
shape = self.shape_tensor() if shape is None else shape
return array_ops.size(shape)
@ -429,7 +429,7 @@ class LinearOperator(module.Module):
# more efficient manner, e.g. without excessive Tensor conversions.
dim_value = tensor_shape.dimension_value(self.domain_dimension)
if dim_value is not None:
return ops.convert_to_tensor(dim_value)
return ops.convert_to_tensor_v2_with_dispatch(dim_value)
else:
shape = self.shape_tensor() if shape is None else shape
return shape[-1]
@ -473,7 +473,7 @@ class LinearOperator(module.Module):
# more efficient manner, e.g. without excessive Tensor conversions.
dim_value = tensor_shape.dimension_value(self.range_dimension)
if dim_value is not None:
return ops.convert_to_tensor(dim_value)
return ops.convert_to_tensor_v2_with_dispatch(dim_value)
else:
shape = self.shape_tensor() if shape is None else shape
return shape[-2]
@ -641,7 +641,7 @@ class LinearOperator(module.Module):
return linear_operator_algebra.matmul(left_operator, right_operator)
with self._name_scope(name):
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
self_dim = -2 if adjoint else -1
@ -688,7 +688,7 @@ class LinearOperator(module.Module):
A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
"""
with self._name_scope(name):
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
self_dim = -2 if adjoint else -1
tensor_shape.dimension_at_index(
@ -834,7 +834,7 @@ class LinearOperator(module.Module):
return linear_operator_algebra.solve(left_operator, right_operator)
with self._name_scope(name):
rhs = ops.convert_to_tensor(rhs, name="rhs")
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
self_dim = -1 if adjoint else -2
@ -891,7 +891,7 @@ class LinearOperator(module.Module):
NotImplementedError: If `self.is_non_singular` or `is_square` is False.
"""
with self._name_scope(name):
rhs = ops.convert_to_tensor(rhs, name="rhs")
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
self_dim = -1 if adjoint else -2
tensor_shape.dimension_at_index(
@ -1054,7 +1054,7 @@ class LinearOperator(module.Module):
A `Tensor` with broadcast shape and same `dtype` as `self`.
"""
with self._name_scope(name):
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
return self._add_to_tensor(x)

View File

@ -263,7 +263,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
def _shape_tensor(self):
# Avoid messy broadcasting if possible.
if self.shape.is_fully_defined():
return ops.convert_to_tensor(
return ops.convert_to_tensor_v2_with_dispatch(
self.shape.as_list(), dtype=dtypes.int32, name="shape")
domain_dimension = sum(self._block_domain_dimension_tensors())
@ -330,12 +330,12 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
x[i] = block
else:
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@ -404,7 +404,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
x[i] = block
@ -412,7 +412,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
y_mat = self.matmul(x_mat, adjoint=adjoint)
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@ -508,12 +508,12 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
split_rhs = rhs
for i, block in enumerate(split_rhs):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
split_rhs[i] = block
else:
rhs = ops.convert_to_tensor(rhs, name="rhs")
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)
@ -583,7 +583,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
for i, block in enumerate(rhs):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
rhs[i] = block
@ -591,7 +591,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
rhs = ops.convert_to_tensor(rhs, name="rhs")
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)

View File

@ -366,7 +366,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
def _shape_tensor(self):
# Avoid messy broadcasting if possible.
if self.shape.is_fully_defined():
return ops.convert_to_tensor(
return ops.convert_to_tensor_v2_with_dispatch(
self.shape.as_list(), dtype=dtypes.int32, name="shape")
domain_dimension = sum(self._block_domain_dimension_tensors())
@ -433,12 +433,12 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
x[i] = block
else:
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@ -543,7 +543,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
for i, block in enumerate(x):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
x[i] = block
@ -551,7 +551,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
y_mat = self.matmul(x_mat, adjoint=adjoint)
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
self._check_input_dtype(x)
op_dimension = (self.range_dimension if adjoint
else self.domain_dimension)
@ -674,7 +674,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
if blockwise_arg:
for i, block in enumerate(rhs):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
rhs[i] = block
@ -684,7 +684,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
split_rhs = rhs
else:
rhs = ops.convert_to_tensor(rhs, name="rhs")
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)
@ -795,14 +795,14 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
for i, block in enumerate(rhs):
if not isinstance(block, linear_operator.LinearOperator):
block = ops.convert_to_tensor(block)
block = ops.convert_to_tensor_v2_with_dispatch(block)
self._check_input_dtype(block)
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
rhs[i] = block
rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
rhs = ops.convert_to_tensor(rhs, name="rhs")
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
self._check_input_dtype(rhs)
op_dimension = (self.domain_dimension if adjoint
else self.range_dimension)

View File

@ -378,7 +378,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
def _broadcast_batch_dims(self, x, spectrum):
"""Broadcast batch dims of batch matrix `x` and spectrum."""
spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
spectrum = ops.convert_to_tensor_v2_with_dispatch(spectrum, name="spectrum")
# spectrum.shape = batch_shape + block_shape
# First make spectrum a batch matrix with
# spectrum.shape = batch_shape + [prod(block_shape), 1]
@ -755,7 +755,7 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
name=name)
def _eigvals(self):
return ops.convert_to_tensor(self.spectrum)
return ops.convert_to_tensor_v2_with_dispatch(self.spectrum)
@tf_export("linalg.LinearOperatorCirculant2D")

View File

@ -251,7 +251,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
return array_ops.matrix_set_diag(x, new_diag)
def _eigvals(self):
return ops.convert_to_tensor(self.diag)
return ops.convert_to_tensor_v2_with_dispatch(self.diag)
def _cond(self):
abs_diag = math_ops.abs(self.diag)

View File

@ -160,7 +160,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
dtypes.complex128,
]
matrix = ops.convert_to_tensor(matrix, name="matrix")
matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
dtype = matrix.dtype
if dtype not in allowed_dtypes:

View File

@ -198,7 +198,8 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
# Note that because this is a reflection, it lies in O(n) (for real vector
# spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
self.reflection_axis)
x = linalg.adjoint(x) if adjoint_arg else x
normalized_axis = reflection_axis / linalg.norm(
reflection_axis, axis=-1, keepdims=True)
@ -229,7 +230,8 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
return self._matmul(rhs, adjoint, adjoint_arg)
def _to_dense(self):
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
self.reflection_axis)
normalized_axis = reflection_axis / linalg.norm(
reflection_axis, axis=-1, keepdims=True)
mat = normalized_axis[..., array_ops.newaxis]
@ -238,7 +240,8 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
matrix, 1. + array_ops.matrix_diag_part(matrix))
def _diag_part(self):
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
self.reflection_axis)
normalized_axis = reflection_axis / linalg.norm(
reflection_axis, axis=-1, keepdims=True)
return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)

View File

@ -394,7 +394,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
A `Tensor` with broadcast shape and same `dtype` as `self`.
"""
with self._name_scope(name):
mat = ops.convert_to_tensor(mat, name="mat")
mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
mat_diag = array_ops.matrix_diag_part(mat)
new_diag = 1 + mat_diag
return array_ops.matrix_set_diag(mat, new_diag)
@ -720,7 +720,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
multiplier_vector = array_ops.expand_dims(self.multiplier, -1)
# Shape [C1,...,Cc, M, M]
mat = ops.convert_to_tensor(mat, name="mat")
mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
# Shape [C1,...,Cc, M]
mat_diag = array_ops.matrix_diag_part(mat)

View File

@ -197,7 +197,7 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
return array_ops.shape(perm)[-1]
def _matmul(self, x, adjoint=False, adjoint_arg=False):
perm = ops.convert_to_tensor(self.perm)
perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
if adjoint and not self.is_self_adjoint:
# TODO(srvasude): invert_permutation doesn't work on batches so we use
# argsort.
@ -232,13 +232,13 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
def _to_dense(self):
perm = ops.convert_to_tensor(self.perm)
perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
return math_ops.cast(math_ops.equal(
math_ops.range(0, self._domain_dimension_tensor(perm)),
perm[..., array_ops.newaxis]), self.dtype)
def _diag_part(self):
perm = ops.convert_to_tensor(self.perm)
perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
return math_ops.cast(math_ops.equal(
math_ops.range(0, self._domain_dimension_tensor(perm)),
perm), self.dtype)

View File

@ -209,8 +209,8 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
# for more details.
x = linalg.adjoint(x) if adjoint_arg else x
expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
col = ops.convert_to_tensor(self.col)
row = ops.convert_to_tensor(self.row)
col = ops.convert_to_tensor_v2_with_dispatch(self.col)
row = ops.convert_to_tensor_v2_with_dispatch(self.row)
circulant_col = array_ops.concat(
[col,
array_ops.zeros_like(col[..., 0:1]),
@ -236,8 +236,8 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
[self.domain_dimension_tensor()], self.dtype)
def _to_dense(self):
row = ops.convert_to_tensor(self.row)
col = ops.convert_to_tensor(self.col)
row = ops.convert_to_tensor_v2_with_dispatch(self.row)
col = ops.convert_to_tensor_v2_with_dispatch(self.col)
total_shape = array_ops.broadcast_dynamic_shape(
array_ops.shape(row), array_ops.shape(col))
n = array_ops.shape(row)[-1]

View File

@ -246,7 +246,7 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
self.diagonals, linalg.adjoint(self.diagonals),
message='Matrix was not equal to its adjoint.')]
elif self.diagonals_format == _COMPACT:
diagonals = ops.convert_to_tensor(self.diagonals)
diagonals = ops.convert_to_tensor_v2_with_dispatch(self.diagonals)
asserts += [linear_operator_util.assert_zero_imag_part(
diagonals[..., 1, :], message=diag_message)]
# Roll the subdiagonal so the shifted argument is at the end.
@ -353,7 +353,9 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
align='LEFT_RIGHT',
padding_value=0.)
diagonals = [ops.convert_to_tensor(d) for d in self.diagonals]
diagonals = [
ops.convert_to_tensor_v2_with_dispatch(d) for d in self.diagonals
]
diagonals = array_ops.stack(diagonals, axis=-2)
return gen_array_ops.matrix_diag_v3(

View File

@ -114,7 +114,7 @@ def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None):
raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format(
dtype_name(dtype_base), dtype_name(value_dtype_base)))
return value
return ops.convert_to_tensor(
return ops.convert_to_tensor_v2_with_dispatch(
value, dtype=dtype, dtype_hint=dtype_hint, name=name)
@ -189,10 +189,10 @@ def assert_no_entries_with_modulus_zero(
An `Op` that asserts `x` has no entries with modulus zero.
"""
with ops.name_scope(name, values=[x]):
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
dtype = x.dtype.base_dtype
should_be_nonzero = math_ops.abs(x)
zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
return check_ops.assert_less(zero, should_be_nonzero, message=message)
@ -208,13 +208,13 @@ def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
An `Op` that asserts `x` has no entries with modulus zero.
"""
with ops.name_scope(name, values=[x]):
x = ops.convert_to_tensor(x, name="x")
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
dtype = x.dtype.base_dtype
if dtype.is_floating:
return control_flow_ops.no_op()
zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
@ -261,7 +261,7 @@ def shape_tensor(shape, name=None):
dtype = dtypes.int32
else:
dtype = None
return ops.convert_to_tensor(shape, dtype=dtype, name=name)
return ops.convert_to_tensor_v2_with_dispatch(shape, dtype=dtype, name=name)
################################################################################
@ -323,7 +323,7 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
batch_matrices = list(batch_matrices)
for i, mat in enumerate(batch_matrices):
batch_matrices[i] = ops.convert_to_tensor(mat)
batch_matrices[i] = ops.convert_to_tensor_v2_with_dispatch(mat)
assert_is_batch_matrix(batch_matrices[i])
if len(batch_matrices) < 2:
@ -366,8 +366,9 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
"""Solve systems of linear equations."""
with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
matrix = ops.convert_to_tensor(matrix, name="matrix")
rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)
matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
rhs = ops.convert_to_tensor_v2_with_dispatch(
rhs, name="rhs", dtype=matrix.dtype)
# If either matrix/rhs has extra dims, we can reshape to get rid of them.
matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
@ -526,7 +527,8 @@ def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
if not any(nest.is_nested(x) for x in arg):
return True
else:
arg_dims = [ops.convert_to_tensor(x).shape[arg_split_dim] for x in arg]
arg_dims = [ops.convert_to_tensor_v2_with_dispatch(
x).shape[arg_split_dim] for x in arg]
self_dims = [dim.value for dim in block_dimensions]
# If none of the operator dimensions are known, interpret the input as

View File

@ -18,11 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.proto_ops import decode_proto
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
@ -60,6 +62,8 @@ class TensorTracer(object):
self.name = name
self.args = args
self.kwargs = kwargs
self.shape = array_ops.ones(shape=(4, 4)).shape
self.dtype = dtypes.float32
def __repr__(self):
if self.args is None and self.kwargs is None:
@ -70,6 +74,10 @@ class TensorTracer(object):
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
return "{}({})".format(self.name, ", ".join(args))
@property
def is_tensor_like(self):
return True
@classmethod
def _overload_all_operators(cls): # pylint: disable=invalid-name
"""Register overloads for all operators."""
@ -282,5 +290,42 @@ class DispatchTest(test_util.TensorFlowTestCase):
# Clean up.
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
def testGlobalDispatcherLinearOperators(self):
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
try:
TensorTracerOpDispatcher().register()
x = TensorTracer("x")
# To grab the eigenvalues the diag operator just calls convert_to_tensor
# (twice) in this case.
trace = linear_operator_diag.LinearOperatorDiag(x).eigvals()
self.assertEqual(
str(trace),
"convert_to_tensor(convert_to_tensor(x, dtype=None, dtype_hint=None, "
"name=diag))")
# The diagonal tensor addition gets traced even though the linear_operator
# API only uses dispatchable ops instead of directly exposing dispatching.
trace = linear_operator_diag.LinearOperatorDiag(x).add_to_tensor(x)
self.assertIn(
"linalg.set_diag(convert_to_tensor(x, name=x), __operators__.add("
"convert_to_tensor(x, dtype=None, dtype_hint=None, name=diag), "
"linalg.diag_part(convert_to_tensor(x, name=x)), "
"name=",
str(trace))
# The dispatch-supporting ops the non-singular check calls out to
# get traced.
trace = linear_operator_diag.LinearOperatorDiag(x).assert_non_singular()
self.assertIn("debugging.assert_less", str(trace))
self.assertIn(
"message=Singular operator: Diagonal contained zero values.",
str(trace))
finally:
# Clean up.
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
if __name__ == "__main__":
googletest.main()