From 0b9ff2eb1a097602206c6b29823543768bfb34fe Mon Sep 17 00:00:00 2001
From: Peng Wang <wangpeng@google.com>
Date: Thu, 4 Feb 2021 19:13:46 -0800
Subject: [PATCH] Remove ndarray wrapper from TF Numpy. We return tensors
 directly.

PiperOrigin-RevId: 355761429
Change-Id: I1ab012bcd831550cd2aa2a8de3d758c23bc6332a
---
 tensorflow/core/protobuf/struct.proto         |   2 +-
 tensorflow/python/eager/backprop.py           |  34 --
 tensorflow/python/eager/forwardprop.py        |   8 +-
 tensorflow/python/eager/function.py           |  11 -
 tensorflow/python/framework/ops.py            |  36 +-
 tensorflow/python/framework/ops_test.py       |  11 +
 tensorflow/python/framework/tensor_shape.py   |  10 +
 .../python/framework/tensor_shape_test.py     |  14 +
 .../mixed_precision/autocast_variable_test.py |   4 +-
 tensorflow/python/ops/array_ops.py            |   3 +
 tensorflow/python/ops/map_fn.py               |   8 -
 tensorflow/python/ops/math_ops.py             |  89 ++++-
 tensorflow/python/ops/numpy_ops/BUILD         |  12 +
 .../ops/numpy_ops/integration_test/BUILD      |  11 +
 .../integration_test/np_config_test.py        |  44 +++
 .../python/ops/numpy_ops/np_array_ops.py      | 313 ++++++++----------
 .../python/ops/numpy_ops/np_array_ops_test.py |  77 ++---
 tensorflow/python/ops/numpy_ops/np_arrays.py  | 302 +----------------
 .../python/ops/numpy_ops/np_arrays_test.py    |  93 +++---
 tensorflow/python/ops/numpy_ops/np_config.py  |  39 +++
 tensorflow/python/ops/numpy_ops/np_dtypes.py  |  31 +-
 .../python/ops/numpy_ops/np_dtypes_test.py    |  57 ++++
 .../python/ops/numpy_ops/np_interop_test.py   |  21 +-
 .../python/ops/numpy_ops/np_logic_test.py     |  10 +-
 .../python/ops/numpy_ops/np_math_ops.py       | 209 ++++++------
 .../python/ops/numpy_ops/np_math_ops_test.py  |   4 +-
 tensorflow/python/ops/numpy_ops/np_random.py  |  15 +-
 .../python/ops/numpy_ops/np_random_test.py    |   4 +-
 tensorflow/python/ops/numpy_ops/np_utils.py   |   5 -
 .../ops/parallel_for/control_flow_ops.py      |  17 -
 tensorflow/python/saved_model/load_test.py    |  30 --
 .../saved_model/nested_structure_coder.py     |   3 -
 .../nested_structure_coder_test.py            |   9 -
 ...ensorflow.experimental.numpy.ndarray.pbtxt |  67 ++--
 34 files changed, 744 insertions(+), 859 deletions(-)
 create mode 100644 tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py
 create mode 100644 tensorflow/python/ops/numpy_ops/np_config.py
 create mode 100644 tensorflow/python/ops/numpy_ops/np_dtypes_test.py

diff --git a/tensorflow/core/protobuf/struct.proto b/tensorflow/core/protobuf/struct.proto
index c99eab5dd88..19cd5dfb6dd 100644
--- a/tensorflow/core/protobuf/struct.proto
+++ b/tensorflow/core/protobuf/struct.proto
@@ -136,7 +136,7 @@ message TypeSpecProto {
     PER_REPLICA_SPEC = 8;     // PerReplicaSpec from distribute/values.py
     VARIABLE_SPEC = 9;        // tf.VariableSpec
     ROW_PARTITION_SPEC = 10;  // RowPartitionSpec from ragged/row_partition.py
-    NDARRAY_SPEC = 11;        // TF Numpy NDarray spec
+    reserved 11;
   }
   TypeSpecClass type_spec_class = 1;
 
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 94a3c5a67ac..4681b9c6185 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -62,9 +62,6 @@ from tensorflow.python.util.tf_export import tf_export
 pfor_ops = LazyLoader(
     "pfor_ops", globals(),
     "tensorflow.python.ops.parallel_for.control_flow_ops")
-np_arrays = LazyLoader(
-    "np_arrays", globals(),
-    "tensorflow.python.ops.numpy_ops.np_arrays")
 
 function = LazyLoader("function", globals(),
                       "tensorflow.python.eager.function")
@@ -727,8 +724,6 @@ def _handle_or_self(x):
   """Unwrap resource variable/ndarray to return tensors."""
   if resource_variable_ops.is_resource_variable(x):
     return x.handle
-  if isinstance(x, np_arrays.ndarray):
-    return x.data
   return x
 
 
@@ -1034,7 +1029,6 @@ class GradientTape(object):
                       " of Tensors or Variables to be differentiated,"
                       " but recieved %r" % (target))
 
-    num_ndarrays = 0
     flat_targets = []
     for t in nest.flatten(target):
       if not backprop_util.IsTrainable(t):
@@ -1045,12 +1039,7 @@ class GradientTape(object):
       if resource_variable_ops.is_resource_variable(t):
         with self:
           t = ops.convert_to_tensor(t)
-      elif isinstance(t, np_arrays.ndarray):
-        t = t.data
-        num_ndarrays += 1
       flat_targets.append(t)
-    # Only rewrap if all targets are ndarray. If not, prefer tensors.
-    rewrap_as_ndarray = num_ndarrays == len(flat_targets)
 
     flat_sources = nest.flatten(sources)
     flat_sources_raw = flat_sources
@@ -1083,13 +1072,6 @@ class GradientTape(object):
       self._watched_variables = self._tape.watched_variables()
       self._tape = None
 
-    if rewrap_as_ndarray:
-      def _tensor_to_ndarray(x):
-        if x is not None:
-          return np_arrays.tensor_to_ndarray(x)
-        return None
-      flat_grad = nest.map_structure(_tensor_to_ndarray, flat_grad)
-
     grad = nest.pack_sequence_as(sources, flat_grad)
     return grad
 
@@ -1158,10 +1140,6 @@ class GradientTape(object):
                          "compute one set of gradients (or jacobians)")
 
     flat_sources = nest.flatten(sources)
-    rewrap_as_ndarray = False
-    if isinstance(target, np_arrays.ndarray):
-      target = target.data
-      rewrap_as_ndarray = True
     target_static_shape = target.shape
     target_shape = array_ops.shape(target)
     # Note that we push and pop the tape here and below. This is needed since we
@@ -1211,8 +1189,6 @@ class GradientTape(object):
         out = array_ops.reshape(out, new_shape)
         if context.executing_eagerly():
           out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
-      if rewrap_as_ndarray:
-        out = np_arrays.tensor_to_ndarray(out)
       output[i] = out
 
     return nest.pack_sequence_as(sources, output)
@@ -1287,12 +1263,6 @@ class GradientTape(object):
     if self._tape is None:
       raise RuntimeError("A non-persistent GradientTape can only be used to"
                          "compute one set of gradients (or jacobians)")
-    rewrap_as_ndarray = False
-    if isinstance(target, np_arrays.ndarray):
-      target = target.data
-      rewrap_as_ndarray = True
-    if isinstance(source, np_arrays.ndarray):
-      source = source.data
     target_shape = target.shape
     if target_shape.rank is None:
       dim = tensor_shape.Dimension(None)
@@ -1354,8 +1324,6 @@ class GradientTape(object):
       # represent unconnected gradients. This is to maintain compatibility with
       # the previous behavior, which ignored `unconnected_gradients`.
       output = array_ops.zeros(new_shape, target.dtype)
-      if rewrap_as_ndarray:
-        output = np_arrays.tensor_to_ndarray(output)
       return output
     else:
       output = array_ops.reshape(output,
@@ -1363,6 +1331,4 @@ class GradientTape(object):
       output = array_ops.transpose(output, [1, 0, 2])
 
       output = array_ops.reshape(output, new_shape)
-      if rewrap_as_ndarray:
-        output = np_arrays.tensor_to_ndarray(output)
       return output
diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py
index 2f64bad0dff..f3ae1643073 100644
--- a/tensorflow/python/eager/forwardprop.py
+++ b/tensorflow/python/eager/forwardprop.py
@@ -32,7 +32,6 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops.numpy_ops import np_arrays
 from tensorflow.python.ops.parallel_for import control_flow_ops
 from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
 from tensorflow.python.platform import tf_logging as logging
@@ -441,16 +440,11 @@ class ForwardAccumulator():
       if hasattr(tensor, "handle"):
         unwrapped_tensor = ops.convert_to_tensor(tensor.handle)
       else:
-        if isinstance(tensor, np_arrays.ndarray):
-          unwrapped_tensor = tensor.data
-        else:
-          unwrapped_tensor = tensor
+        unwrapped_tensor = tensor
       result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
                                                        unwrapped_tensor)
       if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
         result = array_ops.zeros_like(tensor)
-      if result is not None and isinstance(tensor, np_arrays.ndarray):
-        return np_arrays.tensor_to_ndarray(result)
       return result
 
     return nest.map_structure(_fetch_jvp, primals)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 280040d4157..2daccff8a89 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1522,11 +1522,6 @@ class ConcreteFunction(object):
     self._func_graph = func_graph
     self._captured_inputs = self._func_graph.external_captures
     self._captured_closures = self._func_graph.deferred_external_captures
-    structured_outputs = self._func_graph.structured_outputs
-    self._ndarrays_list = (
-        isinstance(structured_outputs, (list, tuple)) and structured_outputs and
-        all(isinstance(o, np_arrays.ndarray) for o in structured_outputs))
-    self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
 
     # function_spec defines the structured signature.
     self._set_function_spec(function_spec)
@@ -2176,12 +2171,6 @@ class ConcreteFunction(object):
     if self._func_graph.structured_outputs is None:
       return result
 
-    if result:
-      if self._ndarrays_list:
-        return [np_arrays.tensor_to_ndarray(o) for o in result]
-      elif self._ndarray_singleton:
-        return np_arrays.tensor_to_ndarray(result[0])
-
     # Replace outputs with results, skipping over any 'None' values.
     outputs_list = nest.flatten(
         self._func_graph.structured_outputs, expand_composites=True)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index ea6ae1b2f6f..794f48799ab 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -257,7 +257,7 @@ def disable_tensor_equality():
 
 
 # TODO(mdan): This object should subclass Symbol, not just Tensor.
-@tf_export("Tensor")
+@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"])
 class Tensor(internal.NativeObject, core_tf_types.Tensor):
   """A tensor is a multidimensional array of elements represented by a
 
@@ -386,6 +386,16 @@ class Tensor(internal.NativeObject, core_tf_types.Tensor):
     self._id = uid()
     self._name = None
 
+  def __getattr__(self, name):
+    if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size",
+                "tolist", "data"}:
+      # TODO(wangpeng): Export the enable_numpy_behavior knob
+      raise AttributeError("""
+        If you are looking for numpy-related methods, please run the following:
+        import tensorflow.python.ops.numpy_ops.np_config
+        np_config.enable_numpy_behavior()""")
+    self.__getattribute__(name)
+
   @staticmethod
   def _create_with_tf_output(op, value_index, dtype, tf_output):
     ret = Tensor(op, value_index, dtype)
@@ -6943,6 +6953,30 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs):
   return grouped_inputs
 
 
+_numpy_style_type_promotion = False
+
+
+def enable_numpy_style_type_promotion():
+  """If called, follows NumPy's rules for type promotion.
+
+  Used for enabling NumPy behavior on methods for TF NumPy.
+  """
+  global _numpy_style_type_promotion
+  _numpy_style_type_promotion = True
+
+
+_numpy_style_slicing = False
+
+
+def enable_numpy_style_slicing():
+  """If called, follows NumPy's rules for slicing Tensors.
+
+  Used for enabling NumPy behavior on slicing for TF NumPy.
+  """
+  global _numpy_style_slicing
+  _numpy_style_slicing = True
+
+
 class _TensorIterator(object):
   """Iterates over the leading dim of a Tensor. Performs no error checks."""
 
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 14db2375d96..19894060ab5 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -202,6 +202,17 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
     self.assertAllEqual(np.array(x), np.ones((3, 4)))
     self.assertEqual(len(x), 3)
 
+  def testConstructor(self):
+    a = array_ops.ones([])
+    for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size",
+                 "tolist", "data"]:
+      with self.assertRaisesRegex(
+          AttributeError, r"If you are looking for numpy-related methods"):
+        getattr(a, name)
+    with self.assertRaisesRegex(
+        AttributeError, r"object has no attribute"):
+      a.foo_bar()
+
   def testRef(self):
     x1 = constant_op.constant(3)
     x2 = x1
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index 8c45906ab37..08c866cd2cf 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -250,6 +250,11 @@ class Dimension(object):
     # Allow use in Python 3 range
     return self._value
 
+  def __hash__(self):
+    if self._value is None:
+      raise ValueError("Unable to hash Dimension with value 'None'")
+    return hash(self._value)
+
   @property
   def value(self):
     """The value of this dimension, or None if it is unknown."""
@@ -986,6 +991,11 @@ class TensorShape(object):
       other = TensorShape(other)
     return other.concatenate(self)
 
+  def __hash__(self):
+    if not self.is_fully_defined():
+      raise ValueError("Unable to hash partially defined TensorShape.")
+    return hash(tuple(self._dims))
+
   def concatenate(self, other):
     """Returns the concatenation of the dimension in `self` and `other`.
 
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index a83d024d5ca..5eb3ff02672 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -384,6 +384,20 @@ class ShapeTest(test_util.TensorFlowTestCase, parameterized.TestCase):
     else:
       self.assertEqual(expected, mcs.as_list())
 
+  def testHash(self):
+    base = tensor_shape.TensorShape([1, 2, 3, 4])
+    base_copy = tensor_shape.TensorShape([1, 2, 3, 4])
+    self.assertEqual(hash(base), hash(base_copy))
+
+    with self.assertRaisesRegex(ValueError, r"Unable to hash partially"):
+      hash(tensor_shape.TensorShape([1, 2, 3, 4, None]))
+
+    with self.assertRaisesRegex(ValueError, r"Unable to hash partially"):
+      hash(tensor_shape.TensorShape(None))
+
+    with self.assertRaisesRegex(ValueError, r"Unable to hash Dimension"):
+      hash(tensor_shape.Dimension(None))
+
   def testMostSpecificCompatibleShape(self):
     self._testMostSpecificCompatibleShapeHelper([1, 2], None, None)
     self._testMostSpecificCompatibleShapeHelper(None, [1, 2], None)
diff --git a/tensorflow/python/keras/mixed_precision/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py
index 6d70aaaca27..3fd00608e38 100644
--- a/tensorflow/python/keras/mixed_precision/autocast_variable_test.py
+++ b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py
@@ -368,9 +368,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
       # mode. Variable.assign(...).op is None in Eager mode and an op in Graph
       # mode or a tf.function. We test this is also true of AutoCastVariable.
       if context.executing_eagerly():
-        with self.assertRaisesRegex(
-            AttributeError,
-            'Tensor.op is meaningless when eager execution is enabled'):
+        with self.assertRaises(AttributeError):
           x.op  # pylint: disable=pointless-statement
         self.assertIsNone(x.assign(1.0).op)
         self.assertIsNone(x.assign_add(1.0).op)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 3e053fca024..f0f97f6a054 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -962,6 +962,9 @@ def _slice_helper(tensor, slice_spec, var=None):
       tf.newaxis or scalar int32/int64 tensors.
   """
   tensor = ops.convert_to_tensor(tensor)
+  # TODO(wangpeng): Consider supporting var
+  if var is None and ops._numpy_style_slicing:  # pylint: disable=protected-access
+    return tensor._numpy_style_getitem(slice_spec)  # pylint: disable=protected-access
 
   if isinstance(slice_spec, bool) or \
   (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py
index dc4ee6bc5d7..75394ba7fe7 100644
--- a/tensorflow/python/ops/map_fn.py
+++ b/tensorflow/python/ops/map_fn.py
@@ -38,16 +38,10 @@ from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import deprecation
-from tensorflow.python.util import lazy_loader
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
 
-np_arrays = lazy_loader.LazyLoader(
-    "np_arrays", globals(),
-    "tensorflow.python.ops.numpy_ops.np_arrays")
-
-
 @tf_export(v1=["map_fn"])
 @deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
 def map_fn(fn,
@@ -426,8 +420,6 @@ def map_fn(fn,
 
     # Check that inputs are not scalars.
     first_elem = elems_flat[0]
-    if isinstance(first_elem, np_arrays.ndarray):
-      first_elem = first_elem.data
     elems_static_shape = first_elem.shape
     if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
       if len(elems_flat) == 1:
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index decad558e85..1cdc90193ab 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -70,6 +70,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import numbers
 import numpy as np
 import six
 from six.moves import builtins
@@ -99,9 +100,17 @@ from tensorflow.python.util import compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import dispatch
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_decorator
 from tensorflow.python.util.compat import collections_abc
+from tensorflow.python.util.lazy_loader import LazyLoader
 from tensorflow.python.util.tf_export import tf_export
 
+
+np_dtypes = LazyLoader(
+    "np_dtypes", globals(),
+    "tensorflow.python.ops.numpy_ops.np_dtypes")
+
+
 # Aliases for some automatically-generated names.
 nextafter = gen_math_ops.next_after
 
@@ -1130,6 +1139,48 @@ ops.Tensor._override_operator("__neg__", gen_math_ops.neg)
 ops.Tensor._override_operator("__abs__", abs)
 
 
+def _maybe_get_dtype(x):
+  """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
+  # Don't put np.ndarray in this list, because np.result_type looks at the
+  # value (not just dtype) of np.ndarray to decide the result type.
+  if isinstance(x, numbers.Real):
+    return x
+  if isinstance(x, ops.Tensor):
+    return x.dtype.as_numpy_dtype
+  if isinstance(x, dtypes.DType):
+    return x.as_numpy_dtype
+  if isinstance(x, tensor_shape.TensorShape):
+    return np.int32
+  if isinstance(x, (list, tuple)):
+    raise ValueError("Got sequence {}".format(x))
+  return x
+
+
+def maybe_promote_tensors(*tensors, force_same_dtype=True):
+  """Promote tensors if numpy style promotion is enabled."""
+  if not tensors:
+    return tensors
+  if not ops._numpy_style_type_promotion:
+    if not force_same_dtype:
+      return tensors
+    promoted_tensors = []
+    promoted_tensors.append(tensors[0])
+    dtype = tensors[0].dtype.base_dtype
+    for tensor in tensors[1:]:
+      promoted_tensors.append(
+          ops.convert_to_tensor(tensor, dtype, name="x"))
+    return promoted_tensors
+  result_type = np_dtypes._result_type(
+      *[_maybe_get_dtype(x) for x in nest.flatten(tensors)])
+  def _promote_or_cast(x):
+    if isinstance(x, ops.Tensor):
+      x = cast(x, result_type)
+    else:
+      x = ops.convert_to_tensor(x, result_type)
+    return x
+  return [_promote_or_cast(x) for x in tensors]
+
+
 def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
   """Register operators with different tensor and scalar versions.
 
@@ -1145,6 +1196,10 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
   def binary_op_wrapper(x, y):
     with ops.name_scope(None, op_name, [x, y]) as name:
       try:
+        # force_same_dtype=False to preserve existing TF behavior
+        # TODO(b/178860388): Figure out why binary_op_wrapper and
+        #   r_binary_op_wrapper use different force_same_dtype values.
+        x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
         return func(x, y, name=name)
       except (TypeError, ValueError) as e:
         # Even if dispatching the op failed, the RHS may be a tensor aware
@@ -1175,7 +1230,9 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
 
   def r_binary_op_wrapper(y, x):
     with ops.name_scope(None, op_name, [x, y]) as name:
-      x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
+      # TODO(b/178860388): Figure out why binary_op_wrapper and
+      #   r_binary_op_wrapper use different force_same_dtype values.
+      y, x = maybe_promote_tensors(y, x)
       return func(x, y, name=name)
 
   # Propagate func.__doc__ to the wrappers
@@ -1581,10 +1638,21 @@ _OverrideBinaryOperatorHelper(xor_, "xor")
 ops.Tensor._override_operator("__invert__", invert_)
 
 
-ops.Tensor._override_operator("__lt__", gen_math_ops.less)
-ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
-ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
-ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
+def _promote_dtypes_decorator(fn):
+  def wrapper(x, y, *args, **kwargs):
+    x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
+    return fn(x, y, *args, **kwargs)
+  return tf_decorator.make_decorator(fn, wrapper)
+
+
+ops.Tensor._override_operator("__lt__", _promote_dtypes_decorator(
+    gen_math_ops.less))
+ops.Tensor._override_operator("__le__", _promote_dtypes_decorator(
+    gen_math_ops.less_equal))
+ops.Tensor._override_operator("__gt__", _promote_dtypes_decorator(
+    gen_math_ops.greater))
+ops.Tensor._override_operator("__ge__", _promote_dtypes_decorator(
+    gen_math_ops.greater_equal))
 
 
 @tf_export("math.equal", "equal")
@@ -1691,6 +1759,7 @@ def tensor_equals(self, other):
   g = getattr(self, "graph", None)
   if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
       (g is None or g.building_function)):
+    self, other = maybe_promote_tensors(self, other)
     return gen_math_ops.equal(self, other, incompatible_shape_error=False)
   else:
     # In legacy graph mode, tensor equality is object equality
@@ -1727,6 +1796,7 @@ def tensor_not_equals(self, other):
   if other is None:
     return True
   if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
+    self, other = maybe_promote_tensors(self, other)
     return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
   else:
     # In legacy graph mode, tensor equality is object equality
@@ -3482,7 +3552,14 @@ def matvec(a,
     return array_ops.squeeze(output, axis=-1)
 
 
-_OverrideBinaryOperatorHelper(matmul, "matmul")
+# TODO(b/178650720): Also support numpy-style type promotion in freestanding TF
+#   functions (e.g. tf.add).
+def matmul_wrapper(a, b, name=None):  # pylint: disable=missing-function-docstring
+  if ops._numpy_style_type_promotion:
+    return a._matmul(b)
+  return matmul(a, b, name=name)
+matmul_wrapper.__doc__ = matmul.__doc__
+_OverrideBinaryOperatorHelper(matmul_wrapper, "matmul")
 
 sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")(
     gen_math_ops.sparse_mat_mul)
diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD
index 8b742f68c02..514e2ab0cad 100644
--- a/tensorflow/python/ops/numpy_ops/BUILD
+++ b/tensorflow/python/ops/numpy_ops/BUILD
@@ -13,6 +13,7 @@ py_library(
         "__init__.py",
         "np_array_ops.py",
         "np_arrays.py",
+        "np_config.py",
         "np_dtypes.py",
         "np_export.py",
         "np_math_ops.py",
@@ -40,6 +41,17 @@ py_library(
     ],
 )
 
+cuda_py_test(
+    name = "np_dtypes_test",
+    srcs = ["np_dtypes_test.py"],
+    deps = [
+        ":numpy",
+        "//tensorflow/python:platform",
+        "//third_party/py/numpy",
+        "@absl_py//absl/testing:parameterized",
+    ],
+)
+
 cuda_py_test(
     name = "np_arrays_test",
     srcs = ["np_arrays_test.py"],
diff --git a/tensorflow/python/ops/numpy_ops/integration_test/BUILD b/tensorflow/python/ops/numpy_ops/integration_test/BUILD
index e5483166406..06cce0c466e 100644
--- a/tensorflow/python/ops/numpy_ops/integration_test/BUILD
+++ b/tensorflow/python/ops/numpy_ops/integration_test/BUILD
@@ -1,4 +1,5 @@
 load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 
 licenses(["notice"])
 
@@ -10,3 +11,13 @@ py_test(
         "//tensorflow:tensorflow_py",
     ],
 )
+
+cuda_py_test(
+    name = "np_config_test",
+    srcs = ["np_config_test.py"],
+    deps = [
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python/ops/numpy_ops:numpy",
+        "//third_party/py/numpy",
+    ],
+)
diff --git a/tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py b/tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py
new file mode 100644
index 00000000000..014eba12960
--- /dev/null
+++ b/tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py
@@ -0,0 +1,44 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that an error is raised when numpy functions are called."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.compat.v2 as tf
+from tensorflow.python.ops.numpy_ops import np_config
+
+
+class ConfigTest(tf.test.TestCase):
+
+  def testMethods(self):
+    a = tf.constant(1.)
+
+    for name in {'T', 'astype', 'ravel', 'transpose', 'reshape', 'clip', 'size',
+                 'tolist'}:
+      with self.assertRaisesRegex(AttributeError, 'enable_numpy_behavior'):
+        getattr(a, name)
+
+    np_config.enable_numpy_behavior()
+
+    for name in {'T', 'astype', 'ravel', 'transpose', 'reshape', 'clip', 'size',
+                 'tolist'}:
+      _ = getattr(a, name)
+
+
+if __name__ == '__main__':
+  tf.compat.v1.enable_eager_execution()
+  tf.test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py
index 3d5c3f93d2e..042dc096586 100644
--- a/tensorflow/python/ops/numpy_ops/np_array_ops.py
+++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py
@@ -61,15 +61,11 @@ def empty_like(a, dtype=None):
 def zeros(shape, dtype=float):  # pylint: disable=redefined-outer-name
   dtype = (
       np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
-  if isinstance(shape, np_arrays.ndarray):
-    shape = shape.data
-  return np_arrays.tensor_to_ndarray(array_ops.zeros(shape, dtype=dtype))
+  return array_ops.zeros(shape, dtype=dtype)
 
 
 @np_utils.np_doc('zeros_like')
 def zeros_like(a, dtype=None):  # pylint: disable=missing-docstring
-  if isinstance(a, np_arrays.ndarray):
-    a = a.data
   if dtype is None:
     # We need to let np_utils.result_type decide the dtype, not tf.zeros_like
     dtype = np_utils.result_type(a)
@@ -78,27 +74,23 @@ def zeros_like(a, dtype=None):  # pylint: disable=missing-docstring
     # `float`, so we let `np_utils.result_type` decide.
     dtype = np_utils.result_type(dtype)
   dtype = dtypes.as_dtype(dtype)  # Work around b/149877262
-  return np_arrays.tensor_to_ndarray(array_ops.zeros_like(a, dtype))
+  return array_ops.zeros_like(a, dtype)
 
 
 @np_utils.np_doc('ones')
 def ones(shape, dtype=float):  # pylint: disable=redefined-outer-name
   if dtype:
     dtype = np_utils.result_type(dtype)
-  if isinstance(shape, np_arrays.ndarray):
-    shape = shape.data
-  return np_arrays.tensor_to_ndarray(array_ops.ones(shape, dtype=dtype))
+  return array_ops.ones(shape, dtype=dtype)
 
 
 @np_utils.np_doc('ones_like')
 def ones_like(a, dtype=None):
-  if isinstance(a, np_arrays.ndarray):
-    a = a.data
   if dtype is None:
     dtype = np_utils.result_type(a)
   else:
     dtype = np_utils.result_type(dtype)
-  return np_arrays.tensor_to_ndarray(array_ops.ones_like(a, dtype))
+  return array_ops.ones_like(a, dtype)
 
 
 @np_utils.np_doc('eye')
@@ -115,7 +107,7 @@ def eye(N, M=None, k=0, dtype=float):  # pylint: disable=invalid-name,missing-do
     # tf.linalg.diag will raise an error in this case
     return zeros([N, M], dtype=dtype)
   if k == 0:
-    return np_arrays.tensor_to_ndarray(linalg_ops.eye(N, M, dtype=dtype))
+    return linalg_ops.eye(N, M, dtype=dtype)
   # We need the precise length, otherwise tf.linalg.diag will raise an error
   diag_len = min(N, M)
   if k > 0:
@@ -129,8 +121,7 @@ def eye(N, M=None, k=0, dtype=float):  # pylint: disable=invalid-name,missing-do
     elif M - k > N:
       diag_len = N + k
   diagonal_ = array_ops.ones([diag_len], dtype=dtype)
-  return np_arrays.tensor_to_ndarray(
-      array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k))
+  return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)
 
 
 @np_utils.np_doc('identity')
@@ -142,10 +133,9 @@ def identity(n, dtype=float):
 def full(shape, fill_value, dtype=None):  # pylint: disable=redefined-outer-name
   if not isinstance(shape, np_arrays.ndarray):
     shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
-  shape = atleast_1d(shape).data
+  shape = atleast_1d(shape)
   fill_value = asarray(fill_value, dtype=dtype)
-  return np_arrays.tensor_to_ndarray(
-      array_ops.broadcast_to(fill_value.data, shape))
+  return array_ops.broadcast_to(fill_value, shape)
 
 
 # Using doc only here since np full_like signature doesn't seem to have the
@@ -160,19 +150,15 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):  #
   if shape:
     raise ValueError('Overriding the shape is not supported.')
 
-  a = asarray(a).data
+  a = asarray(a)
   dtype = dtype or np_utils.result_type(a)
   fill_value = asarray(fill_value, dtype=dtype)
-  return np_arrays.tensor_to_ndarray(
-      array_ops.broadcast_to(fill_value.data, array_ops.shape(a)))
+  return array_ops.broadcast_to(fill_value, array_ops.shape(a))
 
 
 def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-outer-name
   """Main implementation of np.array()."""
-  if isinstance(val, np_arrays.ndarray):
-    result_t = val.data
-  else:
-    result_t = val
+  result_t = val
 
   if not isinstance(result_t, ops.Tensor):
     if not dtype:
@@ -180,13 +166,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=red
     # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
     # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
     # while np.array allows them. We need to convert-then-cast.
-    def maybe_data(x):
-      if isinstance(x, np_arrays.ndarray):
-        return x.data
-      return x
 
-    # Handles lists of ndarrays
-    result_t = nest.map_structure(maybe_data, result_t)
     # EagerTensor conversion complains about "mixed types" when converting
     # tensors with no dtype information. This is because it infers types based
     # on one selected item in the list. So e.g. when converting [2., 2j]
@@ -204,7 +184,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=red
     result_t = array_ops.identity(result_t)
 
   if ndmin == 0:
-    return np_arrays.tensor_to_ndarray(result_t)
+    return result_t
 
   ndims = array_ops.rank(result_t)
 
@@ -216,7 +196,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=red
 
   result_t = np_utils.cond(
       np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
-  return np_arrays.tensor_to_ndarray(result_t)
+  return result_t
 
 
 # TODO(wangpeng): investigate whether we can make `copy` default to False.
@@ -241,7 +221,8 @@ def array(val, dtype=None, copy=True, ndmin=0):  # pylint: disable=redefined-out
 def asarray(a, dtype=None):
   if dtype:
     dtype = np_utils.result_type(dtype)
-  if isinstance(a, np_arrays.ndarray) and (not dtype or dtype == a.dtype):
+  if isinstance(a, np_arrays.ndarray) and (
+      not dtype or dtype == a.dtype.as_numpy_dtype):
     return a
   return array(a, dtype, copy=False)
 
@@ -294,15 +275,15 @@ def arange(start, stop=None, step=1, dtype=None):
     return array([], dtype=dtype)
   # TODO(srbs): There are some bugs when start or stop is float type and dtype
   # is integer type.
-  return np_arrays.tensor_to_ndarray(
-      math_ops.cast(math_ops.range(start, limit=stop, delta=step), dtype=dtype))
+  return math_ops.cast(
+      math_ops.range(start, limit=stop, delta=step), dtype=dtype)
 
 
 # Building matrices.
 @np_utils.np_doc('diag')
 def diag(v, k=0):  # pylint: disable=missing-docstring
   """Raises an error if input is not 1- or 2-d."""
-  v = asarray(v).data
+  v = asarray(v)
   v_rank = array_ops.rank(v)
 
   v.shape.with_rank_at_most(2)
@@ -331,20 +312,20 @@ def diag(v, k=0):  # pylint: disable=missing-docstring
 
   result = np_utils.cond(
       math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
-  return np_utils.tensor_to_ndarray(result)
+  return result
 
 
 @np_utils.np_doc('diagonal')
 def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstring
-  a = asarray(a).data
+  a = asarray(a)
 
   maybe_rank = a.shape.rank
   if maybe_rank is not None and offset == 0 and (
       axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
                                                    axis2 == -1):
-    return np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a))
+    return array_ops.matrix_diag_part(a)
 
-  a = moveaxis(np_utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data
+  a = moveaxis(a, (axis1, axis2), (-2, -1))
 
   a_shape = array_ops.shape(a)
 
@@ -361,20 +342,20 @@ def diagonal(a, offset=0, axis1=0, axis2=1):  # pylint: disable=missing-docstrin
           np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
       ), _zeros, lambda: (a, offset))
 
-  a = np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a, k=offset))
+  a = array_ops.matrix_diag_part(a, k=offset)
   return a
 
 
 @np_utils.np_doc('diagflat')
 def diagflat(v, k=0):
   v = asarray(v)
-  return diag(array_ops.reshape(v.data, [-1]), k)
+  return diag(array_ops.reshape(v, [-1]), k)
 
 
 def _promote_dtype(*arrays):
   dtype = np_utils.result_type(*arrays)
   def _fast_asarray(a):
-    if isinstance(a, np_arrays.ndarray) and dtype == a.dtype:
+    if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype:
       return a
     return _array_internal(a, dtype=dtype, copy=False)
   return [_fast_asarray(a) for a in arrays]
@@ -382,9 +363,11 @@ def _promote_dtype(*arrays):
 
 def _promote_dtype_binary(t1, t2):
   dtype = np_utils._result_type_binary(t1, t2)  # pylint: disable=protected-access
-  if not(isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype):
+  if not(
+      isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype):
     t1 = _array_internal(t1, dtype=dtype, copy=False)
-  if not(isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype):
+  if not(
+      isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype):
     t2 = _array_internal(t2, dtype=dtype, copy=False)
   return t1, t2
 
@@ -392,15 +375,13 @@ def _promote_dtype_binary(t1, t2):
 @np_utils.np_doc('all')
 def all(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
   a = asarray(a, dtype=bool)
-  return np_utils.tensor_to_ndarray(
-      math_ops.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims))
+  return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims)
 
 
 @np_utils.np_doc('any')
 def any(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
   a = asarray(a, dtype=bool)
-  return np_utils.tensor_to_ndarray(
-      math_ops.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims))
+  return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims)
 
 
 @np_utils.np_doc('compress')
@@ -425,13 +406,12 @@ def compress(condition, a, axis=None):  # pylint: disable=redefined-outer-name,m
 
   # `tf.boolean_mask` requires the first dimensions of array and condition to
   # match. `np.compress` pads condition with False when it is shorter.
-  condition_t = condition.data
-  a_t = a.data
+  condition_t = condition
+  a_t = a
   if condition.shape[0] < a.shape[axis]:
     padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False)
     condition_t = array_ops.concat([condition_t, padding], axis=0)
-  return np_utils.tensor_to_ndarray(
-      array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis))
+  return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)
 
 
 @np_utils.np_doc('copy')
@@ -443,8 +423,9 @@ def _maybe_promote_to_int(a):
   if dtypes.as_dtype(a.dtype).is_integer:
     # If a is an integer type and its precision is less than that of `int`,
     # the output type will be `int`.
-    output_type = np.promote_types(a.dtype, int)
-    if output_type != a.dtype:
+    a_numpy_dtype = a.dtype.as_numpy_dtype
+    output_type = np.promote_types(a_numpy_dtype, int)
+    if output_type != a_numpy_dtype:
       a = asarray(a, dtype=output_type)
 
   return a
@@ -462,8 +443,8 @@ def cumprod(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
     a = ravel(a)
     axis = 0
   elif axis < 0:
-    axis += array_ops.rank(a.data)
-  return np_utils.tensor_to_ndarray(math_ops.cumprod(a.data, axis))
+    axis += array_ops.rank(a)
+  return math_ops.cumprod(a, axis)
 
 
 @np_utils.np_doc('cumsum')
@@ -478,8 +459,8 @@ def cumsum(a, axis=None, dtype=None):  # pylint: disable=missing-docstring
     a = ravel(a)
     axis = 0
   elif axis < 0:
-    axis += array_ops.rank(a.data)
-  return np_utils.tensor_to_ndarray(math_ops.cumsum(a.data, axis))
+    axis += array_ops.rank(a)
+  return math_ops.cumsum(a, axis)
 
 
 @np_utils.np_doc('imag')
@@ -487,7 +468,7 @@ def imag(val):
   val = asarray(val)
   # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always
   # return an ndarray.
-  return np_utils.tensor_to_ndarray(math_ops.imag(val.data))
+  return math_ops.imag(val)
 
 
 _TO_INT_ = 0
@@ -532,10 +513,9 @@ def _reduce(tf_fn,
   a = asarray(a, dtype=dtype)
   if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and
       tf_bool_fn is not None):
-    return np_utils.tensor_to_ndarray(
-        tf_bool_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
+    return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims)
   if dtype is None:
-    dtype = a.dtype
+    dtype = a.dtype.as_numpy_dtype
     if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
       if promote_int == _TO_INT_:
         # If a is an integer/bool type and whose bit width is less than np.int_,
@@ -554,12 +534,15 @@ def _reduce(tf_fn,
             dtype = np.int_
           else:
             dtype = np.uint
-          a = a.astype(dtype)
+          a = math_ops.cast(a, dtype)
       elif promote_int == _TO_FLOAT:
-        a = a.astype(np_dtypes.default_float_type())
+        a = math_ops.cast(a, np_dtypes.default_float_type())
 
-  return np_utils.tensor_to_ndarray(
-      tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
+  if isinstance(axis, ops.Tensor) and axis.dtype not in (
+      dtypes.int32, dtypes.int64):
+    axis = math_ops.cast(axis, dtypes.int64)
+
+  return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims)
 
 
 # TODO (DarrenZhang01): Add `axis` support to the `size` API.
@@ -570,11 +553,11 @@ def size(x, axis=None):  # pylint: disable=missing-docstring
                               '`np.size` implementation')
   if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
     return 1
-  x = asarray(x).data
+  x = asarray(x)
   if x.shape.is_fully_defined():
     return np.prod(x.shape.as_list(), dtype=int)
   else:
-    return np_utils.tensor_to_ndarray(array_ops.size_v2(x))
+    return array_ops.size_v2(x)
 
 
 @np_utils.np_doc('sum')
@@ -677,10 +660,10 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):  # pylint: d
       axis=axis,
       dtype=working_dtype,
       keepdims=keepdims,
-      promote_int=_TO_FLOAT).data
+      promote_int=_TO_FLOAT)
   if dtype:
     result = math_ops.cast(result, dtype)
-  return np_utils.tensor_to_ndarray(result)
+  return result
 
 
 @np_utils.np_doc('std')
@@ -697,13 +680,7 @@ def std(a, axis=None, keepdims=None):  # pylint: disable=missing-function-docstr
 @np_utils.np_doc('ravel')
 def ravel(a):  # pylint: disable=missing-docstring
   a = asarray(a)
-  out = np_utils.cond(
-      math_ops.equal(a.ndim, 1), lambda: a.data,
-      lambda: array_ops.reshape(a.data, [-1]))
-  return np_utils.tensor_to_ndarray(out)
-
-
-setattr(np_arrays.ndarray, 'ravel', ravel)
+  return array_ops.reshape(a, [-1])
 
 
 @np_utils.np_doc('real')
@@ -711,12 +688,12 @@ def real(val):
   val = asarray(val)
   # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
   # return an ndarray.
-  return np_utils.tensor_to_ndarray(math_ops.real(val.data))
+  return math_ops.real(val)
 
 
 @np_utils.np_doc('repeat')
 def repeat(a, repeats, axis=None):  # pylint: disable=missing-docstring
-  a = asarray(a).data
+  a = asarray(a)
   original_shape = a._shape_as_list()  # pylint: disable=protected-access
   # Best effort recovery of the shape.
   known_shape = original_shape is not None and None not in original_shape
@@ -737,18 +714,18 @@ def repeat(a, repeats, axis=None):  # pylint: disable=missing-docstring
         else:
           original_shape[axis] = repeats_np.sum()
 
-  repeats = asarray(repeats).data
+  repeats = asarray(repeats)
   result = array_ops.repeat(a, repeats, axis)
   if known_shape:
     result.set_shape(original_shape)
 
-  return np_utils.tensor_to_ndarray(result)
+  return result
 
 
 @np_utils.np_doc('around')
 def around(a, decimals=0):  # pylint: disable=missing-docstring
   a = asarray(a)
-  dtype = a.dtype
+  dtype = a.dtype.as_numpy_dtype
   factor = math.pow(10, decimals)
   if np.issubdtype(dtype, np.inexact):
     factor = math_ops.cast(factor, dtype)
@@ -756,12 +733,12 @@ def around(a, decimals=0):  # pylint: disable=missing-docstring
     # Use float as the working dtype when a.dtype is exact (e.g. integer),
     # because `decimals` can be negative.
     float_dtype = np_dtypes.default_float_type()
-    a = a.astype(float_dtype).data
+    a = a.astype(float_dtype)
     factor = math_ops.cast(factor, float_dtype)
   a = math_ops.multiply(a, factor)
   a = math_ops.round(a)
   a = math_ops.divide(a, factor)
-  return np_utils.tensor_to_ndarray(a).astype(dtype)
+  return a.astype(dtype)
 
 
 setattr(np_arrays.ndarray, '__round__', around)
@@ -774,18 +751,16 @@ def reshape(a, newshape, order='C'):
     raise ValueError('Unsupported order argument {}'.format(order))
 
   a = asarray(a)
-  if isinstance(newshape, np_arrays.ndarray):
-    newshape = newshape.data
   if isinstance(newshape, int):
     newshape = [newshape]
 
   if order == 'F':
     r = array_ops.transpose(
-        array_ops.reshape(array_ops.transpose(a.data), newshape[::-1]))
+        array_ops.reshape(array_ops.transpose(a), newshape[::-1]))
   else:
-    r = array_ops.reshape(a.data, newshape)
+    r = array_ops.reshape(a, newshape)
 
-  return np_utils.tensor_to_ndarray(r)
+  return r
 
 
 def _reshape_method_wrapper(a, *newshape, **kwargs):
@@ -802,13 +777,13 @@ def _reshape_method_wrapper(a, *newshape, **kwargs):
 @np_utils.np_doc('expand_dims')
 def expand_dims(a, axis):
   a = asarray(a)
-  return np_utils.tensor_to_ndarray(array_ops.expand_dims(a.data, axis=axis))
+  return array_ops.expand_dims(a, axis=axis)
 
 
 @np_utils.np_doc('squeeze')
 def squeeze(a, axis=None):
   a = asarray(a)
-  return np_utils.tensor_to_ndarray(array_ops.squeeze(a, axis))
+  return array_ops.squeeze(a, axis)
 
 
 @np_utils.np_doc('transpose')
@@ -816,12 +791,12 @@ def transpose(a, axes=None):
   a = asarray(a)
   if axes is not None:
     axes = asarray(axes)
-  return np_utils.tensor_to_ndarray(array_ops.transpose(a=a.data, perm=axes))
+  return array_ops.transpose(a=a, perm=axes)
 
 
 @np_utils.np_doc('swapaxes')
 def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
-  a = asarray(a).data
+  a = asarray(a)
   def adjust_axes(axes, rank):
     def f(x):
       if isinstance(x, int):
@@ -848,7 +823,7 @@ def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
     perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
                                            [axis2, axis1])
   a = array_ops.transpose(a, perm)
-  return np_utils.tensor_to_ndarray(a)
+  return a
 
 
 @np_utils.np_doc('moveaxis')
@@ -857,7 +832,7 @@ def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
   if not source and not destination:
     return a
 
-  a = asarray(a).data
+  a = asarray(a)
 
   if isinstance(source, int):
     source = (source,)
@@ -908,13 +883,7 @@ def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
         perm, array_ops.expand_dims(destination, 1), source)
   a = array_ops.transpose(a, perm)
 
-  return np_utils.tensor_to_ndarray(a)
-
-
-# TODO(wangpeng): Make a custom `setattr` that also sets docstring for the
-#   method.
-setattr(np_arrays.ndarray, 'transpose', transpose)
-setattr(np_arrays.ndarray, 'reshape', _reshape_method_wrapper)
+  return a
 
 
 @np_utils.np_doc('pad')
@@ -926,12 +895,11 @@ def pad(array, pad_width, mode, **kwargs):  # pylint: disable=redefined-outer-na
   mode = mode.upper()
   array = asarray(array)
   pad_width = asarray(pad_width, dtype=dtypes.int32)
-  return np_utils.tensor_to_ndarray(
-      array_ops.pad(
-          tensor=array.data,
-          paddings=pad_width.data,
-          mode=mode,
-          constant_values=constant_values))
+  return array_ops.pad(
+      tensor=array,
+      paddings=pad_width,
+      mode=mode,
+      constant_values=constant_values)
 
 
 @np_utils.np_doc('take')
@@ -943,8 +911,8 @@ def take(a, indices, axis=None, out=None, mode='clip'):
   if mode not in {'raise', 'clip', 'wrap'}:
     raise ValueError("Invalid mode '{}' for take".format(mode))
 
-  a = asarray(a).data
-  indices = asarray(indices).data
+  a = asarray(a)
+  indices = asarray(indices)
 
   if axis is None:
     a = array_ops.reshape(a, [-1])
@@ -958,7 +926,7 @@ def take(a, indices, axis=None, out=None, mode='clip'):
   else:
     raise ValueError("The 'raise' mode to take is not supported.")
 
-  return np_utils.tensor_to_ndarray(array_ops.gather(a, indices, axis=axis))
+  return array_ops.gather(a, indices, axis=axis)
 
 
 @np_utils.np_doc_only('where')
@@ -969,8 +937,7 @@ def where(condition, x=None, y=None):
     return nonzero(condition)
   elif x is not None and y is not None:
     x, y = _promote_dtype(x, y)
-    return np_utils.tensor_to_ndarray(
-        array_ops.where_v2(condition.data, x.data, y.data))
+    return array_ops.where_v2(condition, x, y)
   raise ValueError('Both x and y must be ndarrays, or both must be None.')
 
 
@@ -1044,8 +1011,7 @@ def split(ary, indices_or_sections, axis=0):
   ary = asarray(ary)
   if not isinstance(indices_or_sections, six.integer_types):
     indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis)
-  result = array_ops.split(ary.data, indices_or_sections, axis=axis)
-  return [np_utils.tensor_to_ndarray(a) for a in result]
+  return array_ops.split(ary, indices_or_sections, axis=axis)
 
 
 def _split_on_axis(np_fun_name, axis):
@@ -1077,7 +1043,7 @@ def stack(arrays, axis=0):  # pylint: disable=missing-function-docstring
       return swapaxes(arrays, 0, axis)
   arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
   unwrapped_arrays = [
-      a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
+      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
   ]
   return asarray(array_ops.stack(unwrapped_arrays, axis))
 
@@ -1087,7 +1053,7 @@ def hstack(tup):
   arrays = [atleast_1d(a) for a in tup]
   arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
   unwrapped_arrays = [
-      a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
+      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
   ]
   rank = array_ops.rank(unwrapped_arrays[0])
   return np_utils.cond(
@@ -1101,7 +1067,7 @@ def vstack(tup):
   arrays = [atleast_2d(a) for a in tup]
   arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
   unwrapped_arrays = [
-      a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
+      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
   ]
   return array_ops.concat(unwrapped_arrays, axis=0)
 
@@ -1111,13 +1077,13 @@ def dstack(tup):
   arrays = [atleast_3d(a) for a in tup]
   arrays = _promote_dtype(*arrays)  # pylint: disable=protected-access
   unwrapped_arrays = [
-      a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
+      a if isinstance(a, np_arrays.ndarray) else a for a in arrays
   ]
   return array_ops.concat(unwrapped_arrays, axis=2)
 
 
 def _pad_left_to(n, old_shape):
-  old_shape = asarray(old_shape, dtype=np.int32).data
+  old_shape = asarray(old_shape, dtype=np.int32)
   new_shape = array_ops.pad(
       old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]],
       constant_values=1)
@@ -1143,8 +1109,8 @@ def _atleast_nd(n, new_shape, *arys):
     return asarray(
         np_utils.cond(
             np_utils.greater(n, array_ops.rank(x)),
-            lambda: reshape(x, new_shape(n, array_ops.shape(x.data))).data,
-            lambda: x.data))
+            lambda: reshape(x, new_shape(n, array_ops.shape(x))),
+            lambda: x))
 
   arys = list(map(f, arys))
   if len(arys) == 1:
@@ -1182,16 +1148,14 @@ def atleast_3d(*arys):  # pylint: disable=missing-docstring
 
 @np_utils.np_doc('nonzero')
 def nonzero(a):
-  a = atleast_1d(a).data
+  a = atleast_1d(a)
   if a.shape.rank is None:
     raise ValueError("The rank of `a` is unknown, so we can't decide how many "
                      'arrays to return.')
-  return nest.map_structure(
-      np_arrays.tensor_to_ndarray,
-      array_ops.unstack(
-          array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
-          a.shape.rank,
-          axis=1))
+  return array_ops.unstack(
+            array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
+            a.shape.rank,
+            axis=1)
 
 
 @np_utils.np_doc('diag_indices')
@@ -1231,12 +1195,12 @@ def tri(N, M=None, k=0, dtype=None):  # pylint: disable=invalid-name,missing-doc
       r = o
     else:
       r = array_ops.matrix_band_part(o, -1, k)
-  return np_utils.tensor_to_ndarray(r)
+  return r
 
 
 @np_utils.np_doc('tril')
 def tril(m, k=0):  # pylint: disable=missing-docstring
-  m = asarray(m).data
+  m = asarray(m)
   if m.shape.ndims is None:
     raise ValueError('Argument to tril should have known rank')
   m_shape = m.shape.as_list()
@@ -1251,14 +1215,13 @@ def tril(m, k=0):  # pylint: disable=missing-docstring
   z = constant_op.constant(0, m.dtype)
 
   mask = tri(*m_shape[-2:], k=k, dtype=bool)
-  return np_utils.tensor_to_ndarray(
-      array_ops.where_v2(
-          array_ops.broadcast_to(mask, array_ops.shape(m)), m, z))
+  return array_ops.where_v2(
+      array_ops.broadcast_to(mask, array_ops.shape(m)), m, z)
 
 
 @np_utils.np_doc('triu')
 def triu(m, k=0):  # pylint: disable=missing-docstring
-  m = asarray(m).data
+  m = asarray(m)
   if m.shape.ndims is None:
     raise ValueError('Argument to triu should have known rank')
   m_shape = m.shape.as_list()
@@ -1273,22 +1236,20 @@ def triu(m, k=0):  # pylint: disable=missing-docstring
   z = constant_op.constant(0, m.dtype)
 
   mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
-  return np_utils.tensor_to_ndarray(
-      array_ops.where_v2(
-          array_ops.broadcast_to(mask, array_ops.shape(m)), z, m))
+  return array_ops.where_v2(
+      array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)
 
 
 @np_utils.np_doc('flip')
 def flip(m, axis=None):  # pylint: disable=missing-docstring
-  m = asarray(m).data
+  m = asarray(m)
 
   if axis is None:
-    return np_utils.tensor_to_ndarray(
-        array_ops.reverse(m, math_ops.range(array_ops.rank(m))))
+    return array_ops.reverse(m, math_ops.range(array_ops.rank(m)))
 
   axis = np_utils._canonicalize_axis(axis, array_ops.rank(m))  # pylint: disable=protected-access
 
-  return np_utils.tensor_to_ndarray(array_ops.reverse(m, [axis]))
+  return array_ops.reverse(m, [axis])
 
 
 @np_utils.np_doc('flipud')
@@ -1303,15 +1264,15 @@ def fliplr(m):  # pylint: disable=missing-docstring
 
 @np_utils.np_doc('roll')
 def roll(a, shift, axis=None):  # pylint: disable=missing-docstring
-  a = asarray(a).data
+  a = asarray(a)
 
   if axis is not None:
-    return np_utils.tensor_to_ndarray(manip_ops.roll(a, shift, axis))
+    return manip_ops.roll(a, shift, axis)
 
   # If axis is None, the roll happens as a 1-d tensor.
   original_shape = array_ops.shape(a)
   a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
-  return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape))
+  return array_ops.reshape(a, original_shape)
 
 
 @np_utils.np_doc('rot90')
@@ -1336,7 +1297,7 @@ def rot90(m, k=1, axes=(0, 1)):  # pylint: disable=missing-docstring
 
 @np_utils.np_doc('vander')
 def vander(x, N=None, increasing=False):  # pylint: disable=missing-docstring,invalid-name
-  x = asarray(x).data
+  x = asarray(x)
 
   x_shape = array_ops.shape(x)
   N = N or x_shape[0]
@@ -1368,9 +1329,8 @@ def vander(x, N=None, increasing=False):  # pylint: disable=missing-docstring,in
     delta = -1
 
   x = array_ops.expand_dims(x, -1)
-  return np_utils.tensor_to_ndarray(
-      math_ops.pow(
-          x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)))
+  return math_ops.pow(
+      x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))
 
 
 @np_utils.np_doc('ix_')
@@ -1378,7 +1338,7 @@ def ix_(*args):  # pylint: disable=missing-docstring
   n = len(args)
   output = []
   for i, a in enumerate(args):
-    a = asarray(a).data
+    a = asarray(a)
     a_rank = array_ops.rank(a)
     a_rank_temp = np_utils.get_static_value(a_rank)
     if a_rank_temp is not None:
@@ -1393,11 +1353,9 @@ def ix_(*args):  # pylint: disable=missing-docstring
     new_shape[i] = -1
     dtype = a.dtype
     if dtype == dtypes.bool:
-      output.append(
-          np_utils.tensor_to_ndarray(
-              array_ops.reshape(nonzero(a)[0].data, new_shape)))
+      output.append(array_ops.reshape(nonzero(a)[0], new_shape))
     elif dtype.is_integer:
-      output.append(np_utils.tensor_to_ndarray(array_ops.reshape(a, new_shape)))
+      output.append(array_ops.reshape(a, new_shape))
     else:
       raise ValueError(
           'Only integer and bool dtypes are supported, got {}'.format(dtype))
@@ -1413,9 +1371,8 @@ def broadcast_arrays(*args, **kwargs):  # pylint: disable=missing-docstring
   if kwargs:
     raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
 
-  args = [asarray(arg).data for arg in args]
-  args = np_utils.tf_broadcast(*args)
-  return [np_utils.tensor_to_ndarray(arg) for arg in args]
+  args = [asarray(arg) for arg in args]
+  return np_utils.tf_broadcast(*args)
 
 
 @np_utils.np_doc_only('sign')
@@ -1428,13 +1385,13 @@ def sign(x, out=None, where=None, **kwargs):  # pylint: disable=missing-docstrin
     raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
 
   x = asarray(x)
-  dtype = x.dtype
+  dtype = x.dtype.as_numpy_dtype
   if np.issubdtype(dtype, np.complex):
-    result = math_ops.cast(math_ops.sign(math_ops.real(x.data)), dtype)
+    result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype)
   else:
-    result = math_ops.sign(x.data)
+    result = math_ops.sign(x)
 
-  return np_utils.tensor_to_ndarray(result)
+  return result
 
 
 # Note that np.take_along_axis may not be present in some supported versions of
@@ -1447,9 +1404,6 @@ def take_along_axis(arr, indices, axis):  # pylint: disable=missing-docstring
   if axis is None:
     return take_along_axis(arr.ravel(), indices, 0)
 
-  arr = arr.data
-  indices = indices.data
-
   rank = array_ops.rank(arr)
   axis = axis + rank if axis < 0 else axis
 
@@ -1475,7 +1429,7 @@ def take_along_axis(arr, indices, axis):  # pylint: disable=missing-docstring
   # Correct indices since gather doesn't correctly handle negative indices.
   indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
 
-  swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data
+  swapaxes_ = lambda t: swapaxes(t, axis, -1)
 
   dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1))
   arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
@@ -1495,7 +1449,7 @@ def take_along_axis(arr, indices, axis):  # pylint: disable=missing-docstring
                          lambda: swapaxes_(result))
   result.set_shape(possible_result_shape)
 
-  return np_utils.tensor_to_ndarray(result)
+  return result
 
 
 _SLICE_ERORR = (
@@ -1519,7 +1473,7 @@ def _as_index(idx, need_scalar=True):
   """
   if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
     return idx, True
-  data = asarray(idx).data
+  data = asarray(idx)
   if data.dtype == dtypes.bool:
     if data.shape.ndims != 1:
       # TODO(agarwal): handle higher rank boolean masks.
@@ -1730,14 +1684,14 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
             dims_contiguous = False
             break
     indices = [advanced_indices_map[x] for x in dims]
-    indices = [x.data for x in _promote_dtype(*indices)]
+    indices = _promote_dtype(*indices)
     indices = np_utils.tf_broadcast(*indices)
     stacked_indices = array_ops.stack(indices, axis=-1)
     # Skip the contiguous-dims optimization for update because there is no
     # tf.*scatter* op that supports the `axis` argument.
     if not dims_contiguous or updates is not None:
       if range(len(dims)) != dims:
-        tensor = moveaxis(tensor, dims, range(len(dims))).data
+        tensor = moveaxis(tensor, dims, range(len(dims)))
       tensor_shape_prefix = array_ops.shape(
           tensor, out_type=stacked_indices.dtype)[:len(dims)]
       stacked_indices = array_ops.where_v2(
@@ -1763,7 +1717,7 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
           def range_(start, length):
             return range(start, start + length)
           updates = moveaxis(updates, range_(batch_start, batch_size),
-                             range(batch_size)).data
+                             range(batch_size))
         if update_method == _UpdateMethod.UPDATE:
           update_op = array_ops.tensor_scatter_update
         elif update_method == _UpdateMethod.ADD:
@@ -1775,7 +1729,7 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
         tensor = update_op(
             tensor, stacked_indices, updates)
         if range(len(dims)) != dims:
-          tensor = moveaxis(tensor, range(len(dims)), dims).data
+          tensor = moveaxis(tensor, range(len(dims)), dims)
         return array_ops.tensor_strided_slice_update(
             original_tensor,
             packed_begin,
@@ -1842,14 +1796,13 @@ def _getitem(self, slice_spec):
                                        slice_spec.dtype == dtypes.bool) or
       (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
        slice_spec.dtype == np.bool)):
-    return np_utils.tensor_to_ndarray(
-        array_ops.boolean_mask(tensor=self.data, mask=slice_spec))
+    return array_ops.boolean_mask(tensor=self, mask=slice_spec)
 
   if not isinstance(slice_spec, tuple):
     slice_spec = _as_spec_tuple(slice_spec)
 
-  result_t = _slice_helper(self.data, slice_spec)
-  return np_utils.tensor_to_ndarray(result_t)
+  result_t = _slice_helper(self, slice_spec)
+  return result_t
 
 
 def _with_index_update_helper(update_method, a, slice_spec, updates):
@@ -1865,11 +1818,11 @@ def _with_index_update_helper(update_method, a, slice_spec, updates):
 
   a_dtype = a.dtype
   a, updates = _promote_dtype_binary(a, updates)
-  result_t = _slice_helper(a.data, slice_spec, update_method, updates.data)
-  return np_utils.tensor_to_ndarray(result_t).astype(a_dtype)
+  result_t = _slice_helper(a, slice_spec, update_method, updates)
+  return result_t.astype(a_dtype)
 
 
-setattr(np_arrays.ndarray, '__getitem__', _getitem)
+setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem)
 setattr(np_arrays.ndarray, '_with_index_update',
         functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE))
 setattr(np_arrays.ndarray, '_with_index_add',
diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py
index b3beb32793b..8fa324cbd8b 100644
--- a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.framework import tensor_spec
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops.numpy_ops import np_array_ops
 from tensorflow.python.ops.numpy_ops import np_arrays
+from tensorflow.python.ops.numpy_ops import np_math_ops
 from tensorflow.python.platform import test
 
 
@@ -305,49 +306,47 @@ class ArrayCreationTest(test.TestCase):
 
     def test_copy_equal_false():
       # Backing tensor is the same if copy=False, other attributes being None.
-      self.assertIs(
-          np_array_ops.array(zeros_list, copy=False).data, zeros_list.data)
-      self.assertIs(
-          np_array_ops.array(zeros_list.data, copy=False).data, zeros_list.data)
+      self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list)
+      self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list)
 
       # Backing tensor is different if ndmin is not satisfied.
       self.assertIsNot(
-          np_array_ops.array(zeros_list, copy=False, ndmin=2).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, ndmin=2),
+          zeros_list)
       self.assertIsNot(
-          np_array_ops.array(zeros_list.data, copy=False, ndmin=2).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, ndmin=2),
+          zeros_list)
       self.assertIs(
-          np_array_ops.array(zeros_list, copy=False, ndmin=1).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, ndmin=1),
+          zeros_list)
       self.assertIs(
-          np_array_ops.array(zeros_list.data, copy=False, ndmin=1).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, ndmin=1),
+          zeros_list)
 
       # Backing tensor is different if dtype is not satisfied.
       self.assertIsNot(
-          np_array_ops.array(zeros_list, copy=False, dtype=int).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, dtype=int),
+          zeros_list)
       self.assertIsNot(
-          np_array_ops.array(zeros_list.data, copy=False, dtype=int).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, dtype=int),
+          zeros_list)
       self.assertIs(
-          np_array_ops.array(zeros_list, copy=False, dtype=float).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, dtype=float),
+          zeros_list)
       self.assertIs(
-          np_array_ops.array(zeros_list.data, copy=False, dtype=float).data,
-          zeros_list.data)
+          np_array_ops.array(zeros_list, copy=False, dtype=float),
+          zeros_list)
 
     test_copy_equal_false()
     with ops.device('CPU:1'):
       test_copy_equal_false()
 
-    self.assertNotIn('CPU:1', zeros_list.data.backing_device)
+    self.assertNotIn('CPU:1', zeros_list.backing_device)
     with ops.device('CPU:1'):
-      self.assertIn('CPU:1', np_array_ops.array(zeros_list, copy=True).data
-                    .backing_device)
-      self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True).data
-                    .backing_device)
+      self.assertIn(
+          'CPU:1', np_array_ops.array(zeros_list, copy=True).backing_device)
+      self.assertIn(
+          'CPU:1', np_array_ops.array(np.array(0), copy=True).backing_device)
 
   def testAsArray(self):
     for a, dtype in itertools.product(self.all_arrays, self.all_types):
@@ -515,9 +514,6 @@ class ArrayCreationTest(test.TestCase):
       msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
           msg, expected.shape, actual.shape)
     self.assertEqual(actual.shape, expected.shape, msg=msg)
-    if msg:
-      msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
-    self.assertIsInstance(actual.shape, tuple, msg=msg)
 
   def match_dtype(self, actual, expected, msg=None):
     if msg:
@@ -535,7 +531,7 @@ class ArrayCreationTest(test.TestCase):
     self.match_dtype(actual, expected, msg)
     self.match_shape(actual, expected, msg)
     if not almost:
-      if not actual.shape:
+      if not actual.shape.rank:
         self.assertEqual(actual.tolist(), expected.tolist())
       else:
         self.assertSequenceEqual(actual.tolist(), expected.tolist())
@@ -636,11 +632,11 @@ class ArrayMethodsTest(test.TestCase):
     run_test(np.arange(9).reshape((3, 3)).tolist())
 
     a = np_array_ops.asarray(0)
-    self.assertNotIn('CPU:1', a.data.backing_device)
+    self.assertNotIn('CPU:1', a.backing_device)
     with ops.device('CPU:1'):
-      self.assertIn('CPU:1', np_array_ops.array(a, copy=True).data
+      self.assertIn('CPU:1', np_array_ops.array(a, copy=True)
                     .backing_device)
-      self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True).data
+      self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True)
                     .backing_device)
 
   def testCumProdAndSum(self):
@@ -824,12 +820,13 @@ class ArrayMethodsTest(test.TestCase):
     self.assertRaises(NotImplementedError, np_array_ops.size, np.ones((2, 2)),
                       1)
 
-    @def_function.function(input_signature=[tensor_spec.TensorSpec(shape=None)])
+    @def_function.function(input_signature=[
+        tensor_spec.TensorSpec(dtype=dtypes.float64, shape=None)])
     def f(arr):
       arr = np_array_ops.asarray(arr)
       return np_array_ops.size(arr)
 
-    self.assertEqual(f(np_array_ops.ones((3, 2))).data.numpy(), 6)
+    self.assertEqual(f(np_array_ops.ones((3, 2))).numpy(), 6)
 
   def testRavel(self):
 
@@ -984,9 +981,6 @@ class ArrayMethodsTest(test.TestCase):
       msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
           msg, expected.shape, actual.shape)
     self.assertEqual(actual.shape, expected.shape, msg=msg)
-    if msg:
-      msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
-    self.assertIsInstance(actual.shape, tuple, msg=msg)
 
   def match_dtype(self, actual, expected, msg=None):
     if msg:
@@ -1004,7 +998,7 @@ class ArrayMethodsTest(test.TestCase):
     if check_dtype:
       self.match_dtype(actual, expected, msg)
     self.match_shape(actual, expected, msg)
-    if not actual.shape:
+    if not actual.shape.rank:
       self.assertAllClose(actual.tolist(), expected.tolist())
     else:
       self.assertAllClose(actual.tolist(), expected.tolist())
@@ -1165,9 +1159,6 @@ class ArrayManipulationTest(test.TestCase):
       msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
           msg, expected.shape, actual.shape)
     self.assertEqual(actual.shape, expected.shape, msg=msg)
-    if msg:
-      msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
-    self.assertIsInstance(actual.shape, tuple, msg=msg)
 
   def match_dtype(self, actual, expected, msg=None):
     if msg:
@@ -1184,7 +1175,7 @@ class ArrayManipulationTest(test.TestCase):
     self.assertIsInstance(actual, np_arrays.ndarray)
     self.match_dtype(actual, expected, msg)
     self.match_shape(actual, expected, msg)
-    if not actual.shape:
+    if not actual.shape.rank:
       self.assertEqual(actual.tolist(), expected.tolist())
     else:
       self.assertSequenceEqual(actual.tolist(), expected.tolist())
@@ -1192,4 +1183,6 @@ class ArrayManipulationTest(test.TestCase):
 
 if __name__ == '__main__':
   ops.enable_eager_execution()
+  ops.enable_numpy_style_type_promotion()
+  np_math_ops.enable_numpy_methods_on_tensor()
   test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py
index ade758d36d3..cd879204a3e 100644
--- a/tensorflow/python/ops/numpy_ops/np_arrays.py
+++ b/tensorflow/python/ops/numpy_ops/np_arrays.py
@@ -20,18 +20,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
 import six
 
-from tensorflow.python.framework import composite_tensor
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_spec
-from tensorflow.python.framework import type_spec
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.numpy_ops import np_dtypes
-from tensorflow.python.ops.numpy_ops import np_export
 
 
 def convert_to_tensor(value, dtype=None, dtype_hint=None):
@@ -58,297 +51,4 @@ def convert_to_tensor(value, dtype=None, dtype_hint=None):
   return ops.convert_to_tensor(value, dtype=dtype, dtype_hint=dtype_hint)
 
 
-class NdarraySpec(type_spec.BatchableTypeSpec):
-  """Type specification for a `tf.experiemntal.numpy.ndarray`."""
-
-  value_type = property(lambda self: ndarray)
-
-  def __init__(self, data_spec):
-    if not isinstance(data_spec, tensor_spec.TensorSpec):
-      raise ValueError('NdarraySpec.__init__ was expecting a tf.TypeSpec, '
-                       'but got a {} instead.'.format(type(data_spec)))
-    self._data_spec = data_spec
-    self._hash = None
-
-  @property
-  def _component_specs(self):
-    return self._data_spec
-
-  def _to_components(self, value):
-    return value.data
-
-  def _from_components(self, data):
-    return tensor_to_ndarray(data)
-
-  def _serialize(self):
-    return (self._data_spec,)
-
-  def _batch(self, batch_size):
-    return NdarraySpec(self._data_spec._batch(batch_size))  # pylint: disable=protected-access
-
-  def _unbatch(self):
-    return NdarraySpec(self._data_spec._unbatch())  # pylint: disable=protected-access
-
-  def __hash__(self):
-    if self._hash is None:
-      self._hash = hash((type(self), self._data_spec))
-    return self._hash
-
-
-@np_export.np_export('ndarray')  # pylint: disable=invalid-name
-class ndarray(composite_tensor.CompositeTensor):
-  """Equivalent of numpy.ndarray backed by TensorFlow tensors.
-
-  This does not support all features of NumPy ndarrays e.g. strides and
-  memory order since, unlike NumPy, the backing storage is not a raw memory
-  buffer.
-
-  TODO(srbs): Clearly specify which attributes and methods are not supported
-  or if there are any differences in behavior.
-  """
-
-  __slots__ = ['_data', '_dtype', '_type_spec_internal']
-
-  def __init__(self, shape, dtype=float, buffer=None):  # pylint: disable=redefined-builtin
-    """Initializes an ndarray.
-
-    This is a low level interface for building ndarrays and should be avoided.
-    Users should instead use methods in array_creation.py.
-
-    This class provides a numpy.ndarray like interface for a TF Tensor with a
-    fully-defined shape. Note that, unlike the backing buffer of np.ndarray,
-    Tensors are immutable. So, operations like `__setitem__` are performed by
-    replacing the Tensor. This restricts the ability to implement NumPy `view`
-    semantics.
-
-    Compared to numpy.ndarray, this does not support `offset`, `strides`
-    and `order` arguments.
-
-    Args:
-      shape: The shape of the array. Must be a scalar, an iterable of integers
-        or a `TensorShape` object.
-      dtype: Optional. The dtype of the array. Must be a python type, a numpy
-        type or a tensorflow `DType` object.
-      buffer: Optional. The backing buffer of the array. Must have shape
-        `shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`.
-
-    Raises:
-      ValueError: If `buffer` is specified and its shape does not match
-       `shape`.
-    """
-    if dtype and not isinstance(dtype, dtypes.DType):
-      dtype = dtypes.as_dtype(np.dtype(dtype))
-    if buffer is None:
-      buffer = array_ops.zeros(shape, dtype=dtype)
-    else:
-      if isinstance(buffer, ndarray):
-        buffer = buffer.data
-      elif isinstance(buffer, np.ndarray):
-        # If `buffer` is a np.ndarray, the Tensor will share the underlying
-        # storage of the array.
-        buffer = convert_to_tensor(value=buffer, dtype=dtype)
-      elif not isinstance(buffer, ops.Tensor):
-        raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
-                         ' Tensor or np.ndarray.'.format(type(buffer)))
-
-      if shape is not None:
-        buffer.set_shape(shape)
-
-    assert isinstance(buffer, ops.Tensor)
-    if dtype and dtype != buffer.dtype:
-      buffer = math_ops.cast(buffer, dtype)
-    self._data = buffer
-    self._type_spec_internal = None
-    self._dtype = None
-
-  @classmethod
-  def from_tensor(cls, tensor):
-    o = cls.__new__(cls, None)
-    # pylint: disable=protected-access
-    o._data = tensor
-    o._dtype = None
-    o._type_spec_internal = None
-    # pylint: enable=protected-access
-    return o
-
-  @property
-  def _type_spec(self):
-    if self._type_spec_internal is None:
-      self._type_spec_internal = NdarraySpec(
-          type_spec.type_spec_from_value(self._data))
-    return self._type_spec_internal
-
-  @property
-  def data(self):
-    """Tensor object containing the array data.
-
-    This has a few key differences from the Python buffer object used in
-    NumPy arrays.
-    1. Tensors are immutable. So operations requiring in-place edit, e.g.
-       __setitem__, are performed by replacing the underlying buffer with a new
-       one.
-    2. Tensors do not provide access to their raw buffer.
-
-    Returns:
-      A Tensor.
-    """
-    return self._data
-
-  @property
-  def shape(self):
-    """Returns a tuple or tf.Tensor of array dimensions."""
-    shape = self.data.shape
-    if shape.is_fully_defined():
-      return tuple(shape.as_list())
-    else:
-      return array_ops.shape(self.data)
-
-  @property
-  def dtype(self):
-    if self._dtype is None:
-      self._dtype = np_dtypes._get_cached_dtype(self._data.dtype)  # pylint: disable=protected-access
-    return self._dtype
-
-  def _is_boolean(self):
-    return self._data.dtype == dtypes.bool
-
-  @property
-  def ndim(self):
-    ndims = self.data.shape.ndims
-    if ndims is None:
-      return array_ops.rank(self.data)
-    else:
-      return ndims
-
-  @property
-  def size(self):
-    """Returns the number of elements in the array."""
-    shape = self.shape
-    if isinstance(shape, ops.Tensor):
-      return array_ops.size(self.data)
-    else:
-      return np.prod(self.shape)
-
-  @property
-  def T(self):  # pylint: disable=invalid-name
-    return self.transpose()
-
-  def __len__(self):
-    shape = self.shape
-    if isinstance(shape, ops.Tensor):
-      raise TypeError('len() of symbolic tensor undefined')
-    elif shape:
-      return self.shape[0]
-    else:
-      raise TypeError('len() of unsized object.')
-
-  def astype(self, dtype):
-    if self.dtype == dtype:
-      return self
-    else:
-      return tensor_to_ndarray(math_ops.cast(self.data, dtype))
-
-  # Unary operations
-  def __neg__(self):
-    return tensor_to_ndarray(-self.data)  # pylint: disable=invalid-unary-operand-type
-
-  def __pos__(self):
-    return self
-
-  __hash__ = None
-
-  def __int__(self):
-    return int(self.data)
-
-  def __float__(self):
-    return float(self.data)
-
-  def __bool__(self):
-    return bool(self.data)
-
-  def __nonzero__(self):
-    return self.__bool__()
-
-  def __iter__(self):
-    if not isinstance(self.data, ops.EagerTensor):
-      raise TypeError('Iteration over symbolic tensor is not allowed')
-    for i in range(self.shape[0]):
-      result_t = self.data[i]
-      yield tensor_to_ndarray(result_t)
-    return
-
-  def __array__(self, dtype=None):
-    """Returns a NumPy ndarray.
-
-    This allows instances of this class to be directly used in NumPy routines.
-    However, doing that may force a copy to CPU.
-
-    Args:
-      dtype: A NumPy compatible type.
-
-    Returns:
-      A NumPy ndarray.
-    """
-    return np.asarray(self.data, dtype)
-
-  # NOTE: we currently prefer interop with TF to allow TF to take precedence.
-  __array_priority__ = 90
-
-  def __array_module__(self, types):
-    # Experimental support for NumPy's module dispatch with NEP-37:
-    # https://numpy.org/neps/nep-0037-array-module.html
-    # Currently requires https://github.com/seberg/numpy-dispatch
-
-    # pylint: disable=g-import-not-at-top
-    import tensorflow.compat.v2 as tf
-
-    if all(issubclass(t, (ndarray, np.ndarray)) for t in types):
-      return tf.experimental.numpy
-    else:
-      return NotImplemented
-
-  def __index__(self):
-    """Returns a python scalar.
-
-    This allows using an instance of this class as an array index.
-    Note that only arrays of integer types with size 1 can be used as array
-    indices.
-
-    Returns:
-      A Python scalar.
-
-    Raises:
-      TypeError: If the array is not of an integer type.
-      ValueError: If the array does not have size 1.
-    """
-    # TODO(wangpeng): Handle graph mode
-    if not isinstance(self.data, ops.EagerTensor):
-      raise TypeError('Indexing using symbolic tensor is not allowed')
-    return self.data.numpy().item()
-
-  def tolist(self):
-    return self.data.numpy().tolist()
-
-  def __str__(self):
-    return 'ndarray<{}>'.format(self.data.__str__())
-
-  def __repr__(self):
-    return 'ndarray<{}>'.format(self.data.__repr__())
-
-
-def tensor_to_ndarray(tensor):
-  return ndarray.from_tensor(tensor)
-
-
-def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
-  if as_ref:
-    raise ValueError('as_ref is not supported.')
-  if dtype and dtypes.as_dtype(arr.dtype) != dtype:
-    return math_ops.cast(arr.data, dtype)
-  result_t = arr.data
-  if name:
-    result_t = array_ops.identity(result_t, name=name)
-  return result_t
-
-
-ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
+ndarray = ops.Tensor
diff --git a/tensorflow/python/ops/numpy_ops/np_arrays_test.py b/tensorflow/python/ops/numpy_ops/np_arrays_test.py
index ab407d2bfcf..782e3f36617 100644
--- a/tensorflow/python/ops/numpy_ops/np_arrays_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_arrays_test.py
@@ -18,11 +18,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import collections
-
 import numpy as np
 
-from tensorflow.python.framework import constant_op
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -32,48 +30,33 @@ from tensorflow.python.ops.numpy_ops import np_math_ops  # pylint: disable=unuse
 from tensorflow.python.platform import test
 from tensorflow.python.util import nest
 
-t2a = np_arrays.tensor_to_ndarray
-
 
 class ArrayTest(test.TestCase):
 
   def testDtype(self):
-    a = t2a(array_ops.zeros(shape=[1, 2], dtype=dtypes.int64))
-    self.assertIs(a.dtype.type, np.int64)
-    self.assertAllEqual(0, a.dtype.type(0))
+    a = array_ops.zeros(shape=[1, 2], dtype=dtypes.int64)
+    self.assertIs(a.dtype.as_numpy_dtype, np.int64)
+    np_dt = a.dtype.as_numpy_dtype
+    self.assertAllEqual(0, np_dt(0))
 
   def testAstype(self):
-    a = t2a(ops.convert_to_tensor(value=1.1,
-                                  dtype=dtypes.float32)).astype(np.int32)
-    self.assertIs(a.dtype.type, np.int32)
+    a = ops.convert_to_tensor(value=1.1, dtype=dtypes.float32).astype(np.int32)
+    self.assertIs(a.dtype.as_numpy_dtype, np.int32)
     self.assertAllEqual(1, a)
-    a = t2a(ops.convert_to_tensor(value=[0.0, 1.1],
-                                  dtype=dtypes.float32)).astype(np.bool_)
-    self.assertIs(a.dtype.type, np.bool_)
+    a = ops.convert_to_tensor(value=[0.0, 1.1], dtype=dtypes.float32).astype(
+        np.bool_)
+    self.assertIs(a.dtype.as_numpy_dtype, np.bool_)
     self.assertAllEqual([False, True], a)
 
-  def testConstructor(self):
-    t = constant_op.constant([[1], [1]])
-    a = np_arrays.ndarray(shape=(2, 1), buffer=t)
-    self.assertAllEqual(t, a)
-    self.assertEqual(dtypes.float64, a.dtype)
-
-    a = np_arrays.ndarray(shape=(2, 1), dtype=dtypes.int32, buffer=t)
-    self.assertAllEqual(t, a)
-    self.assertEqual(dtypes.int32, a.dtype)
-
-    with self.assertRaises(ValueError):  # bad shape
-      _ = np_arrays.ndarray((2, 2), buffer=t)
-
   def testNeg(self):
-    a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
-    self.assertAllEqual([-1.0, -2.0], -a)
+    a = ops.convert_to_tensor(value=[1.0, 2.0])
+    self.assertAllEqual([-1.0, -2.0], -a)  # pylint: disable=invalid-unary-operand-type
 
   def _testBinOp(self, a, b, out, f, types=None):
-    a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
-    b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
+    a = ops.convert_to_tensor(value=a, dtype=np.int32)
+    b = ops.convert_to_tensor(value=b, dtype=np.int32)
     if not isinstance(out, np_arrays.ndarray):
-      out = t2a(ops.convert_to_tensor(value=out, dtype=np.int32))
+      out = ops.convert_to_tensor(value=out, dtype=np.int32)
     if types is None:
       types = [[np.int32, np.int32, np.int32], [np.int64, np.int32, np.int64],
                [np.int32, np.int64, np.int64],
@@ -84,7 +67,7 @@ class ArrayTest(test.TestCase):
                [np.float32, np.float64, np.float64]]
     for a_type, b_type, out_type in types:
       o = f(a.astype(a_type), b.astype(b_type))
-      self.assertIs(o.dtype.type, out_type)
+      self.assertIs(o.dtype.as_numpy_dtype, out_type)
       out = out.astype(out_type)
       if np.issubdtype(out_type, np.inexact):
         self.assertAllClose(out, o)
@@ -126,19 +109,20 @@ class ArrayTest(test.TestCase):
 
   def testTruediv(self):
     self._testBinOp([3, 5], [2, 4],
-                    t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
+                    ops.convert_to_tensor(value=[1.5, 1.25]),
                     lambda a, b: a.__truediv__(b),
                     types=self._truediv_types)
 
   def testRtruediv(self):
     self._testBinOp([3, 5], [2, 4],
-                    t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
+                    ops.convert_to_tensor(value=[1.5, 1.25]),
                     lambda a, b: b.__rtruediv__(a),
                     types=self._truediv_types)
 
   def _testCmp(self, a, b, out, f):
-    a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
-    b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
+    a = ops.convert_to_tensor(value=a, dtype=np.int32)
+    b = ops.convert_to_tensor(value=b, dtype=np.int32)
+
     types = [[np.int32, np.int32], [np.int64, np.int32], [np.int32, np.int64],
              [np.float32, np.int32], [np.int32, np.float32],
              [np.float32, np.float32], [np.float64, np.float32],
@@ -173,32 +157,41 @@ class ArrayTest(test.TestCase):
 
   def testInt(self):
     v = 10
-    u = int(t2a(ops.convert_to_tensor(value=v)))
+    u = int(ops.convert_to_tensor(value=v))
     self.assertIsInstance(u, int)
     self.assertAllEqual(v, u)
 
   def testFloat(self):
     v = 21.32
-    u = float(t2a(ops.convert_to_tensor(value=v)))
+    u = float(ops.convert_to_tensor(value=v))
     self.assertIsInstance(u, float)
     self.assertAllClose(v, u)
 
   def testBool(self):
-    b = bool(t2a(ops.convert_to_tensor(value=10)))
+    b = bool(ops.convert_to_tensor(value=10))
     self.assertIsInstance(b, bool)
     self.assertTrue(b)
-    self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0))))
-    self.assertTrue(bool(t2a(ops.convert_to_tensor(value=0.1))))
-    self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0.0))))
+    self.assertFalse(bool(ops.convert_to_tensor(value=0)))
+    self.assertTrue(bool(ops.convert_to_tensor(value=0.1)))
+    self.assertFalse(bool(ops.convert_to_tensor(value=0.0)))
 
   def testHash(self):
-    a = t2a(ops.convert_to_tensor(value=10))
-    self.assertNotIsInstance(a, collections.Hashable)
-    with self.assertRaisesWithPredicateMatch(TypeError, r'unhashable type'):
+    a = ops.convert_to_tensor(value=10)
+    def eager():
       hash(a)
+    def graph():
+      @def_function.function
+      def f(x):
+        hash(x)
+      f(a)
+    for f in [eager, graph]:
+      with self.assertRaisesRegexp(
+          TypeError,
+          r'Tensor is unhashable. Instead, use tensor.ref\(\) as the key.'):
+        f()
 
   def testFromToCompositeTensor(self):
-    tensors = [t2a(ops.convert_to_tensor(0.1)), t2a(ops.convert_to_tensor(0.2))]
+    tensors = [ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)]
 
     flattened = nest.flatten(tensors, expand_composites=True)
     # Each ndarray contains only one tensor, so the flattened output should be
@@ -216,6 +209,10 @@ class ArrayTest(test.TestCase):
 
 
 if __name__ == '__main__':
-  # TODO(wangpeng): Test in graph mode as well.
+  # TODO(wangpeng): Test in graph mode as well. Also test in V2 (the requirement
+  # for setting _USE_EQUALITY points to V2 behavior not being on).
   ops.enable_eager_execution()
+  ops.Tensor._USE_EQUALITY = True
+  ops.enable_numpy_style_type_promotion()
+  np_math_ops.enable_numpy_methods_on_tensor()
   test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_config.py b/tensorflow/python/ops/numpy_ops/np_config.py
new file mode 100644
index 00000000000..05fc64ffd76
--- /dev/null
+++ b/tensorflow/python/ops/numpy_ops/np_config.py
@@ -0,0 +1,39 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Config functions for TF NumPy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops.numpy_ops import np_dtypes
+from tensorflow.python.ops.numpy_ops import np_math_ops
+
+
+def enable_numpy_behavior(prefer_float32=False):
+  """Enable NumPy behavior on Tensors.
+
+  Includes addition of methods, type promotion on operator overloads and
+  support for NumPy-style slicing.
+
+  Args:
+    prefer_float32: Whether to allow type inference to use float32, or use
+    float64 similar to NumPy.
+  """
+  ops.enable_numpy_style_type_promotion()
+  ops.enable_numpy_style_slicing()
+  np_math_ops.enable_numpy_methods_on_tensor()
+  np_dtypes.set_prefer_float32(prefer_float32)
diff --git a/tensorflow/python/ops/numpy_ops/np_dtypes.py b/tensorflow/python/ops/numpy_ops/np_dtypes.py
index cde3883d3d9..1f4bb97e380 100644
--- a/tensorflow/python/ops/numpy_ops/np_dtypes.py
+++ b/tensorflow/python/ops/numpy_ops/np_dtypes.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.framework import dtypes
 from tensorflow.python.ops.numpy_ops import np_export
 
 
@@ -63,9 +64,27 @@ _to_float32 = {
 
 _cached_np_dtypes = {}
 
+
+# Difference between is_prefer_float32 and is_allow_float64: is_prefer_float32
+# only decides which dtype to use for Python floats; is_allow_float64 decides
+# whether float64 dtypes can ever appear in programs. The latter is more
+# restrictive than the former.
+_prefer_float32 = False
+
+
+# TODO(b/178862061): Consider removing this knob
 _allow_float64 = True
 
 
+def is_prefer_float32():
+  return _prefer_float32
+
+
+def set_prefer_float32(b):
+  global _prefer_float32
+  _prefer_float32 = b
+
+
 def is_allow_float64():
   return _allow_float64
 
@@ -85,8 +104,13 @@ def canonicalize_dtype(dtype):
 
 
 def _result_type(*arrays_and_dtypes):
+  def preprocess_float(x):
+    if is_prefer_float32() and isinstance(x, float):
+      return np.float32(x)
+    return x
+  arrays_and_dtypes = [preprocess_float(x) for x in arrays_and_dtypes]
   dtype = np.result_type(*arrays_and_dtypes)
-  return canonicalize_dtype(dtype)
+  return dtypes.as_dtype(canonicalize_dtype(dtype))
 
 
 def _get_cached_dtype(dtype):
@@ -105,9 +129,10 @@ def default_float_type():
   """Gets the default float type.
 
   Returns:
-    If `is_allow_float64()` is true, returns float64; otherwise returns float32.
+    If `is_prefer_float32()` is false and `is_allow_float64()` is true, returns
+    float64; otherwise returns float32.
   """
-  if is_allow_float64():
+  if not is_prefer_float32() and is_allow_float64():
     return float64
   else:
     return float32
diff --git a/tensorflow/python/ops/numpy_ops/np_dtypes_test.py b/tensorflow/python/ops/numpy_ops/np_dtypes_test.py
new file mode 100644
index 00000000000..b5a3ab9c325
--- /dev/null
+++ b/tensorflow/python/ops/numpy_ops/np_dtypes_test.py
@@ -0,0 +1,57 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf-numpy dtype utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops.numpy_ops import np_dtypes
+from tensorflow.python.platform import test
+
+
+class DTypeTest(test.TestCase, parameterized.TestCase):
+
+  @parameterized.parameters([False, True])
+  def testAllowF64False(self, prefer_f32):
+    np_dtypes.set_allow_float64(False)
+    np_dtypes.set_prefer_float32(prefer_f32)
+    self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
+    self.assertEqual(dtypes.float32,
+                     np_dtypes._result_type(np.zeros([], np.float64), 1.1))
+
+  def testAllowF64TruePreferF32False(self):
+    np_dtypes.set_allow_float64(True)
+    np_dtypes.set_prefer_float32(False)
+    self.assertEqual(dtypes.float64, np_dtypes.default_float_type())
+    self.assertEqual(dtypes.float64, np_dtypes._result_type(1.1))
+
+  def testAllowF64TruePreferF32True(self):
+    np_dtypes.set_allow_float64(True)
+    np_dtypes.set_prefer_float32(True)
+    self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
+    self.assertEqual(dtypes.float32, np_dtypes._result_type(1.1))
+    self.assertEqual(dtypes.float64,
+                     np_dtypes._result_type(np.zeros([], np.float64), 1.1))
+
+
+if __name__ == '__main__':
+  ops.enable_eager_execution()
+  test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_interop_test.py b/tensorflow/python/ops/numpy_ops/np_interop_test.py
index 8999c8f832e..d265b5e7d66 100644
--- a/tensorflow/python/ops/numpy_ops/np_interop_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_interop_test.py
@@ -21,7 +21,9 @@ from __future__ import print_function
 import numpy as onp
 import tensorflow.compat.v2 as tf
 
+from tensorflow.python.framework import ops
 from tensorflow.python.ops import numpy_ops as np
+from tensorflow.python.ops.numpy_ops import np_math_ops
 
 
 # Tests for code snippet put in README.md
@@ -174,27 +176,26 @@ class InteropTest(tf.test.TestCase):
     self.assertIsInstance(sq, onp.ndarray)
     self.assertEqual(100., sq[0])
 
+# TODO(b/171313773): why doesn't tensor have __array_module__
   def testArrayModule(self):
+    self.skipTest("Tensor doesn't have __array_module__")
     arr = np.asarray([10])
 
-    module = arr.__array_module__((np.ndarray,))
+    module = arr.__array_module__((tf.Tensor,))
     self.assertIs(module, tf.experimental.numpy)
 
     class Dummy:
       pass
-    module = arr.__array_module__((np.ndarray, Dummy))
+    module = arr.__array_module__((tf.Tensor, Dummy))
     self.assertIs(module, NotImplemented)
 
-    # TODO(nareshmodi): Fails since the autopacking code doesn't use
-    # nest.flatten.
-
-
+# TODO(nareshmodi): Fails since the autopacking code doesn't use
+# nest.flatten.
 #   def testAutopacking(self):
 #     arr1 = np.asarray(1.)
 #     arr2 = np.asarray(2.)
 #     arr3 = np.asarray(3.)
 #     t = ops.convert_to_tensor_v2([arr1, arr2, arr3])
-
 #     self.assertEqual(t.numpy(), [1., 2., 3.])
 
   def testDistStratInterop(self):
@@ -409,7 +410,9 @@ class FunctionTest(InteropTest):
 
   def testLen(self):
 
-    @tf.function
+    # len can be fixed by autograph.
+    # TODO(wangpeng): this test can just be removed
+    @tf.function(autograph=False)
     def f(x):
       # Note that shape of input to len is data dependent.
       return len(np.where(x)[0])
@@ -451,5 +454,7 @@ class VariableTest(InteropTest):
 
 
 if __name__ == '__main__':
+  ops.enable_numpy_style_type_promotion()
+  np_math_ops.enable_numpy_methods_on_tensor()
   tf.compat.v1.enable_eager_execution()
   tf.test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_logic_test.py b/tensorflow/python/ops/numpy_ops/np_logic_test.py
index 85826873356..9e38a87e70c 100644
--- a/tensorflow/python/ops/numpy_ops/np_logic_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_logic_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tf numpy random number methods."""
+"""Tests for tf numpy logical methods."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -76,9 +76,6 @@ class LogicTest(test.TestCase):
       msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
           msg, expected.shape, actual.shape)
     self.assertEqual(actual.shape, expected.shape, msg=msg)
-    if msg:
-      msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
-    self.assertIsInstance(actual.shape, tuple, msg=msg)
 
   def match_dtype(self, actual, expected, msg=None):
     if msg:
@@ -95,16 +92,17 @@ class LogicTest(test.TestCase):
     self.assertIsInstance(actual, np_arrays.ndarray)
     self.match_dtype(actual, expected, msg)
     self.match_shape(actual, expected, msg)
-    if not actual.shape:
+    if not actual.shape.rank:
       self.assertEqual(actual.tolist(), expected.tolist())
     else:
       self.assertSequenceEqual(actual.tolist(), expected.tolist())
 
 
 def make_numpy_compatible(s):
-  return s if not isinstance(s, np_arrays.ndarray) else s.data.numpy()
+  return s if not isinstance(s, np_arrays.ndarray) else s.numpy()
 
 
 if __name__ == '__main__':
   ops.enable_eager_execution()
+  np_math_ops.enable_numpy_methods_on_tensor()
   test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py
index 85cfdf6c5b8..1fd90df06c8 100644
--- a/tensorflow/python/ops/numpy_ops/np_math_ops.py
+++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py
@@ -74,7 +74,7 @@ def _bin_op(tf_fun, a, b, promote=True):
   else:
     a = np_array_ops.array(a)
     b = np_array_ops.array(b)
-  return np_utils.tensor_to_ndarray(tf_fun(a.data, b.data))
+  return tf_fun(a, b)
 
 
 @np_utils.np_doc('add')
@@ -177,9 +177,8 @@ def maximum(x1, x2):  # pylint: disable=missing-function-docstring
   # Fast path for when maximum is used as relu.
   if isinstance(
       x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance(
-          x1, np_arrays.ndarray) and not x1._is_boolean():  # pylint: disable=protected-access
-    return np_utils.tensor_to_ndarray(
-        nn_ops.relu(np_array_ops.asarray(x1).data))
+          x1, np_arrays.ndarray) and x1.dtype != dtypes.bool:
+    return nn_ops.relu(np_array_ops.asarray(x1))
 
   def max_or_or(x1, x2):
     if x1.dtype == dtypes.bool:
@@ -212,12 +211,7 @@ def clip(a, a_min, a_max):  # pylint: disable=missing-docstring
     return maximum(a, a_min)
   else:
     a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max)  # pylint: disable=protected-access
-    return np_utils.tensor_to_ndarray(
-        clip_ops.clip_by_value(
-            *np_utils.tf_broadcast(a.data, a_min.data, a_max.data)))
-
-
-setattr(np_arrays.ndarray, 'clip', clip)
+    return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max))
 
 
 @np_utils.np_doc('matmul')
@@ -241,6 +235,12 @@ def matmul(x1, x2):  # pylint: disable=missing-docstring
   return _bin_op(f, x1, x2)
 
 
+# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles
+# batched matmul as well, so simply including promotion in TF's current
+# __matmul__ implementation was not sufficient.
+setattr(np_arrays.ndarray, '_matmul', matmul)
+
+
 @np_utils.np_doc('tensordot')
 def tensordot(a, b, axes=2):
   return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b)
@@ -375,7 +375,7 @@ def heaviside(x1, x2):  # pylint: disable=missing-function-docstring
         array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2))
 
   y = _bin_op(f, x1, x2)
-  if not np.issubdtype(y.dtype, np.inexact):
+  if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact):
     y = y.astype(np_dtypes.default_float_type())
   return y
 
@@ -392,13 +392,13 @@ def kron(a, b):  # pylint: disable=missing-function-docstring
   t_a = np_utils.cond(
       a.ndim < b.ndim,
       lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
-          a.data, np_array_ops._pad_left_to(b.ndim, a.shape)),
-      lambda: a.data)
+          a, np_array_ops._pad_left_to(b.ndim, a.shape)),
+      lambda: a)
   t_b = np_utils.cond(
       b.ndim < a.ndim,
       lambda: np_array_ops.reshape(  # pylint: disable=g-long-lambda
-          b.data, np_array_ops._pad_left_to(a.ndim, b.shape)),
-      lambda: b.data)
+          b, np_array_ops._pad_left_to(a.ndim, b.shape)),
+      lambda: b)
 
   def _make_shape(shape, prepend):
     ones = array_ops.ones_like(shape)
@@ -596,9 +596,9 @@ def _scalar(tf_fn, x, promote_to_float=False):
     floating point type, in which case the output type is same as x.dtype.
   """
   x = np_array_ops.asarray(x)
-  if promote_to_float and not np.issubdtype(x.dtype, np.inexact):
+  if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact):
     x = x.astype(np_dtypes.default_float_type())
-  return np_utils.tensor_to_ndarray(tf_fn(x.data))
+  return tf_fn(x)
 
 
 @np_utils.np_doc('log')
@@ -814,7 +814,7 @@ def isreal(x):
 @np_utils.np_doc('iscomplexobj')
 def iscomplexobj(x):
   x = np_array_ops.array(x)
-  return np.issubdtype(x.dtype, np.complexfloating)
+  return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating)
 
 
 @np_utils.np_doc('isrealobj')
@@ -850,11 +850,12 @@ nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1)
 @np_utils.np_doc('nanmean')
 def nanmean(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=missing-docstring
   a = np_array_ops.array(a)
-  if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer):
+  if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype(
+      a.dtype.as_numpy_dtype, np.integer):
     return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
   nan_mask = logical_not(isnan(a))
   if dtype is None:
-    dtype = a.dtype
+    dtype = a.dtype.as_numpy_dtype
   normalizer = np_array_ops.sum(
       nan_mask, axis=axis, dtype=dtype, keepdims=keepdims)
   return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer
@@ -960,37 +961,16 @@ def _wrap(f, reverse=False):
   return _f
 
 
-setattr(np_arrays.ndarray, '__abs__', absolute)
-setattr(np_arrays.ndarray, '__floordiv__', _wrap(floor_divide))
-setattr(np_arrays.ndarray, '__rfloordiv__', _wrap(floor_divide, True))
-setattr(np_arrays.ndarray, '__mod__', _wrap(mod))
-setattr(np_arrays.ndarray, '__rmod__', _wrap(mod, True))
-setattr(np_arrays.ndarray, '__add__', _wrap(add))
-setattr(np_arrays.ndarray, '__radd__', _wrap(add, True))
-setattr(np_arrays.ndarray, '__sub__', _wrap(subtract))
-setattr(np_arrays.ndarray, '__rsub__', _wrap(subtract, True))
-setattr(np_arrays.ndarray, '__mul__', _wrap(multiply))
-setattr(np_arrays.ndarray, '__rmul__', _wrap(multiply, True))
-setattr(np_arrays.ndarray, '__matmul__', _wrap(matmul))
-setattr(np_arrays.ndarray, '__rmatmul__', _wrap(matmul, True))
-setattr(np_arrays.ndarray, '__pow__', _wrap(power))
-setattr(np_arrays.ndarray, '__rpow__', _wrap(power, True))
-setattr(np_arrays.ndarray, '__truediv__', _wrap(true_divide))
-setattr(np_arrays.ndarray, '__rtruediv__', _wrap(true_divide, True))
-
-
 def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
   """Helper function for comparision."""
   dtype = np_utils.result_type(x1, x2)
   # Cast x1 and x2 to the result_type if needed.
   x1 = np_array_ops.array(x1, dtype=dtype)
   x2 = np_array_ops.array(x2, dtype=dtype)
-  x1 = x1.data
-  x2 = x2.data
   if cast_bool_to_int and x1.dtype == dtypes.bool:
     x1 = math_ops.cast(x1, dtypes.int32)
     x2 = math_ops.cast(x2, dtypes.int32)
-  return np_utils.tensor_to_ndarray(tf_fun(x1, x2))
+  return tf_fun(x1, x2)
 
 
 @np_utils.np_doc('equal')
@@ -1043,7 +1023,7 @@ def array_equal(a1, a2):  # pylint: disable=missing-function-docstring
 def _logical_binary_op(tf_fun, x1, x2):
   x1 = np_array_ops.array(x1, dtype=np.bool_)
   x2 = np_array_ops.array(x2, dtype=np.bool_)
-  return np_utils.tensor_to_ndarray(tf_fun(x1.data, x2.data))
+  return tf_fun(x1, x2)
 
 
 @np_utils.np_doc('logical_and')
@@ -1064,16 +1044,7 @@ def logical_xor(x1, x2):
 @np_utils.np_doc('logical_not')
 def logical_not(x):
   x = np_array_ops.array(x, dtype=np.bool_)
-  return np_utils.tensor_to_ndarray(math_ops.logical_not(x.data))
-
-
-setattr(np_arrays.ndarray, '__invert__', logical_not)
-setattr(np_arrays.ndarray, '__lt__', _wrap(less))
-setattr(np_arrays.ndarray, '__le__', _wrap(less_equal))
-setattr(np_arrays.ndarray, '__gt__', _wrap(greater))
-setattr(np_arrays.ndarray, '__ge__', _wrap(greater_equal))
-setattr(np_arrays.ndarray, '__eq__', _wrap(equal))
-setattr(np_arrays.ndarray, '__ne__', _wrap(not_equal))
+  return math_ops.logical_not(x)
 
 
 @np_utils.np_doc('linspace')
@@ -1087,8 +1058,8 @@ def linspace(  # pylint: disable=missing-docstring
     axis=0):
   if dtype:
     dtype = np_utils.result_type(dtype)
-  start = np_array_ops.array(start, dtype=dtype).data
-  stop = np_array_ops.array(stop, dtype=dtype).data
+  start = np_array_ops.array(start, dtype=dtype)
+  stop = np_array_ops.array(stop, dtype=dtype)
   if num < 0:
     raise ValueError('Number of samples {} must be non-negative.'.format(num))
   step = ops.convert_to_tensor(np.nan)
@@ -1109,28 +1080,27 @@ def linspace(  # pylint: disable=missing-docstring
   if dtype:
     result = math_ops.cast(result, dtype)
   if retstep:
-    return (np_arrays.tensor_to_ndarray(result),
-            np_arrays.tensor_to_ndarray(step))
+    return (result, step)
   else:
-    return np_arrays.tensor_to_ndarray(result)
+    return result
 
 
 @np_utils.np_doc('logspace')
 def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
   dtype = np_utils.result_type(start, stop, dtype)
   result = linspace(
-      start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis).data
+      start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis)
   result = math_ops.pow(math_ops.cast(base, result.dtype), result)
   if dtype:
     result = math_ops.cast(result, dtype)
-  return np_arrays.tensor_to_ndarray(result)
+  return result
 
 
 @np_utils.np_doc('geomspace')
 def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):  # pylint: disable=missing-docstring
-  dtype = dtype or np_utils.result_type(start, stop, float(num),
-                                        np_array_ops.zeros((), dtype))
-  computation_dtype = np.promote_types(dtype, np.float32)
+  dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type(
+      start, stop, float(num), np_array_ops.zeros((), dtype))
+  computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32)
   start = np_array_ops.asarray(start, dtype=computation_dtype)
   stop = np_array_ops.asarray(stop, dtype=computation_dtype)
   # follow the numpy geomspace convention for negative and complex endpoints
@@ -1147,7 +1117,7 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):  # pylint
       axis=0)
   if axis != 0:
     res = np_array_ops.moveaxis(res, 0, axis)
-  return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))
+  return math_ops.cast(res, dtype)
 
 
 @np_utils.np_doc('ptp')
@@ -1163,14 +1133,14 @@ def concatenate(arys, axis=0):
   if not arys:
     raise ValueError('Need at least one array to concatenate.')
   dtype = np_utils.result_type(*arys)
-  arys = [np_array_ops.array(array, dtype=dtype).data for array in arys]
-  return np_arrays.tensor_to_ndarray(array_ops.concat(arys, axis))
+  arys = [np_array_ops.array(array, dtype=dtype) for array in arys]
+  return array_ops.concat(arys, axis)
 
 
 @np_utils.np_doc_only('tile')
 def tile(a, reps):  # pylint: disable=missing-function-docstring
-  a = np_array_ops.array(a).data
-  reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data
+  a = np_array_ops.array(a)
+  reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1])
 
   a_rank = array_ops.rank(a)
   reps_size = array_ops.size(reps)
@@ -1181,13 +1151,12 @@ def tile(a, reps):  # pylint: disable=missing-function-docstring
       constant_values=1)
   a = array_ops.reshape(a, a_shape)
 
-  return np_arrays.tensor_to_ndarray(array_ops.tile(a, reps))
+  return array_ops.tile(a, reps)
 
 
 @np_utils.np_doc('count_nonzero')
 def count_nonzero(a, axis=None):
-  return np_arrays.tensor_to_ndarray(
-      math_ops.count_nonzero(np_array_ops.array(a).data, axis))
+  return math_ops.count_nonzero(np_array_ops.array(a), axis)
 
 
 @np_utils.np_doc('argsort')
@@ -1199,7 +1168,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missin
     raise ValueError("'order' argument to sort is not supported.")
   stable = (kind == 'stable')
 
-  a = np_array_ops.array(a).data
+  a = np_array_ops.array(a)
 
   def _argsort(a, axis, stable):
     if axis is None:
@@ -1225,20 +1194,19 @@ def sort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-d
   a = np_array_ops.array(a)
 
   if axis is None:
-    result_t = sort_ops.sort(array_ops.reshape(a.data, [-1]), 0)
-    return np_utils.tensor_to_ndarray(result_t)
+    return sort_ops.sort(array_ops.reshape(a, [-1]), 0)
   else:
-    return np_utils.tensor_to_ndarray(sort_ops.sort(a.data, axis))
+    return sort_ops.sort(a, axis)
 
 
 def _argminmax(fn, a, axis=None):
   a = np_array_ops.array(a)
   if axis is None:
     # When axis is None numpy flattens the array.
-    a_t = array_ops.reshape(a.data, [-1])
+    a_t = array_ops.reshape(a, [-1])
   else:
-    a_t = np_array_ops.atleast_1d(a).data
-  return np_utils.tensor_to_ndarray(fn(input=a_t, axis=axis))
+    a_t = np_array_ops.atleast_1d(a)
+  return fn(input=a_t, axis=axis)
 
 
 @np_utils.np_doc('argmax')
@@ -1267,24 +1235,24 @@ def average(a, axis=None, weights=None, returned=False):  # pylint: disable=miss
                      'supported yet. Got type: %s' % type(axis))
   a = np_array_ops.array(a)
   if weights is None:  # Treat all weights as 1
-    if not np.issubdtype(a.dtype, np.inexact):
+    if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
       a = a.astype(
           np_utils.result_type(a.dtype, np_dtypes.default_float_type()))
-    avg = math_ops.reduce_mean(a.data, axis=axis)
+    avg = math_ops.reduce_mean(a, axis=axis)
     if returned:
       if axis is None:
-        weights_sum = array_ops.size(a.data)
+        weights_sum = array_ops.size(a)
       else:
-        weights_sum = array_ops.shape(a.data)[axis]
-      weights_sum = math_ops.cast(weights_sum, a.data.dtype)
+        weights_sum = array_ops.shape(a)[axis]
+      weights_sum = math_ops.cast(weights_sum, a.dtype)
   else:
-    if np.issubdtype(a.dtype, np.inexact):
+    if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
       out_dtype = np_utils.result_type(a.dtype, weights)
     else:
       out_dtype = np_utils.result_type(a.dtype, weights,
                                        np_dtypes.default_float_type())
-    a = np_array_ops.array(a, out_dtype).data
-    weights = np_array_ops.array(weights, out_dtype).data
+    a = np_array_ops.array(a, out_dtype)
+    weights = np_array_ops.array(weights, out_dtype)
 
     def rank_equal_case():
       control_flow_ops.Assert(
@@ -1316,8 +1284,7 @@ def average(a, axis=None, weights=None, returned=False):  # pylint: disable=miss
 
   avg = np_array_ops.array(avg)
   if returned:
-    weights_sum = np_array_ops.broadcast_to(weights_sum,
-                                            array_ops.shape(avg.data))
+    weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg))
     return avg, weights_sum
   return avg
 
@@ -1326,7 +1293,7 @@ def average(a, axis=None, weights=None, returned=False):  # pylint: disable=miss
 def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
   if dtype:
     dtype = np_utils.result_type(dtype)
-  a = np_array_ops.asarray(a, dtype).data
+  a = np_array_ops.asarray(a, dtype)
 
   if offset == 0:
     a_shape = a.shape
@@ -1334,7 +1301,7 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing
       rank = len(a_shape)
       if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or
                                                  axis2 == rank - 1):
-        return np_utils.tensor_to_ndarray(math_ops.trace(a))
+        return math_ops.trace(a)
 
   a = np_array_ops.diagonal(a, offset, axis1, axis2)
   return np_array_ops.sum(a, -1, dtype)
@@ -1353,11 +1320,10 @@ def meshgrid(*xi, **kwargs):
 
   indexing = kwargs.get('indexing', 'xy')
 
-  xi = [np_array_ops.asarray(arg).data for arg in xi]
+  xi = [np_array_ops.asarray(arg) for arg in xi]
   kwargs = {'indexing': indexing}
 
   outputs = array_ops.meshgrid(*xi, **kwargs)
-  outputs = [np_utils.tensor_to_ndarray(output) for output in outputs]
 
   return outputs
 
@@ -1387,7 +1353,62 @@ def einsum(subscripts, *operands, **kwargs):  # pylint: disable=missing-docstrin
     tf_optimize = 'optimal'
   else:
     raise ValueError('`optimize` method not supported: %s' % optimize)
-  operands = [x.data for x in operands]
   res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
-  res = np_utils.tensor_to_ndarray(res)
   return res
+
+
+def _tensor_t(self):
+  """Returns a Tensor which is the transpose of this Tensor."""
+  return self.transpose()
+
+
+def _tensor_ndim(self):
+  """Returns the rank of the Tensor."""
+  return self.shape.ndims
+
+
+def _tensor_pos(self):
+  """Returns self, for unary operator `+`."""
+  return self
+
+
+def _tensor_size(self):
+  """Returns the number of elements in this Tensor, if fully known."""
+  if not self.shape.is_fully_defined():
+    return None
+  return np.prod(self.shape.as_list())
+
+
+def _tensor_tolist(self):
+  if isinstance(self, ops.EagerTensor):
+    return self._numpy().tolist()  # pylint: disable=protected-access
+
+  raise ValueError('Symbolic Tensors do not support the tolist API.')
+
+
+def enable_numpy_methods_on_tensor():
+  """Adds additional NumPy methods on tf.Tensor class."""
+  t = property(_tensor_t)
+  setattr(ops.Tensor, 'T', t)
+
+  ndim = property(_tensor_ndim)
+  setattr(ops.Tensor, 'ndim', ndim)
+
+  size = property(_tensor_size)
+  setattr(ops.Tensor, 'size', size)
+
+  setattr(ops.Tensor, '__pos__', _tensor_pos)
+  setattr(ops.Tensor, 'tolist', _tensor_tolist)
+
+  # TODO(b/178540516): Make a custom `setattr` that changes the method's
+  #   docstring to the TF one.
+  setattr(ops.Tensor, 'transpose', np_array_ops.transpose)
+  setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper)  # pylint: disable=protected-access
+  setattr(ops.Tensor, 'ravel', np_array_ops.ravel)
+  setattr(ops.Tensor, 'clip', clip)
+  setattr(ops.Tensor, 'astype', math_ops.cast)
+  setattr(ops.Tensor, '__round__', np_array_ops.around)
+
+  # TODO(wangpeng): Remove `data` when all uses of it are removed
+  data = property(lambda self: self)
+  setattr(ops.Tensor, 'data', data)
diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py
index cb5326bcded..fd9dc18abfc 100644
--- a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py
@@ -160,7 +160,7 @@ class MathTest(test.TestCase, parameterized.TestCase):
       self.assertEqual(
           actual.dtype, expected.dtype,
           'Dtype mismatch.\nActual: {}\nExpected: {}\n{}'.format(
-              actual.dtype, expected.dtype, msg))
+              actual.dtype.as_numpy_dtype, expected.dtype, msg))
     self.assertEqual(
         actual.shape, expected.shape,
         'Shape mismatch.\nActual: {}\nExpected: {}\n{}'.format(
@@ -350,4 +350,6 @@ class MathTest(test.TestCase, parameterized.TestCase):
 
 if __name__ == '__main__':
   ops.enable_eager_execution()
+  ops.enable_numpy_style_type_promotion()
+  np_math_ops.enable_numpy_methods_on_tensor()
   test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_random.py b/tensorflow/python/ops/numpy_ops/np_random.py
index d0b199b2d21..f6a6462760f 100644
--- a/tensorflow/python/ops/numpy_ops/np_random.py
+++ b/tensorflow/python/ops/numpy_ops/np_random.py
@@ -73,7 +73,7 @@ def standard_normal(size=None):
   elif np_utils.isscalar(size):
     size = (size,)
   dtype = np_dtypes.default_float_type()
-  return np_utils.tensor_to_ndarray(random_ops.random_normal(size, dtype=dtype))
+  return random_ops.random_normal(size, dtype=dtype)
 
 
 @np_utils.np_doc('random.uniform')
@@ -83,9 +83,8 @@ def uniform(low=0.0, high=1.0, size=None):
   high = np_array_ops.asarray(high, dtype=dtype)
   if size is None:
     size = array_ops.broadcast_dynamic_shape(low.shape, high.shape)
-  return np_utils.tensor_to_ndarray(
-      random_ops.random_uniform(
-          shape=size, minval=low, maxval=high, dtype=dtype))
+  return random_ops.random_uniform(
+      shape=size, minval=low, maxval=high, dtype=dtype)
 
 
 @np_utils.np_doc('random.poisson')
@@ -94,8 +93,7 @@ def poisson(lam=1.0, size=None):
     size = ()
   elif np_utils.isscalar(size):
     size = (size,)
-  return np_utils.tensor_to_ndarray(
-      random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_))
+  return random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_)
 
 
 @np_utils.np_doc('random.random')
@@ -121,6 +119,5 @@ def randint(low, high=None, size=None, dtype=onp.int):  # pylint: disable=missin
   dtype = np_utils.result_type(dtype)
   if dtype not in (onp.int32, onp.int64):
     raise ValueError('Only np.int32 or np.int64 types are supported')
-  return np_utils.tensor_to_ndarray(
-      random_ops.random_uniform(
-          shape=size, minval=low, maxval=high, dtype=dtype))
+  return random_ops.random_uniform(
+      shape=size, minval=low, maxval=high, dtype=dtype)
diff --git a/tensorflow/python/ops/numpy_ops/np_random_test.py b/tensorflow/python/ops/numpy_ops/np_random_test.py
index 61ddbcaf47b..de0ee32802e 100644
--- a/tensorflow/python/ops/numpy_ops/np_random_test.py
+++ b/tensorflow/python/ops/numpy_ops/np_random_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import numpy_ops as np
 # Needed for ndarray.reshape.
 from tensorflow.python.ops.numpy_ops import np_array_ops  # pylint: disable=unused-import
 from tensorflow.python.ops.numpy_ops import np_dtypes
+from tensorflow.python.ops.numpy_ops import np_math_ops
 from tensorflow.python.ops.numpy_ops import np_random
 from tensorflow.python.platform import test
 
@@ -192,7 +193,7 @@ class RandNDistriutionTest(test.TestCase):
         self.assertEqual(output.shape, tuple(args))
         default_dtype = (
             np.float64 if np_dtypes.is_allow_float64() else np.float32)
-        self.assertEqual(output.dtype.type, default_dtype)
+        self.assertEqual(output.dtype.as_numpy_dtype, default_dtype)
 
       if np.prod(args):  # Don't bother with empty arrays.
         outputs = [output.tolist() for output in outputs]
@@ -230,4 +231,5 @@ class RandNDistriutionTest(test.TestCase):
 
 if __name__ == '__main__':
   ops.enable_eager_execution()
+  np_math_ops.enable_numpy_methods_on_tensor()
   test.main()
diff --git a/tensorflow/python/ops/numpy_ops/np_utils.py b/tensorflow/python/ops/numpy_ops/np_utils.py
index ca09624de76..90f3c3913a6 100644
--- a/tensorflow/python/ops/numpy_ops/np_utils.py
+++ b/tensorflow/python/ops/numpy_ops/np_utils.py
@@ -38,9 +38,6 @@ from tensorflow.python.types import core
 from tensorflow.python.util import nest
 
 
-tensor_to_ndarray = np_arrays.tensor_to_ndarray
-
-
 def _canonicalize_axis(axis, rank):
   return _canonicalize_axes([axis], rank)[0]
 
@@ -478,8 +475,6 @@ def _maybe_get_dtype(x):
   """Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
   # Don't put np.ndarray in this list, because np.result_type looks at the
   # value (not just dtype) of np.ndarray to decide the result type.
-  if isinstance(x, np_arrays.ndarray):
-    return x.dtype
   if isinstance(x, numbers.Real):
     return x
   if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)):
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index 169eb17cda1..569ef81b410 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -34,7 +34,6 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import tensor_array_ops
-from tensorflow.python.ops.numpy_ops import np_arrays
 from tensorflow.python.ops.parallel_for.pfor import PFor
 from tensorflow.python.ops.parallel_for.pfor import PForConfig
 from tensorflow.python.platform import tf_logging as logging
@@ -289,7 +288,6 @@ def _pfor_impl(loop_fn,
                                                 loop_fn_outputs)
 
   # Convert outputs to Tensor if needed.
-  rewrap_as_ndarray = False
   tmp_loop_fn_outputs = []
   for loop_fn_output in nest.flatten(loop_fn_output_tensors):
     if (loop_fn_output is not None and not isinstance(
@@ -301,9 +299,6 @@ def _pfor_impl(loop_fn,
                      " IndexedSlices separately, and handle the vectorized"
                      " outputs directly." % loop_fn_output)
         loop_fn_output = ops.convert_to_tensor(loop_fn_output)
-      elif isinstance(loop_fn_output, np_arrays.ndarray):
-        loop_fn_output = loop_fn_output.data
-        rewrap_as_ndarray = True
       else:
         loop_fn_output = ops.convert_to_tensor(loop_fn_output)
     tmp_loop_fn_outputs.append(loop_fn_output)
@@ -327,8 +322,6 @@ def _pfor_impl(loop_fn,
       flattened_output_tensors = []
       for loop_fn_output in nest.flatten(loop_fn_output_tensors):
         output = converter.convert(loop_fn_output)
-        if rewrap_as_ndarray:
-          output = np_arrays.tensor_to_ndarray(output)
         flattened_output_tensors.append(output)
   else:
     if pfor_config is not None and pfor_config._has_reductions():  # pylint: disable=protected-access
@@ -346,8 +339,6 @@ def _pfor_impl(loop_fn,
       flattened_output_tensors = nest.flatten(loop_fn_output_tensors)
       for loop_fn_output in flattened_output_tensors:
         output = converter.convert(loop_fn_output)
-        if rewrap_as_ndarray:
-          output = np_arrays.tensor_to_ndarray(output)
         remaining_output_tensors.append(output)
 
     with ops.name_scope("pfor_tiled"):
@@ -398,10 +389,6 @@ def _pfor_impl(loop_fn,
             tensor_shape.TensorShape([iters_value]).concatenate(
                 original_output.shape))
 
-      if rewrap_as_ndarray:
-        flattened_output_tensors = [
-            np_arrays.tensor_to_ndarray(x) for x in flattened_output_tensors]
-
   return nest.map_structure_up_to(
       loop_fn_outputs,
       functools.partial(_composite_from_tensors, batch_size=iters_value),
@@ -418,8 +405,6 @@ def _broadcasting_gather(x, i):
   elif static_first_dim is None:
     i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0)
   result = array_ops.gather(x, i)
-  if isinstance(x, np_arrays.ndarray):
-    result = np_arrays.ndarray.from_tensor(result)
   return result
 
 
@@ -548,8 +533,6 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
                             is_batched=True),
           elems))
   def _get_shape(x):
-    if isinstance(x, np_arrays.ndarray):
-      x = x.data
     if x.shape.rank is None:
       return None
     return x.shape.as_list()[0]
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index 69c59744efe..ecdc53d2572 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -50,12 +50,10 @@ from tensorflow.python.ops import cond_v2
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import numpy_ops as tnp
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
-from tensorflow.python.ops.numpy_ops import np_arrays
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.saved_model import load
@@ -1967,34 +1965,6 @@ class LoadTest(test.TestCase, parameterized.TestCase):
     self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
     self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
 
-  def test_saving_ndarray_specs(self, cycles):
-    class NdarrayModule(module.Module):
-
-      @def_function.function
-      def plain(self, x):
-        return tnp.add(x, 1)
-
-      @def_function.function(input_signature=[
-          np_arrays.NdarraySpec(tensor_spec.TensorSpec([], dtypes.float32))])
-      def with_signature(self, x):
-        return tnp.add(x, 1)
-
-    m = NdarrayModule()
-    c = tnp.asarray(3.0, tnp.float32)
-    output_plain, output_with_signature = m.plain(c), m.with_signature(c)
-
-    loaded_m = cycle(m, cycles)
-
-    load_output_plain, load_output_with_signature = (
-        loaded_m.plain(c), loaded_m.with_signature(c))
-
-    self.assertIsInstance(output_plain, tnp.ndarray)
-    self.assertIsInstance(load_output_plain, tnp.ndarray)
-    self.assertIsInstance(output_with_signature, tnp.ndarray)
-    self.assertIsInstance(load_output_with_signature, tnp.ndarray)
-    self.assertAllClose(output_plain, load_output_plain)
-    self.assertAllClose(output_with_signature, load_output_with_signature)
-
 
 class SingleCycleTests(test.TestCase, parameterized.TestCase):
 
diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py
index a7e5548ee06..9c71b853675 100644
--- a/tensorflow/python/saved_model/nested_structure_coder.py
+++ b/tensorflow/python/saved_model/nested_structure_coder.py
@@ -48,7 +48,6 @@ from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import tensor_array_ops
-from tensorflow.python.ops.numpy_ops import np_arrays
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.ragged import row_partition
 from tensorflow.python.util import compat
@@ -517,8 +516,6 @@ class _TypeSpecCodec(object):
           resource_variable_ops.VariableSpec,
       struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
           row_partition.RowPartitionSpec,
-      struct_pb2.TypeSpecProto.NDARRAY_SPEC:
-          np_arrays.NdarraySpec,
   }
 
   # Mapping from type (TypeSpec subclass) to enum value.
diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py
index fb074f76eb0..9951ea64a49 100644
--- a/tensorflow/python/saved_model/nested_structure_coder_test.py
+++ b/tensorflow/python/saved_model/nested_structure_coder_test.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops.numpy_ops import np_arrays
 from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import nested_structure_coder
@@ -332,14 +331,6 @@ class NestedStructureTest(test.TestCase):
     decoded = self._coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 
-  def testEncodeDecodeNdarraySpec(self):
-    structure = [np_arrays.NdarraySpec(
-        tensor_spec.TensorSpec([4, 2], dtypes.float32))]
-    self.assertTrue(self._coder.can_encode(structure))
-    encoded = self._coder.encode_structure(structure)
-    decoded = self._coder.decode_proto(encoded)
-    self.assertEqual(structure, decoded)
-
   def testNotEncodable(self):
 
     class NotEncodable(object):
diff --git a/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt b/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt
index 4edc5f08e84..104e85835cd 100644
--- a/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt
+++ b/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt
@@ -1,14 +1,15 @@
 path: "tensorflow.experimental.numpy.ndarray"
 tf_class {
-  is_instance: "<class \'tensorflow.python.ops.numpy_ops.np_arrays.ndarray\'>"
-  is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
+  is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
+  is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
+  is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
   is_instance: "<type \'object\'>"
   member {
-    name: "T"
-    mtype: "<type \'property\'>"
+    name: "OVERLOADABLE_OPERATORS"
+    mtype: "<type \'set\'>"
   }
   member {
-    name: "data"
+    name: "device"
     mtype: "<type \'property\'>"
   }
   member {
@@ -16,7 +17,15 @@ tf_class {
     mtype: "<type \'property\'>"
   }
   member {
-    name: "ndim"
+    name: "graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "op"
     mtype: "<type \'property\'>"
   }
   member {
@@ -24,39 +33,35 @@ tf_class {
     mtype: "<type \'property\'>"
   }
   member {
-    name: "size"
+    name: "value_index"
     mtype: "<type \'property\'>"
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'shape\', \'dtype\', \'buffer\'], varargs=None, keywords=None, defaults=[\"<class \'float\'>\", \'None\'], "
+    argspec: "args=[\'self\', \'op\', \'value_index\', \'dtype\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "astype"
-    argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "clip"
-    argspec: "args=[\'a\', \'a_min\', \'a_max\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "from_tensor"
-    argspec: "args=[\'cls\', \'tensor\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "ravel"
-    argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None"
-  }
-  member_method {
-    name: "reshape"
-    argspec: "args=[\'a\'], varargs=newshape, keywords=kwargs, defaults=None"
-  }
-  member_method {
-    name: "tolist"
+    name: "consumers"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
-    name: "transpose"
-    argspec: "args=[\'a\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+    name: "eval"
+    argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+  }
+  member_method {
+    name: "experimental_ref"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_shape"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "ref"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_shape"
+    argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"
   }
 }