The linalg.LinearOperator*
Module APIs do not support top-level dispatching because they are classes w/ methods instead of top-level methods in TF's APIs. But, their class methods call out to APIs that do support dispatching.
This CL updates the convert_to_tensor calls in the `linalg.LinearOperator*` APIs to use the publicly exposed, dispatching `convert_to_tensor_v2_with_dispatch`, which enables the Operators to effectively work with dispatching as the APIs they call out to support dispatching as well. PiperOrigin-RevId: 324834645 Change-Id: If2e9f17be101e74f8835497d8ca51a0174055053
This commit is contained in:
parent
c8ddf5a1df
commit
12b7b9e06d
tensorflow/python
ops/linalg
linear_operator.pylinear_operator_block_diag.pylinear_operator_block_lower_triangular.pylinear_operator_circulant.pylinear_operator_diag.pylinear_operator_full_matrix.pylinear_operator_householder.pylinear_operator_identity.pylinear_operator_permutation.pylinear_operator_toeplitz.pylinear_operator_tridiag.pylinear_operator_util.py
util
@ -385,7 +385,7 @@ class LinearOperator(module.Module):
|
||||
# `shape` may be passed in if this can be pre-computed in a
|
||||
# more efficient manner, e.g. without excessive Tensor conversions.
|
||||
if self.tensor_rank is not None:
|
||||
return ops.convert_to_tensor(self.tensor_rank)
|
||||
return ops.convert_to_tensor_v2_with_dispatch(self.tensor_rank)
|
||||
else:
|
||||
shape = self.shape_tensor() if shape is None else shape
|
||||
return array_ops.size(shape)
|
||||
@ -429,7 +429,7 @@ class LinearOperator(module.Module):
|
||||
# more efficient manner, e.g. without excessive Tensor conversions.
|
||||
dim_value = tensor_shape.dimension_value(self.domain_dimension)
|
||||
if dim_value is not None:
|
||||
return ops.convert_to_tensor(dim_value)
|
||||
return ops.convert_to_tensor_v2_with_dispatch(dim_value)
|
||||
else:
|
||||
shape = self.shape_tensor() if shape is None else shape
|
||||
return shape[-1]
|
||||
@ -473,7 +473,7 @@ class LinearOperator(module.Module):
|
||||
# more efficient manner, e.g. without excessive Tensor conversions.
|
||||
dim_value = tensor_shape.dimension_value(self.range_dimension)
|
||||
if dim_value is not None:
|
||||
return ops.convert_to_tensor(dim_value)
|
||||
return ops.convert_to_tensor_v2_with_dispatch(dim_value)
|
||||
else:
|
||||
shape = self.shape_tensor() if shape is None else shape
|
||||
return shape[-2]
|
||||
@ -641,7 +641,7 @@ class LinearOperator(module.Module):
|
||||
return linear_operator_algebra.matmul(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
|
||||
self_dim = -2 if adjoint else -1
|
||||
@ -688,7 +688,7 @@ class LinearOperator(module.Module):
|
||||
A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
self_dim = -2 if adjoint else -1
|
||||
tensor_shape.dimension_at_index(
|
||||
@ -834,7 +834,7 @@ class LinearOperator(module.Module):
|
||||
return linear_operator_algebra.solve(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name):
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
|
||||
self_dim = -1 if adjoint else -2
|
||||
@ -891,7 +891,7 @@ class LinearOperator(module.Module):
|
||||
NotImplementedError: If `self.is_non_singular` or `is_square` is False.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
self_dim = -1 if adjoint else -2
|
||||
tensor_shape.dimension_at_index(
|
||||
@ -1054,7 +1054,7 @@ class LinearOperator(module.Module):
|
||||
A `Tensor` with broadcast shape and same `dtype` as `self`.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
return self._add_to_tensor(x)
|
||||
|
||||
|
@ -263,7 +263,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
def _shape_tensor(self):
|
||||
# Avoid messy broadcasting if possible.
|
||||
if self.shape.is_fully_defined():
|
||||
return ops.convert_to_tensor(
|
||||
return ops.convert_to_tensor_v2_with_dispatch(
|
||||
self.shape.as_list(), dtype=dtypes.int32, name="shape")
|
||||
|
||||
domain_dimension = sum(self._block_domain_dimension_tensors())
|
||||
@ -330,12 +330,12 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
x[i] = block
|
||||
else:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
@ -404,7 +404,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
x[i] = block
|
||||
@ -412,7 +412,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
y_mat = self.matmul(x_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
@ -508,12 +508,12 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
split_rhs = rhs
|
||||
for i, block in enumerate(split_rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
split_rhs[i] = block
|
||||
else:
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
@ -583,7 +583,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
|
||||
for i, block in enumerate(rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
rhs[i] = block
|
||||
@ -591,7 +591,7 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
|
||||
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
|
@ -366,7 +366,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
def _shape_tensor(self):
|
||||
# Avoid messy broadcasting if possible.
|
||||
if self.shape.is_fully_defined():
|
||||
return ops.convert_to_tensor(
|
||||
return ops.convert_to_tensor_v2_with_dispatch(
|
||||
self.shape.as_list(), dtype=dtypes.int32, name="shape")
|
||||
|
||||
domain_dimension = sum(self._block_domain_dimension_tensors())
|
||||
@ -433,12 +433,12 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
x[i] = block
|
||||
else:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
@ -543,7 +543,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
x[i] = block
|
||||
@ -551,7 +551,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
y_mat = self.matmul(x_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
@ -674,7 +674,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
if blockwise_arg:
|
||||
for i, block in enumerate(rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
rhs[i] = block
|
||||
@ -684,7 +684,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
split_rhs = rhs
|
||||
|
||||
else:
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
@ -795,14 +795,14 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
|
||||
for i, block in enumerate(rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
block = ops.convert_to_tensor_v2_with_dispatch(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
rhs[i] = block
|
||||
rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
|
||||
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
rhs = ops.convert_to_tensor_v2_with_dispatch(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
|
@ -378,7 +378,7 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
|
||||
|
||||
def _broadcast_batch_dims(self, x, spectrum):
|
||||
"""Broadcast batch dims of batch matrix `x` and spectrum."""
|
||||
spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
|
||||
spectrum = ops.convert_to_tensor_v2_with_dispatch(spectrum, name="spectrum")
|
||||
# spectrum.shape = batch_shape + block_shape
|
||||
# First make spectrum a batch matrix with
|
||||
# spectrum.shape = batch_shape + [prod(block_shape), 1]
|
||||
@ -755,7 +755,7 @@ class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
|
||||
name=name)
|
||||
|
||||
def _eigvals(self):
|
||||
return ops.convert_to_tensor(self.spectrum)
|
||||
return ops.convert_to_tensor_v2_with_dispatch(self.spectrum)
|
||||
|
||||
|
||||
@tf_export("linalg.LinearOperatorCirculant2D")
|
||||
|
@ -251,7 +251,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
return array_ops.matrix_set_diag(x, new_diag)
|
||||
|
||||
def _eigvals(self):
|
||||
return ops.convert_to_tensor(self.diag)
|
||||
return ops.convert_to_tensor_v2_with_dispatch(self.diag)
|
||||
|
||||
def _cond(self):
|
||||
abs_diag = math_ops.abs(self.diag)
|
||||
|
@ -160,7 +160,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
dtypes.complex128,
|
||||
]
|
||||
|
||||
matrix = ops.convert_to_tensor(matrix, name="matrix")
|
||||
matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
|
||||
|
||||
dtype = matrix.dtype
|
||||
if dtype not in allowed_dtypes:
|
||||
|
@ -198,7 +198,8 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
||||
|
||||
# Note that because this is a reflection, it lies in O(n) (for real vector
|
||||
# spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
|
||||
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
|
||||
reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
|
||||
self.reflection_axis)
|
||||
x = linalg.adjoint(x) if adjoint_arg else x
|
||||
normalized_axis = reflection_axis / linalg.norm(
|
||||
reflection_axis, axis=-1, keepdims=True)
|
||||
@ -229,7 +230,8 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
||||
return self._matmul(rhs, adjoint, adjoint_arg)
|
||||
|
||||
def _to_dense(self):
|
||||
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
|
||||
reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
|
||||
self.reflection_axis)
|
||||
normalized_axis = reflection_axis / linalg.norm(
|
||||
reflection_axis, axis=-1, keepdims=True)
|
||||
mat = normalized_axis[..., array_ops.newaxis]
|
||||
@ -238,7 +240,8 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
||||
matrix, 1. + array_ops.matrix_diag_part(matrix))
|
||||
|
||||
def _diag_part(self):
|
||||
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
|
||||
reflection_axis = ops.convert_to_tensor_v2_with_dispatch(
|
||||
self.reflection_axis)
|
||||
normalized_axis = reflection_axis / linalg.norm(
|
||||
reflection_axis, axis=-1, keepdims=True)
|
||||
return 1. - 2 * normalized_axis * math_ops.conj(normalized_axis)
|
||||
|
@ -394,7 +394,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
A `Tensor` with broadcast shape and same `dtype` as `self`.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
mat = ops.convert_to_tensor(mat, name="mat")
|
||||
mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
|
||||
mat_diag = array_ops.matrix_diag_part(mat)
|
||||
new_diag = 1 + mat_diag
|
||||
return array_ops.matrix_set_diag(mat, new_diag)
|
||||
@ -720,7 +720,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
multiplier_vector = array_ops.expand_dims(self.multiplier, -1)
|
||||
|
||||
# Shape [C1,...,Cc, M, M]
|
||||
mat = ops.convert_to_tensor(mat, name="mat")
|
||||
mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat")
|
||||
|
||||
# Shape [C1,...,Cc, M]
|
||||
mat_diag = array_ops.matrix_diag_part(mat)
|
||||
|
@ -197,7 +197,7 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
|
||||
return array_ops.shape(perm)[-1]
|
||||
|
||||
def _matmul(self, x, adjoint=False, adjoint_arg=False):
|
||||
perm = ops.convert_to_tensor(self.perm)
|
||||
perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
|
||||
if adjoint and not self.is_self_adjoint:
|
||||
# TODO(srvasude): invert_permutation doesn't work on batches so we use
|
||||
# argsort.
|
||||
@ -232,13 +232,13 @@ class LinearOperatorPermutation(linear_operator.LinearOperator):
|
||||
return self._matmul(rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
|
||||
|
||||
def _to_dense(self):
|
||||
perm = ops.convert_to_tensor(self.perm)
|
||||
perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
|
||||
return math_ops.cast(math_ops.equal(
|
||||
math_ops.range(0, self._domain_dimension_tensor(perm)),
|
||||
perm[..., array_ops.newaxis]), self.dtype)
|
||||
|
||||
def _diag_part(self):
|
||||
perm = ops.convert_to_tensor(self.perm)
|
||||
perm = ops.convert_to_tensor_v2_with_dispatch(self.perm)
|
||||
return math_ops.cast(math_ops.equal(
|
||||
math_ops.range(0, self._domain_dimension_tensor(perm)),
|
||||
perm), self.dtype)
|
||||
|
@ -209,8 +209,8 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
|
||||
# for more details.
|
||||
x = linalg.adjoint(x) if adjoint_arg else x
|
||||
expanded_x = array_ops.concat([x, array_ops.zeros_like(x)], axis=-2)
|
||||
col = ops.convert_to_tensor(self.col)
|
||||
row = ops.convert_to_tensor(self.row)
|
||||
col = ops.convert_to_tensor_v2_with_dispatch(self.col)
|
||||
row = ops.convert_to_tensor_v2_with_dispatch(self.row)
|
||||
circulant_col = array_ops.concat(
|
||||
[col,
|
||||
array_ops.zeros_like(col[..., 0:1]),
|
||||
@ -236,8 +236,8 @@ class LinearOperatorToeplitz(linear_operator.LinearOperator):
|
||||
[self.domain_dimension_tensor()], self.dtype)
|
||||
|
||||
def _to_dense(self):
|
||||
row = ops.convert_to_tensor(self.row)
|
||||
col = ops.convert_to_tensor(self.col)
|
||||
row = ops.convert_to_tensor_v2_with_dispatch(self.row)
|
||||
col = ops.convert_to_tensor_v2_with_dispatch(self.col)
|
||||
total_shape = array_ops.broadcast_dynamic_shape(
|
||||
array_ops.shape(row), array_ops.shape(col))
|
||||
n = array_ops.shape(row)[-1]
|
||||
|
@ -246,7 +246,7 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
|
||||
self.diagonals, linalg.adjoint(self.diagonals),
|
||||
message='Matrix was not equal to its adjoint.')]
|
||||
elif self.diagonals_format == _COMPACT:
|
||||
diagonals = ops.convert_to_tensor(self.diagonals)
|
||||
diagonals = ops.convert_to_tensor_v2_with_dispatch(self.diagonals)
|
||||
asserts += [linear_operator_util.assert_zero_imag_part(
|
||||
diagonals[..., 1, :], message=diag_message)]
|
||||
# Roll the subdiagonal so the shifted argument is at the end.
|
||||
@ -353,7 +353,9 @@ class LinearOperatorTridiag(linear_operator.LinearOperator):
|
||||
align='LEFT_RIGHT',
|
||||
padding_value=0.)
|
||||
|
||||
diagonals = [ops.convert_to_tensor(d) for d in self.diagonals]
|
||||
diagonals = [
|
||||
ops.convert_to_tensor_v2_with_dispatch(d) for d in self.diagonals
|
||||
]
|
||||
diagonals = array_ops.stack(diagonals, axis=-2)
|
||||
|
||||
return gen_array_ops.matrix_diag_v3(
|
||||
|
@ -114,7 +114,7 @@ def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None):
|
||||
raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format(
|
||||
dtype_name(dtype_base), dtype_name(value_dtype_base)))
|
||||
return value
|
||||
return ops.convert_to_tensor(
|
||||
return ops.convert_to_tensor_v2_with_dispatch(
|
||||
value, dtype=dtype, dtype_hint=dtype_hint, name=name)
|
||||
|
||||
|
||||
@ -189,10 +189,10 @@ def assert_no_entries_with_modulus_zero(
|
||||
An `Op` that asserts `x` has no entries with modulus zero.
|
||||
"""
|
||||
with ops.name_scope(name, values=[x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
dtype = x.dtype.base_dtype
|
||||
should_be_nonzero = math_ops.abs(x)
|
||||
zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
|
||||
zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
|
||||
return check_ops.assert_less(zero, should_be_nonzero, message=message)
|
||||
|
||||
|
||||
@ -208,13 +208,13 @@ def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
|
||||
An `Op` that asserts `x` has no entries with modulus zero.
|
||||
"""
|
||||
with ops.name_scope(name, values=[x]):
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
x = ops.convert_to_tensor_v2_with_dispatch(x, name="x")
|
||||
dtype = x.dtype.base_dtype
|
||||
|
||||
if dtype.is_floating:
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
|
||||
zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype)
|
||||
return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
|
||||
|
||||
|
||||
@ -261,7 +261,7 @@ def shape_tensor(shape, name=None):
|
||||
dtype = dtypes.int32
|
||||
else:
|
||||
dtype = None
|
||||
return ops.convert_to_tensor(shape, dtype=dtype, name=name)
|
||||
return ops.convert_to_tensor_v2_with_dispatch(shape, dtype=dtype, name=name)
|
||||
|
||||
|
||||
################################################################################
|
||||
@ -323,7 +323,7 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
|
||||
batch_matrices = list(batch_matrices)
|
||||
|
||||
for i, mat in enumerate(batch_matrices):
|
||||
batch_matrices[i] = ops.convert_to_tensor(mat)
|
||||
batch_matrices[i] = ops.convert_to_tensor_v2_with_dispatch(mat)
|
||||
assert_is_batch_matrix(batch_matrices[i])
|
||||
|
||||
if len(batch_matrices) < 2:
|
||||
@ -366,8 +366,9 @@ def broadcast_matrix_batch_dims(batch_matrices, name=None):
|
||||
def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
|
||||
"""Solve systems of linear equations."""
|
||||
with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]):
|
||||
matrix = ops.convert_to_tensor(matrix, name="matrix")
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)
|
||||
matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix")
|
||||
rhs = ops.convert_to_tensor_v2_with_dispatch(
|
||||
rhs, name="rhs", dtype=matrix.dtype)
|
||||
|
||||
# If either matrix/rhs has extra dims, we can reshape to get rid of them.
|
||||
matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
|
||||
@ -526,7 +527,8 @@ def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
|
||||
if not any(nest.is_nested(x) for x in arg):
|
||||
return True
|
||||
else:
|
||||
arg_dims = [ops.convert_to_tensor(x).shape[arg_split_dim] for x in arg]
|
||||
arg_dims = [ops.convert_to_tensor_v2_with_dispatch(
|
||||
x).shape[arg_split_dim] for x in arg]
|
||||
self_dims = [dim.value for dim in block_dimensions]
|
||||
|
||||
# If none of the operator dimensions are known, interpret the input as
|
||||
|
@ -18,11 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linear_operator_diag
|
||||
from tensorflow.python.ops.proto_ops import decode_proto
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
@ -60,6 +62,8 @@ class TensorTracer(object):
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.shape = array_ops.ones(shape=(4, 4)).shape
|
||||
self.dtype = dtypes.float32
|
||||
|
||||
def __repr__(self):
|
||||
if self.args is None and self.kwargs is None:
|
||||
@ -70,6 +74,10 @@ class TensorTracer(object):
|
||||
["{}={}".format(name, x) for (name, x) in self.kwargs.items()])
|
||||
return "{}({})".format(self.name, ", ".join(args))
|
||||
|
||||
@property
|
||||
def is_tensor_like(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _overload_all_operators(cls): # pylint: disable=invalid-name
|
||||
"""Register overloads for all operators."""
|
||||
@ -282,5 +290,42 @@ class DispatchTest(test_util.TensorFlowTestCase):
|
||||
# Clean up.
|
||||
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||
|
||||
def testGlobalDispatcherLinearOperators(self):
|
||||
original_global_dispatchers = dispatch._GLOBAL_DISPATCHERS
|
||||
try:
|
||||
TensorTracerOpDispatcher().register()
|
||||
|
||||
x = TensorTracer("x")
|
||||
|
||||
# To grab the eigenvalues the diag operator just calls convert_to_tensor
|
||||
# (twice) in this case.
|
||||
trace = linear_operator_diag.LinearOperatorDiag(x).eigvals()
|
||||
self.assertEqual(
|
||||
str(trace),
|
||||
"convert_to_tensor(convert_to_tensor(x, dtype=None, dtype_hint=None, "
|
||||
"name=diag))")
|
||||
|
||||
# The diagonal tensor addition gets traced even though the linear_operator
|
||||
# API only uses dispatchable ops instead of directly exposing dispatching.
|
||||
trace = linear_operator_diag.LinearOperatorDiag(x).add_to_tensor(x)
|
||||
self.assertIn(
|
||||
"linalg.set_diag(convert_to_tensor(x, name=x), __operators__.add("
|
||||
"convert_to_tensor(x, dtype=None, dtype_hint=None, name=diag), "
|
||||
"linalg.diag_part(convert_to_tensor(x, name=x)), "
|
||||
"name=",
|
||||
str(trace))
|
||||
|
||||
# The dispatch-supporting ops the non-singular check calls out to
|
||||
# get traced.
|
||||
trace = linear_operator_diag.LinearOperatorDiag(x).assert_non_singular()
|
||||
self.assertIn("debugging.assert_less", str(trace))
|
||||
self.assertIn(
|
||||
"message=Singular operator: Diagonal contained zero values.",
|
||||
str(trace))
|
||||
|
||||
finally:
|
||||
# Clean up.
|
||||
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user