From 64ea20632bf346a9474b4e0420f1277e8054a002 Mon Sep 17 00:00:00 2001
From: Ian Langmore <langmore@google.com>
Date: Tue, 17 Jan 2017 10:52:32 -0800
Subject: [PATCH] Name change in LinearOperator:  batch_shape_dynamic -->
 batch_shape_tensor. Similarly for other "dynamic" Ops. Change: 144728885

---
 .../distributions/python/ops/bijector.py      |  2 +-
 .../linear_operator_composition_test.py       |  8 +-
 .../kernel_tests/linear_operator_test.py      | 18 ++---
 .../kernel_tests/linear_operator_util_test.py |  2 +-
 .../linalg/python/ops/linear_operator.py      | 76 ++++++++++---------
 .../python/ops/linear_operator_composition.py | 10 +--
 .../linalg/python/ops/linear_operator_diag.py |  2 +-
 .../python/ops/linear_operator_identity.py    | 10 +--
 .../python/ops/linear_operator_matrix.py      |  2 +-
 .../python/ops/linear_operator_test_util.py   | 10 +--
 .../linalg/python/ops/linear_operator_tril.py |  2 +-
 .../linalg/python/ops/linear_operator_util.py |  4 +-
 12 files changed, 74 insertions(+), 72 deletions(-)

diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py
index 7e92f496773..41a4f9d8592 100644
--- a/tensorflow/contrib/distributions/python/ops/bijector.py
+++ b/tensorflow/contrib/distributions/python/ops/bijector.py
@@ -1977,7 +1977,7 @@ class AffineLinearOperator(Bijector):
         if scale.tensor_rank is not None:
           batch_ndims = scale.tensor_rank - 2
         else:
-          batch_ndims = scale.tensor_rank_dynamic() - 2
+          batch_ndims = scale.tensor_rank_tensor() - 2
           graph_parents += [batch_ndims]
       else:
         batch_ndims = 0  # We won't need shape inference when scale is None.
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py
index 2f60554104d..6309d36258e 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_composition_test.py
@@ -200,16 +200,16 @@ class NonSquareLinearOperatorCompositionTest(
     operator = linalg.LinearOperatorComposition(operators)
     self.assertAllEqual((2, 3, 5), operator.shape)
 
-  def test_dynamic_shapes_when_statically_available(self):
+  def test_shape_tensors_when_statically_available(self):
     operators = [
         linalg.LinearOperatorMatrix(rng.rand(2, 3, 4)),
         linalg.LinearOperatorMatrix(rng.rand(2, 4, 5))
     ]
     operator = linalg.LinearOperatorComposition(operators)
     with self.test_session():
-      self.assertAllEqual((2, 3, 5), operator.shape_dynamic().eval())
+      self.assertAllEqual((2, 3, 5), operator.shape_tensor().eval())
 
-  def test_dynamic_shapes_when_only_dynamically_available(self):
+  def test_shape_tensors_when_only_dynamically_available(self):
     mat_1 = rng.rand(1, 2, 3, 4)
     mat_2 = rng.rand(1, 2, 4, 5)
     mat_ph_1 = array_ops.placeholder(dtypes.float64)
@@ -223,7 +223,7 @@ class NonSquareLinearOperatorCompositionTest(
     operator = linalg.LinearOperatorComposition(operators)
     with self.test_session():
       self.assertAllEqual(
-          (1, 2, 3, 5), operator.shape_dynamic().eval(feed_dict=feed_dict))
+          (1, 2, 3, 5), operator.shape_tensor().eval(feed_dict=feed_dict))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
index 8f77c5e6e33..c099194eed6 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_test.py
@@ -31,7 +31,7 @@ rng = np.random.RandomState(123)
 
 
 class LinearOperatorShape(linalg.LinearOperator):
-  """LinearOperator that implements the methods ._shape and _shape_dynamic."""
+  """LinearOperator that implements the methods ._shape and _shape_tensor."""
 
   def __init__(self,
                shape,
@@ -49,7 +49,7 @@ class LinearOperatorShape(linalg.LinearOperator):
   def _shape(self):
     return tensor_shape.TensorShape(self._stored_shape)
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     return constant_op.constant(self._stored_shape, dtype=dtypes.int32)
 
 
@@ -71,7 +71,7 @@ class LinearOperatorApplyOnly(linalg.LinearOperator):
   def _shape(self):
     return self._matrix.get_shape()
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     return array_ops.shape(self._matrix)
 
   def _apply(self, x, adjoint=False):
@@ -96,11 +96,11 @@ class LinearOperatorTest(test.TestCase):
       shape = (1, 2, 3, 4)
       operator = LinearOperatorShape(shape)
 
-      self.assertAllEqual(shape, operator.shape_dynamic().eval())
-      self.assertAllEqual(4, operator.tensor_rank_dynamic().eval())
-      self.assertAllEqual((1, 2), operator.batch_shape_dynamic().eval())
-      self.assertAllEqual(4, operator.domain_dimension_dynamic().eval())
-      self.assertAllEqual(3, operator.range_dimension_dynamic().eval())
+      self.assertAllEqual(shape, operator.shape_tensor().eval())
+      self.assertAllEqual(4, operator.tensor_rank_tensor().eval())
+      self.assertAllEqual((1, 2), operator.batch_shape_tensor().eval())
+      self.assertAllEqual(4, operator.domain_dimension_tensor().eval())
+      self.assertAllEqual(3, operator.range_dimension_tensor().eval())
 
   def test_is_x_properties(self):
     operator = LinearOperatorShape(
@@ -120,7 +120,7 @@ class LinearOperatorTest(test.TestCase):
       self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
       self.assertAllClose(matrix, operator_dense.eval())
 
-  def test_generic_to_dense_method_non_square_matrix_dynamic(self):
+  def test_generic_to_dense_method_non_square_matrix_tensor(self):
     matrix = rng.randn(2, 3, 4)
     matrix_ph = array_ops.placeholder(dtypes.float64)
     operator = LinearOperatorApplyOnly(matrix_ph)
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
index 4eac01092f1..bf6f8f83027 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
@@ -96,7 +96,7 @@ class DomainDimensionStubOperator(object):
   def __init__(self, domain_dimension):
     self._domain_dimension = ops.convert_to_tensor(domain_dimension)
 
-  def domain_dimension_dynamic(self):
+  def domain_dimension_tensor(self):
     return self._domain_dimension
 
 
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py
index e229820edc1..2467603605c 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py
@@ -180,13 +180,15 @@ class LinearOperator(object):
     self._is_positive_definite = is_positive_definite
     self._name = name or type(self).__name__
 
-    # We will cache some values to avoid repeatedly adding shape
-    # manipulation ops to the graph.  Cleaner.
-    self._cached_shape_dynamic = None
-    self._cached_batch_shape_dynamic = None
-    self._cached_domain_dimension_dynamic = None
-    self._cached_range_dimension_dynamic = None
-    self._cached_tensor_rank_dynamic = None
+    # We will cache some tensors to avoid repeatedly adding shape
+    # manipulation ops to the graph.
+    # Naming convention:
+    #   self._cached_X_tensor is the cached version of self._X_tensor.
+    self._cached_shape_tensor = None
+    self._cached_batch_shape_tensor = None
+    self._cached_domain_dimension_tensor = None
+    self._cached_range_dimension_tensor = None
+    self._cached_tensor_rank_tensor = None
 
   @contextlib.contextmanager
   def _name_scope(self, name=None, values=None):
@@ -240,10 +242,10 @@ class LinearOperator(object):
     """
     return self._shape()
 
-  def _shape_dynamic(self):
-    raise NotImplementedError("_shape_dynamic is not implemented.")
+  def _shape_tensor(self):
+    raise NotImplementedError("_shape_tensor is not implemented.")
 
-  def shape_dynamic(self, name="shape_dynamic"):
+  def shape_tensor(self, name="shape_tensor"):
     """Shape of this `LinearOperator`, determined at runtime.
 
     If this operator acts like the batch matrix `A` with
@@ -258,14 +260,14 @@ class LinearOperator(object):
     """
     with self._name_scope(name):
       # Be clean by avoiding adding shape Ops to the graph too many times.
-      if self._cached_shape_dynamic is None:
+      if self._cached_shape_tensor is None:
         # Prefer to use statically defined shape if available.
         if self.shape.is_fully_defined():
-          self._cached_shape_dynamic = linear_operator_util.shape_tensor(
+          self._cached_shape_tensor = linear_operator_util.shape_tensor(
               self.shape.as_list())
         else:
-          self._cached_shape_dynamic = self._shape_dynamic()
-      return self._cached_shape_dynamic
+          self._cached_shape_tensor = self._shape_tensor()
+      return self._cached_shape_tensor
 
   @property
   def batch_shape(self):
@@ -281,7 +283,7 @@ class LinearOperator(object):
     # Derived classes get this "for free" once .shape is implemented.
     return self.shape[:-2]
 
-  def batch_shape_dynamic(self, name="batch_shape_dynamic"):
+  def batch_shape_tensor(self, name="batch_shape_tensor"):
     """Shape of batch dimensions of this operator, determined at runtime.
 
     If this operator acts like the batch matrix `A` with
@@ -296,14 +298,14 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_batch_shape_dynamic is None:
+      if self._cached_batch_shape_tensor is None:
         # Prefer to use statically defined shape if available.
         if self.batch_shape.is_fully_defined():
-          self._cached_batch_shape_dynamic = linear_operator_util.shape_tensor(
+          self._cached_batch_shape_tensor = linear_operator_util.shape_tensor(
               self.batch_shape.as_list(), name="batch_shape")
         else:
-          self._cached_batch_shape_dynamic = self.shape_dynamic()[:-2]
-      return self._cached_batch_shape_dynamic
+          self._cached_batch_shape_tensor = self.shape_tensor()[:-2]
+      return self._cached_batch_shape_tensor
 
   @property
   def tensor_rank(self, name="tensor_rank"):
@@ -322,7 +324,7 @@ class LinearOperator(object):
     with self._name_scope(name):
       return self.shape.ndims
 
-  def tensor_rank_dynamic(self, name="tensor_rank_dynamic"):
+  def tensor_rank_tensor(self, name="tensor_rank_tensor"):
     """Rank (in the sense of tensors) of matrix corresponding to this operator.
 
     If this operator acts like the batch matrix `A` with
@@ -336,15 +338,15 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_tensor_rank_dynamic is None:
+      if self._cached_tensor_rank_tensor is None:
         # Prefer to use statically defined shape if available.
         if self.tensor_rank is not None:
-          self._cached_tensor_rank_dynamic = ops.convert_to_tensor(
+          self._cached_tensor_rank_tensor = ops.convert_to_tensor(
               self.tensor_rank)
         else:
-          self._cached_tensor_rank_dynamic = array_ops.size(
-              self.shape_dynamic())
-      return self._cached_tensor_rank_dynamic
+          self._cached_tensor_rank_tensor = array_ops.size(
+              self.shape_tensor())
+      return self._cached_tensor_rank_tensor
 
   @property
   def domain_dimension(self):
@@ -359,7 +361,7 @@ class LinearOperator(object):
     # Derived classes get this "for free" once .shape is implemented.
     return self.shape[-1]
 
-  def domain_dimension_dynamic(self, name="domain_dimension_dynamic"):
+  def domain_dimension_tensor(self, name="domain_dimension_tensor"):
     """Dimension (in the sense of vector spaces) of the domain of this operator.
 
     Determined at runtime.
@@ -375,14 +377,14 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_domain_dimension_dynamic is None:
+      if self._cached_domain_dimension_tensor is None:
         # Prefer to use statically defined shape if available.
         if self.domain_dimension.value is not None:
-          self._cached_domain_dimension_dynamic = ops.convert_to_tensor(
+          self._cached_domain_dimension_tensor = ops.convert_to_tensor(
               self.domain_dimension.value)
         else:
-          self._cached_domain_dimension_dynamic = self.shape_dynamic()[-1]
-      return self._cached_domain_dimension_dynamic
+          self._cached_domain_dimension_tensor = self.shape_tensor()[-1]
+      return self._cached_domain_dimension_tensor
 
   @property
   def range_dimension(self):
@@ -397,7 +399,7 @@ class LinearOperator(object):
     # Derived classes get this "for free" once .shape is implemented.
     return self.shape[-2]
 
-  def range_dimension_dynamic(self, name="range_dimension_dynamic"):
+  def range_dimension_tensor(self, name="range_dimension_tensor"):
     """Dimension (in the sense of vector spaces) of the range of this operator.
 
     Determined at runtime.
@@ -413,14 +415,14 @@ class LinearOperator(object):
     """
     # Derived classes get this "for free" once .shape() is implemented.
     with self._name_scope(name):
-      if self._cached_range_dimension_dynamic is None:
+      if self._cached_range_dimension_tensor is None:
         # Prefer to use statically defined shape if available.
         if self.range_dimension.value is not None:
-          self._cached_range_dimension_dynamic = ops.convert_to_tensor(
+          self._cached_range_dimension_tensor = ops.convert_to_tensor(
               self.range_dimension.value)
         else:
-          self._cached_range_dimension_dynamic = self.shape_dynamic()[-2]
-      return self._cached_range_dimension_dynamic
+          self._cached_range_dimension_tensor = self.shape_tensor()[-2]
+      return self._cached_range_dimension_tensor
 
   def _assert_non_singular(self):
     raise NotImplementedError("assert_non_singular is not implemented.")
@@ -574,12 +576,12 @@ class LinearOperator(object):
     if self.batch_shape.is_fully_defined():
       batch_shape = self.batch_shape
     else:
-      batch_shape = self.batch_shape_dynamic()
+      batch_shape = self.batch_shape_tensor()
 
     if self.domain_dimension.value is not None:
       n = self.domain_dimension.value
     else:
-      n = self.domain_dimension_dynamic()
+      n = self.domain_dimension_tensor()
 
     eye = linalg_ops.eye(num_rows=n, batch_shape=batch_shape, dtype=self.dtype)
     return self.apply(eye)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
index 3e118ebbd47..81e77358410 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
@@ -202,7 +202,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
 
     return batch_shape.concatenate(matrix_shape)
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     # Avoid messy broadcasting if possible.
     if self.shape.is_fully_defined():
       return ops.convert_to_tensor(
@@ -212,14 +212,14 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
     # the graph.  Things will fail at runtime naturally if shapes are
     # incompatible.
     matrix_shape = array_ops.stack([
-        self.operators[0].range_dimension_dynamic(),
-        self.operators[-1].domain_dimension_dynamic()
+        self.operators[0].range_dimension_tensor(),
+        self.operators[-1].domain_dimension_tensor()
     ])
 
     # Dummy Tensor of zeros.  Will never be materialized.
-    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_dynamic())
+    zeros = array_ops.zeros(shape=self.operators[0].batch_shape_tensor())
     for operator in self.operators[1:]:
-      zeros += array_ops.zeros(shape=operator.batch_shape_dynamic())
+      zeros += array_ops.zeros(shape=operator.batch_shape_tensor())
     batch_shape = array_ops.shape(zeros)
 
     return array_ops.concat((batch_shape, matrix_shape), 0)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
index d59e8be767d..4700e655186 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
@@ -166,7 +166,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
     d_shape = self._diag.get_shape()
     return d_shape.concatenate(d_shape[-1:])
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     d_shape = array_ops.shape(self._diag)
     k = d_shape[-1]
     return array_ops.concat((d_shape, [k]), 0)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py
index 3304698ec67..6559f8b1168 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py
@@ -261,7 +261,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
     batch_shape = tensor_shape.TensorShape(self._batch_shape_static)
     return batch_shape.concatenate(matrix_shape)
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     matrix_shape = array_ops.stack(
         (self._num_rows, self._num_rows), axis=0)
     if self._batch_shape_arg is None:
@@ -307,7 +307,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
     # Dynamic broadcast:
     #   Always add to an array of zeros, rather than using a "cond", since a
     #   cond would require copying data from GPU --> CPU.
-    special_shape = array_ops.concat((self.batch_shape_dynamic(), [1, 1]), 0)
+    special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0)
     zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
     return x + zeros
 
@@ -320,10 +320,10 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
     return self._possibly_broadcast_batch_shape(x)
 
   def _determinant(self):
-    return array_ops.ones(shape=self.batch_shape_dynamic(), dtype=self.dtype)
+    return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype)
 
   def _log_abs_determinant(self):
-    return array_ops.zeros(shape=self.batch_shape_dynamic(), dtype=self.dtype)
+    return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype)
 
   def _solve(self, rhs, adjoint=False):
     return self._apply(rhs)
@@ -566,7 +566,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
     batch_shape = self.multiplier.get_shape()
     return batch_shape.concatenate(matrix_shape)
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     matrix_shape = array_ops.stack(
         (self._num_rows, self._num_rows), axis=0)
 
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py b/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py
index 7ca18450d1e..3b5dc7c4819 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_matrix.py
@@ -157,7 +157,7 @@ class LinearOperatorMatrix(linear_operator.LinearOperator):
   def _shape(self):
     return self._matrix.get_shape()
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     return array_ops.shape(self._matrix)
 
   def _apply(self, x, adjoint=False):
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
index 5de9bb5d775..466fedd578f 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
@@ -262,8 +262,8 @@ class SquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
       n = operator.domain_dimension.value
       x_shape = batch_shape + [n, r]
     else:
-      batch_shape = operator.batch_shape_dynamic()
-      n = operator.domain_dimension_dynamic()
+      batch_shape = operator.batch_shape_tensor()
+      n = operator.domain_dimension_tensor()
       x_shape = array_ops.concat((batch_shape, [n, r]), 0)
 
     return random_normal(x_shape, dtype=operator.dtype)
@@ -316,11 +316,11 @@ class NonSquareLinearOperatorDerivedClassTest(LinearOperatorDerivedClassTest):
         n = operator.domain_dimension.value
       x_shape = batch_shape + [n, r]
     else:
-      batch_shape = operator.batch_shape_dynamic()
+      batch_shape = operator.batch_shape_tensor()
       if adjoint:
-        n = operator.range_dimension_dynamic()
+        n = operator.range_dimension_tensor()
       else:
-        n = operator.domain_dimension_dynamic()
+        n = operator.domain_dimension_tensor()
       x_shape = array_ops.concat((batch_shape, [n, r]), 0)
 
     return random_normal(x_shape, dtype=operator.dtype)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
index 7c5b9b6b547..2b1fb4c04ca 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
@@ -157,7 +157,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
   def _shape(self):
     return self._tril.get_shape()
 
-  def _shape_dynamic(self):
+  def _shape_tensor(self):
     return array_ops.shape(self._tril)
 
   def _assert_non_singular(self):
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
index 44092f0c062..6e56fac2e3d 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
@@ -83,10 +83,10 @@ def assert_compatible_matrix_dimensions(operator, x):
   Returns:
     `Assert` `Op`.
   """
-  # Static checks are done in the base class.  Only dynamic asserts here.
+  # Static checks are done in the base class.  Only tensor asserts here.
   assert_same_dd = check_ops.assert_equal(
       array_ops.shape(x)[-2],
-      operator.domain_dimension_dynamic(),
+      operator.domain_dimension_tensor(),
       message=(
           "Incompatible matrix dimensions.  "
           "shape[-2] of argument to be the same as this operator"))