Remove ndarray wrapper from TF Numpy. We return tensors directly.

PiperOrigin-RevId: 355761429
Change-Id: I1ab012bcd831550cd2aa2a8de3d758c23bc6332a
This commit is contained in:
Peng Wang 2021-02-04 19:13:46 -08:00 committed by TensorFlower Gardener
parent 6487888f1d
commit 0b9ff2eb1a
34 changed files with 744 additions and 859 deletions

View File

@ -136,7 +136,7 @@ message TypeSpecProto {
PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py
VARIABLE_SPEC = 9; // tf.VariableSpec
ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py
NDARRAY_SPEC = 11; // TF Numpy NDarray spec
reserved 11;
}
TypeSpecClass type_spec_class = 1;

View File

@ -62,9 +62,6 @@ from tensorflow.python.util.tf_export import tf_export
pfor_ops = LazyLoader(
"pfor_ops", globals(),
"tensorflow.python.ops.parallel_for.control_flow_ops")
np_arrays = LazyLoader(
"np_arrays", globals(),
"tensorflow.python.ops.numpy_ops.np_arrays")
function = LazyLoader("function", globals(),
"tensorflow.python.eager.function")
@ -727,8 +724,6 @@ def _handle_or_self(x):
"""Unwrap resource variable/ndarray to return tensors."""
if resource_variable_ops.is_resource_variable(x):
return x.handle
if isinstance(x, np_arrays.ndarray):
return x.data
return x
@ -1034,7 +1029,6 @@ class GradientTape(object):
" of Tensors or Variables to be differentiated,"
" but recieved %r" % (target))
num_ndarrays = 0
flat_targets = []
for t in nest.flatten(target):
if not backprop_util.IsTrainable(t):
@ -1045,12 +1039,7 @@ class GradientTape(object):
if resource_variable_ops.is_resource_variable(t):
with self:
t = ops.convert_to_tensor(t)
elif isinstance(t, np_arrays.ndarray):
t = t.data
num_ndarrays += 1
flat_targets.append(t)
# Only rewrap if all targets are ndarray. If not, prefer tensors.
rewrap_as_ndarray = num_ndarrays == len(flat_targets)
flat_sources = nest.flatten(sources)
flat_sources_raw = flat_sources
@ -1083,13 +1072,6 @@ class GradientTape(object):
self._watched_variables = self._tape.watched_variables()
self._tape = None
if rewrap_as_ndarray:
def _tensor_to_ndarray(x):
if x is not None:
return np_arrays.tensor_to_ndarray(x)
return None
flat_grad = nest.map_structure(_tensor_to_ndarray, flat_grad)
grad = nest.pack_sequence_as(sources, flat_grad)
return grad
@ -1158,10 +1140,6 @@ class GradientTape(object):
"compute one set of gradients (or jacobians)")
flat_sources = nest.flatten(sources)
rewrap_as_ndarray = False
if isinstance(target, np_arrays.ndarray):
target = target.data
rewrap_as_ndarray = True
target_static_shape = target.shape
target_shape = array_ops.shape(target)
# Note that we push and pop the tape here and below. This is needed since we
@ -1211,8 +1189,6 @@ class GradientTape(object):
out = array_ops.reshape(out, new_shape)
if context.executing_eagerly():
out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
if rewrap_as_ndarray:
out = np_arrays.tensor_to_ndarray(out)
output[i] = out
return nest.pack_sequence_as(sources, output)
@ -1287,12 +1263,6 @@ class GradientTape(object):
if self._tape is None:
raise RuntimeError("A non-persistent GradientTape can only be used to"
"compute one set of gradients (or jacobians)")
rewrap_as_ndarray = False
if isinstance(target, np_arrays.ndarray):
target = target.data
rewrap_as_ndarray = True
if isinstance(source, np_arrays.ndarray):
source = source.data
target_shape = target.shape
if target_shape.rank is None:
dim = tensor_shape.Dimension(None)
@ -1354,8 +1324,6 @@ class GradientTape(object):
# represent unconnected gradients. This is to maintain compatibility with
# the previous behavior, which ignored `unconnected_gradients`.
output = array_ops.zeros(new_shape, target.dtype)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
return output
else:
output = array_ops.reshape(output,
@ -1363,6 +1331,4 @@ class GradientTape(object):
output = array_ops.transpose(output, [1, 0, 2])
output = array_ops.reshape(output, new_shape)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
return output

View File

@ -32,7 +32,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.parallel_for import control_flow_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
@ -441,16 +440,11 @@ class ForwardAccumulator():
if hasattr(tensor, "handle"):
unwrapped_tensor = ops.convert_to_tensor(tensor.handle)
else:
if isinstance(tensor, np_arrays.ndarray):
unwrapped_tensor = tensor.data
else:
unwrapped_tensor = tensor
unwrapped_tensor = tensor
result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator,
unwrapped_tensor)
if result is None and unconnected_gradients == UnconnectedGradients.ZERO:
result = array_ops.zeros_like(tensor)
if result is not None and isinstance(tensor, np_arrays.ndarray):
return np_arrays.tensor_to_ndarray(result)
return result
return nest.map_structure(_fetch_jvp, primals)

View File

@ -1522,11 +1522,6 @@ class ConcreteFunction(object):
self._func_graph = func_graph
self._captured_inputs = self._func_graph.external_captures
self._captured_closures = self._func_graph.deferred_external_captures
structured_outputs = self._func_graph.structured_outputs
self._ndarrays_list = (
isinstance(structured_outputs, (list, tuple)) and structured_outputs and
all(isinstance(o, np_arrays.ndarray) for o in structured_outputs))
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
# function_spec defines the structured signature.
self._set_function_spec(function_spec)
@ -2176,12 +2171,6 @@ class ConcreteFunction(object):
if self._func_graph.structured_outputs is None:
return result
if result:
if self._ndarrays_list:
return [np_arrays.tensor_to_ndarray(o) for o in result]
elif self._ndarray_singleton:
return np_arrays.tensor_to_ndarray(result[0])
# Replace outputs with results, skipping over any 'None' values.
outputs_list = nest.flatten(
self._func_graph.structured_outputs, expand_composites=True)

View File

@ -257,7 +257,7 @@ def disable_tensor_equality():
# TODO(mdan): This object should subclass Symbol, not just Tensor.
@tf_export("Tensor")
@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"])
class Tensor(internal.NativeObject, core_tf_types.Tensor):
"""A tensor is a multidimensional array of elements represented by a
@ -386,6 +386,16 @@ class Tensor(internal.NativeObject, core_tf_types.Tensor):
self._id = uid()
self._name = None
def __getattr__(self, name):
if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size",
"tolist", "data"}:
# TODO(wangpeng): Export the enable_numpy_behavior knob
raise AttributeError("""
If you are looking for numpy-related methods, please run the following:
import tensorflow.python.ops.numpy_ops.np_config
np_config.enable_numpy_behavior()""")
self.__getattribute__(name)
@staticmethod
def _create_with_tf_output(op, value_index, dtype, tf_output):
ret = Tensor(op, value_index, dtype)
@ -6943,6 +6953,30 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs):
return grouped_inputs
_numpy_style_type_promotion = False
def enable_numpy_style_type_promotion():
"""If called, follows NumPy's rules for type promotion.
Used for enabling NumPy behavior on methods for TF NumPy.
"""
global _numpy_style_type_promotion
_numpy_style_type_promotion = True
_numpy_style_slicing = False
def enable_numpy_style_slicing():
"""If called, follows NumPy's rules for slicing Tensors.
Used for enabling NumPy behavior on slicing for TF NumPy.
"""
global _numpy_style_slicing
_numpy_style_slicing = True
class _TensorIterator(object):
"""Iterates over the leading dim of a Tensor. Performs no error checks."""

View File

@ -202,6 +202,17 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
self.assertAllEqual(np.array(x), np.ones((3, 4)))
self.assertEqual(len(x), 3)
def testConstructor(self):
a = array_ops.ones([])
for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size",
"tolist", "data"]:
with self.assertRaisesRegex(
AttributeError, r"If you are looking for numpy-related methods"):
getattr(a, name)
with self.assertRaisesRegex(
AttributeError, r"object has no attribute"):
a.foo_bar()
def testRef(self):
x1 = constant_op.constant(3)
x2 = x1

View File

@ -250,6 +250,11 @@ class Dimension(object):
# Allow use in Python 3 range
return self._value
def __hash__(self):
if self._value is None:
raise ValueError("Unable to hash Dimension with value 'None'")
return hash(self._value)
@property
def value(self):
"""The value of this dimension, or None if it is unknown."""
@ -986,6 +991,11 @@ class TensorShape(object):
other = TensorShape(other)
return other.concatenate(self)
def __hash__(self):
if not self.is_fully_defined():
raise ValueError("Unable to hash partially defined TensorShape.")
return hash(tuple(self._dims))
def concatenate(self, other):
"""Returns the concatenation of the dimension in `self` and `other`.

View File

@ -384,6 +384,20 @@ class ShapeTest(test_util.TensorFlowTestCase, parameterized.TestCase):
else:
self.assertEqual(expected, mcs.as_list())
def testHash(self):
base = tensor_shape.TensorShape([1, 2, 3, 4])
base_copy = tensor_shape.TensorShape([1, 2, 3, 4])
self.assertEqual(hash(base), hash(base_copy))
with self.assertRaisesRegex(ValueError, r"Unable to hash partially"):
hash(tensor_shape.TensorShape([1, 2, 3, 4, None]))
with self.assertRaisesRegex(ValueError, r"Unable to hash partially"):
hash(tensor_shape.TensorShape(None))
with self.assertRaisesRegex(ValueError, r"Unable to hash Dimension"):
hash(tensor_shape.Dimension(None))
def testMostSpecificCompatibleShape(self):
self._testMostSpecificCompatibleShapeHelper([1, 2], None, None)
self._testMostSpecificCompatibleShapeHelper(None, [1, 2], None)

View File

@ -368,9 +368,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase):
# mode. Variable.assign(...).op is None in Eager mode and an op in Graph
# mode or a tf.function. We test this is also true of AutoCastVariable.
if context.executing_eagerly():
with self.assertRaisesRegex(
AttributeError,
'Tensor.op is meaningless when eager execution is enabled'):
with self.assertRaises(AttributeError):
x.op # pylint: disable=pointless-statement
self.assertIsNone(x.assign(1.0).op)
self.assertIsNone(x.assign_add(1.0).op)

View File

@ -962,6 +962,9 @@ def _slice_helper(tensor, slice_spec, var=None):
tf.newaxis or scalar int32/int64 tensors.
"""
tensor = ops.convert_to_tensor(tensor)
# TODO(wangpeng): Consider supporting var
if var is None and ops._numpy_style_slicing: # pylint: disable=protected-access
return tensor._numpy_style_getitem(slice_spec) # pylint: disable=protected-access
if isinstance(slice_spec, bool) or \
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \

View File

@ -38,16 +38,10 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import lazy_loader
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
np_arrays = lazy_loader.LazyLoader(
"np_arrays", globals(),
"tensorflow.python.ops.numpy_ops.np_arrays")
@tf_export(v1=["map_fn"])
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
def map_fn(fn,
@ -426,8 +420,6 @@ def map_fn(fn,
# Check that inputs are not scalars.
first_elem = elems_flat[0]
if isinstance(first_elem, np_arrays.ndarray):
first_elem = first_elem.data
elems_static_shape = first_elem.shape
if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
if len(elems_flat) == 1:

View File

@ -70,6 +70,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numbers
import numpy as np
import six
from six.moves import builtins
@ -99,9 +100,17 @@ from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
np_dtypes = LazyLoader(
"np_dtypes", globals(),
"tensorflow.python.ops.numpy_ops.np_dtypes")
# Aliases for some automatically-generated names.
nextafter = gen_math_ops.next_after
@ -1130,6 +1139,48 @@ ops.Tensor._override_operator("__neg__", gen_math_ops.neg)
ops.Tensor._override_operator("__abs__", abs)
def _maybe_get_dtype(x):
"""Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
# Don't put np.ndarray in this list, because np.result_type looks at the
# value (not just dtype) of np.ndarray to decide the result type.
if isinstance(x, numbers.Real):
return x
if isinstance(x, ops.Tensor):
return x.dtype.as_numpy_dtype
if isinstance(x, dtypes.DType):
return x.as_numpy_dtype
if isinstance(x, tensor_shape.TensorShape):
return np.int32
if isinstance(x, (list, tuple)):
raise ValueError("Got sequence {}".format(x))
return x
def maybe_promote_tensors(*tensors, force_same_dtype=True):
"""Promote tensors if numpy style promotion is enabled."""
if not tensors:
return tensors
if not ops._numpy_style_type_promotion:
if not force_same_dtype:
return tensors
promoted_tensors = []
promoted_tensors.append(tensors[0])
dtype = tensors[0].dtype.base_dtype
for tensor in tensors[1:]:
promoted_tensors.append(
ops.convert_to_tensor(tensor, dtype, name="x"))
return promoted_tensors
result_type = np_dtypes._result_type(
*[_maybe_get_dtype(x) for x in nest.flatten(tensors)])
def _promote_or_cast(x):
if isinstance(x, ops.Tensor):
x = cast(x, result_type)
else:
x = ops.convert_to_tensor(x, result_type)
return x
return [_promote_or_cast(x) for x in tensors]
def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
"""Register operators with different tensor and scalar versions.
@ -1145,6 +1196,10 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
def binary_op_wrapper(x, y):
with ops.name_scope(None, op_name, [x, y]) as name:
try:
# force_same_dtype=False to preserve existing TF behavior
# TODO(b/178860388): Figure out why binary_op_wrapper and
# r_binary_op_wrapper use different force_same_dtype values.
x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
return func(x, y, name=name)
except (TypeError, ValueError) as e:
# Even if dispatching the op failed, the RHS may be a tensor aware
@ -1175,7 +1230,9 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
def r_binary_op_wrapper(y, x):
with ops.name_scope(None, op_name, [x, y]) as name:
x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x")
# TODO(b/178860388): Figure out why binary_op_wrapper and
# r_binary_op_wrapper use different force_same_dtype values.
y, x = maybe_promote_tensors(y, x)
return func(x, y, name=name)
# Propagate func.__doc__ to the wrappers
@ -1581,10 +1638,21 @@ _OverrideBinaryOperatorHelper(xor_, "xor")
ops.Tensor._override_operator("__invert__", invert_)
ops.Tensor._override_operator("__lt__", gen_math_ops.less)
ops.Tensor._override_operator("__le__", gen_math_ops.less_equal)
ops.Tensor._override_operator("__gt__", gen_math_ops.greater)
ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal)
def _promote_dtypes_decorator(fn):
def wrapper(x, y, *args, **kwargs):
x, y = maybe_promote_tensors(x, y, force_same_dtype=False)
return fn(x, y, *args, **kwargs)
return tf_decorator.make_decorator(fn, wrapper)
ops.Tensor._override_operator("__lt__", _promote_dtypes_decorator(
gen_math_ops.less))
ops.Tensor._override_operator("__le__", _promote_dtypes_decorator(
gen_math_ops.less_equal))
ops.Tensor._override_operator("__gt__", _promote_dtypes_decorator(
gen_math_ops.greater))
ops.Tensor._override_operator("__ge__", _promote_dtypes_decorator(
gen_math_ops.greater_equal))
@tf_export("math.equal", "equal")
@ -1691,6 +1759,7 @@ def tensor_equals(self, other):
g = getattr(self, "graph", None)
if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and
(g is None or g.building_function)):
self, other = maybe_promote_tensors(self, other)
return gen_math_ops.equal(self, other, incompatible_shape_error=False)
else:
# In legacy graph mode, tensor equality is object equality
@ -1727,6 +1796,7 @@ def tensor_not_equals(self, other):
if other is None:
return True
if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions():
self, other = maybe_promote_tensors(self, other)
return gen_math_ops.not_equal(self, other, incompatible_shape_error=False)
else:
# In legacy graph mode, tensor equality is object equality
@ -3482,7 +3552,14 @@ def matvec(a,
return array_ops.squeeze(output, axis=-1)
_OverrideBinaryOperatorHelper(matmul, "matmul")
# TODO(b/178650720): Also support numpy-style type promotion in freestanding TF
# functions (e.g. tf.add).
def matmul_wrapper(a, b, name=None): # pylint: disable=missing-function-docstring
if ops._numpy_style_type_promotion:
return a._matmul(b)
return matmul(a, b, name=name)
matmul_wrapper.__doc__ = matmul.__doc__
_OverrideBinaryOperatorHelper(matmul_wrapper, "matmul")
sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")(
gen_math_ops.sparse_mat_mul)

View File

@ -13,6 +13,7 @@ py_library(
"__init__.py",
"np_array_ops.py",
"np_arrays.py",
"np_config.py",
"np_dtypes.py",
"np_export.py",
"np_math_ops.py",
@ -40,6 +41,17 @@ py_library(
],
)
cuda_py_test(
name = "np_dtypes_test",
srcs = ["np_dtypes_test.py"],
deps = [
":numpy",
"//tensorflow/python:platform",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "np_arrays_test",
srcs = ["np_arrays_test.py"],

View File

@ -1,4 +1,5 @@
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
licenses(["notice"])
@ -10,3 +11,13 @@ py_test(
"//tensorflow:tensorflow_py",
],
)
cuda_py_test(
name = "np_config_test",
srcs = ["np_config_test.py"],
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python/ops/numpy_ops:numpy",
"//third_party/py/numpy",
],
)

View File

@ -0,0 +1,44 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that an error is raised when numpy functions are called."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v2 as tf
from tensorflow.python.ops.numpy_ops import np_config
class ConfigTest(tf.test.TestCase):
def testMethods(self):
a = tf.constant(1.)
for name in {'T', 'astype', 'ravel', 'transpose', 'reshape', 'clip', 'size',
'tolist'}:
with self.assertRaisesRegex(AttributeError, 'enable_numpy_behavior'):
getattr(a, name)
np_config.enable_numpy_behavior()
for name in {'T', 'astype', 'ravel', 'transpose', 'reshape', 'clip', 'size',
'tolist'}:
_ = getattr(a, name)
if __name__ == '__main__':
tf.compat.v1.enable_eager_execution()
tf.test.main()

View File

@ -61,15 +61,11 @@ def empty_like(a, dtype=None):
def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name
dtype = (
np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type())
if isinstance(shape, np_arrays.ndarray):
shape = shape.data
return np_arrays.tensor_to_ndarray(array_ops.zeros(shape, dtype=dtype))
return array_ops.zeros(shape, dtype=dtype)
@np_utils.np_doc('zeros_like')
def zeros_like(a, dtype=None): # pylint: disable=missing-docstring
if isinstance(a, np_arrays.ndarray):
a = a.data
if dtype is None:
# We need to let np_utils.result_type decide the dtype, not tf.zeros_like
dtype = np_utils.result_type(a)
@ -78,27 +74,23 @@ def zeros_like(a, dtype=None): # pylint: disable=missing-docstring
# `float`, so we let `np_utils.result_type` decide.
dtype = np_utils.result_type(dtype)
dtype = dtypes.as_dtype(dtype) # Work around b/149877262
return np_arrays.tensor_to_ndarray(array_ops.zeros_like(a, dtype))
return array_ops.zeros_like(a, dtype)
@np_utils.np_doc('ones')
def ones(shape, dtype=float): # pylint: disable=redefined-outer-name
if dtype:
dtype = np_utils.result_type(dtype)
if isinstance(shape, np_arrays.ndarray):
shape = shape.data
return np_arrays.tensor_to_ndarray(array_ops.ones(shape, dtype=dtype))
return array_ops.ones(shape, dtype=dtype)
@np_utils.np_doc('ones_like')
def ones_like(a, dtype=None):
if isinstance(a, np_arrays.ndarray):
a = a.data
if dtype is None:
dtype = np_utils.result_type(a)
else:
dtype = np_utils.result_type(dtype)
return np_arrays.tensor_to_ndarray(array_ops.ones_like(a, dtype))
return array_ops.ones_like(a, dtype)
@np_utils.np_doc('eye')
@ -115,7 +107,7 @@ def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-do
# tf.linalg.diag will raise an error in this case
return zeros([N, M], dtype=dtype)
if k == 0:
return np_arrays.tensor_to_ndarray(linalg_ops.eye(N, M, dtype=dtype))
return linalg_ops.eye(N, M, dtype=dtype)
# We need the precise length, otherwise tf.linalg.diag will raise an error
diag_len = min(N, M)
if k > 0:
@ -129,8 +121,7 @@ def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-do
elif M - k > N:
diag_len = N + k
diagonal_ = array_ops.ones([diag_len], dtype=dtype)
return np_arrays.tensor_to_ndarray(
array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k))
return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)
@np_utils.np_doc('identity')
@ -142,10 +133,9 @@ def identity(n, dtype=float):
def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name
if not isinstance(shape, np_arrays.ndarray):
shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32))
shape = atleast_1d(shape).data
shape = atleast_1d(shape)
fill_value = asarray(fill_value, dtype=dtype)
return np_arrays.tensor_to_ndarray(
array_ops.broadcast_to(fill_value.data, shape))
return array_ops.broadcast_to(fill_value, shape)
# Using doc only here since np full_like signature doesn't seem to have the
@ -160,19 +150,15 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): #
if shape:
raise ValueError('Overriding the shape is not supported.')
a = asarray(a).data
a = asarray(a)
dtype = dtype or np_utils.result_type(a)
fill_value = asarray(fill_value, dtype=dtype)
return np_arrays.tensor_to_ndarray(
array_ops.broadcast_to(fill_value.data, array_ops.shape(a)))
return array_ops.broadcast_to(fill_value, array_ops.shape(a))
def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name
"""Main implementation of np.array()."""
if isinstance(val, np_arrays.ndarray):
result_t = val.data
else:
result_t = val
result_t = val
if not isinstance(result_t, ops.Tensor):
if not dtype:
@ -180,13 +166,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red
# We can't call `convert_to_tensor(result_t, dtype=dtype)` here because
# convert_to_tensor doesn't allow incompatible arguments such as (5.5, int)
# while np.array allows them. We need to convert-then-cast.
def maybe_data(x):
if isinstance(x, np_arrays.ndarray):
return x.data
return x
# Handles lists of ndarrays
result_t = nest.map_structure(maybe_data, result_t)
# EagerTensor conversion complains about "mixed types" when converting
# tensors with no dtype information. This is because it infers types based
# on one selected item in the list. So e.g. when converting [2., 2j]
@ -204,7 +184,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red
result_t = array_ops.identity(result_t)
if ndmin == 0:
return np_arrays.tensor_to_ndarray(result_t)
return result_t
ndims = array_ops.rank(result_t)
@ -216,7 +196,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red
result_t = np_utils.cond(
np_utils.greater(ndmin, ndims), true_fn, lambda: result_t)
return np_arrays.tensor_to_ndarray(result_t)
return result_t
# TODO(wangpeng): investigate whether we can make `copy` default to False.
@ -241,7 +221,8 @@ def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-out
def asarray(a, dtype=None):
if dtype:
dtype = np_utils.result_type(dtype)
if isinstance(a, np_arrays.ndarray) and (not dtype or dtype == a.dtype):
if isinstance(a, np_arrays.ndarray) and (
not dtype or dtype == a.dtype.as_numpy_dtype):
return a
return array(a, dtype, copy=False)
@ -294,15 +275,15 @@ def arange(start, stop=None, step=1, dtype=None):
return array([], dtype=dtype)
# TODO(srbs): There are some bugs when start or stop is float type and dtype
# is integer type.
return np_arrays.tensor_to_ndarray(
math_ops.cast(math_ops.range(start, limit=stop, delta=step), dtype=dtype))
return math_ops.cast(
math_ops.range(start, limit=stop, delta=step), dtype=dtype)
# Building matrices.
@np_utils.np_doc('diag')
def diag(v, k=0): # pylint: disable=missing-docstring
"""Raises an error if input is not 1- or 2-d."""
v = asarray(v).data
v = asarray(v)
v_rank = array_ops.rank(v)
v.shape.with_rank_at_most(2)
@ -331,20 +312,20 @@ def diag(v, k=0): # pylint: disable=missing-docstring
result = np_utils.cond(
math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k))
return np_utils.tensor_to_ndarray(result)
return result
@np_utils.np_doc('diagonal')
def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring
a = asarray(a).data
a = asarray(a)
maybe_rank = a.shape.rank
if maybe_rank is not None and offset == 0 and (
axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or
axis2 == -1):
return np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a))
return array_ops.matrix_diag_part(a)
a = moveaxis(np_utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data
a = moveaxis(a, (axis1, axis2), (-2, -1))
a_shape = array_ops.shape(a)
@ -361,20 +342,20 @@ def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstrin
np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)),
), _zeros, lambda: (a, offset))
a = np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a, k=offset))
a = array_ops.matrix_diag_part(a, k=offset)
return a
@np_utils.np_doc('diagflat')
def diagflat(v, k=0):
v = asarray(v)
return diag(array_ops.reshape(v.data, [-1]), k)
return diag(array_ops.reshape(v, [-1]), k)
def _promote_dtype(*arrays):
dtype = np_utils.result_type(*arrays)
def _fast_asarray(a):
if isinstance(a, np_arrays.ndarray) and dtype == a.dtype:
if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype:
return a
return _array_internal(a, dtype=dtype, copy=False)
return [_fast_asarray(a) for a in arrays]
@ -382,9 +363,11 @@ def _promote_dtype(*arrays):
def _promote_dtype_binary(t1, t2):
dtype = np_utils._result_type_binary(t1, t2) # pylint: disable=protected-access
if not(isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype):
if not(
isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype):
t1 = _array_internal(t1, dtype=dtype, copy=False)
if not(isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype):
if not(
isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype):
t2 = _array_internal(t2, dtype=dtype, copy=False)
return t1, t2
@ -392,15 +375,13 @@ def _promote_dtype_binary(t1, t2):
@np_utils.np_doc('all')
def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin
a = asarray(a, dtype=bool)
return np_utils.tensor_to_ndarray(
math_ops.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims))
return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims)
@np_utils.np_doc('any')
def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin
a = asarray(a, dtype=bool)
return np_utils.tensor_to_ndarray(
math_ops.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims))
return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims)
@np_utils.np_doc('compress')
@ -425,13 +406,12 @@ def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,m
# `tf.boolean_mask` requires the first dimensions of array and condition to
# match. `np.compress` pads condition with False when it is shorter.
condition_t = condition.data
a_t = a.data
condition_t = condition
a_t = a
if condition.shape[0] < a.shape[axis]:
padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False)
condition_t = array_ops.concat([condition_t, padding], axis=0)
return np_utils.tensor_to_ndarray(
array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis))
return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)
@np_utils.np_doc('copy')
@ -443,8 +423,9 @@ def _maybe_promote_to_int(a):
if dtypes.as_dtype(a.dtype).is_integer:
# If a is an integer type and its precision is less than that of `int`,
# the output type will be `int`.
output_type = np.promote_types(a.dtype, int)
if output_type != a.dtype:
a_numpy_dtype = a.dtype.as_numpy_dtype
output_type = np.promote_types(a_numpy_dtype, int)
if output_type != a_numpy_dtype:
a = asarray(a, dtype=output_type)
return a
@ -462,8 +443,8 @@ def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring
a = ravel(a)
axis = 0
elif axis < 0:
axis += array_ops.rank(a.data)
return np_utils.tensor_to_ndarray(math_ops.cumprod(a.data, axis))
axis += array_ops.rank(a)
return math_ops.cumprod(a, axis)
@np_utils.np_doc('cumsum')
@ -478,8 +459,8 @@ def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring
a = ravel(a)
axis = 0
elif axis < 0:
axis += array_ops.rank(a.data)
return np_utils.tensor_to_ndarray(math_ops.cumsum(a.data, axis))
axis += array_ops.rank(a)
return math_ops.cumsum(a, axis)
@np_utils.np_doc('imag')
@ -487,7 +468,7 @@ def imag(val):
val = asarray(val)
# TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always
# return an ndarray.
return np_utils.tensor_to_ndarray(math_ops.imag(val.data))
return math_ops.imag(val)
_TO_INT_ = 0
@ -532,10 +513,9 @@ def _reduce(tf_fn,
a = asarray(a, dtype=dtype)
if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and
tf_bool_fn is not None):
return np_utils.tensor_to_ndarray(
tf_bool_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims)
if dtype is None:
dtype = a.dtype
dtype = a.dtype.as_numpy_dtype
if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
if promote_int == _TO_INT_:
# If a is an integer/bool type and whose bit width is less than np.int_,
@ -554,12 +534,15 @@ def _reduce(tf_fn,
dtype = np.int_
else:
dtype = np.uint
a = a.astype(dtype)
a = math_ops.cast(a, dtype)
elif promote_int == _TO_FLOAT:
a = a.astype(np_dtypes.default_float_type())
a = math_ops.cast(a, np_dtypes.default_float_type())
return np_utils.tensor_to_ndarray(
tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
if isinstance(axis, ops.Tensor) and axis.dtype not in (
dtypes.int32, dtypes.int64):
axis = math_ops.cast(axis, dtypes.int64)
return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims)
# TODO (DarrenZhang01): Add `axis` support to the `size` API.
@ -570,11 +553,11 @@ def size(x, axis=None): # pylint: disable=missing-docstring
'`np.size` implementation')
if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)):
return 1
x = asarray(x).data
x = asarray(x)
if x.shape.is_fully_defined():
return np.prod(x.shape.as_list(), dtype=int)
else:
return np_utils.tensor_to_ndarray(array_ops.size_v2(x))
return array_ops.size_v2(x)
@np_utils.np_doc('sum')
@ -677,10 +660,10 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: d
axis=axis,
dtype=working_dtype,
keepdims=keepdims,
promote_int=_TO_FLOAT).data
promote_int=_TO_FLOAT)
if dtype:
result = math_ops.cast(result, dtype)
return np_utils.tensor_to_ndarray(result)
return result
@np_utils.np_doc('std')
@ -697,13 +680,7 @@ def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstr
@np_utils.np_doc('ravel')
def ravel(a): # pylint: disable=missing-docstring
a = asarray(a)
out = np_utils.cond(
math_ops.equal(a.ndim, 1), lambda: a.data,
lambda: array_ops.reshape(a.data, [-1]))
return np_utils.tensor_to_ndarray(out)
setattr(np_arrays.ndarray, 'ravel', ravel)
return array_ops.reshape(a, [-1])
@np_utils.np_doc('real')
@ -711,12 +688,12 @@ def real(val):
val = asarray(val)
# TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
# return an ndarray.
return np_utils.tensor_to_ndarray(math_ops.real(val.data))
return math_ops.real(val)
@np_utils.np_doc('repeat')
def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
a = asarray(a).data
a = asarray(a)
original_shape = a._shape_as_list() # pylint: disable=protected-access
# Best effort recovery of the shape.
known_shape = original_shape is not None and None not in original_shape
@ -737,18 +714,18 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring
else:
original_shape[axis] = repeats_np.sum()
repeats = asarray(repeats).data
repeats = asarray(repeats)
result = array_ops.repeat(a, repeats, axis)
if known_shape:
result.set_shape(original_shape)
return np_utils.tensor_to_ndarray(result)
return result
@np_utils.np_doc('around')
def around(a, decimals=0): # pylint: disable=missing-docstring
a = asarray(a)
dtype = a.dtype
dtype = a.dtype.as_numpy_dtype
factor = math.pow(10, decimals)
if np.issubdtype(dtype, np.inexact):
factor = math_ops.cast(factor, dtype)
@ -756,12 +733,12 @@ def around(a, decimals=0): # pylint: disable=missing-docstring
# Use float as the working dtype when a.dtype is exact (e.g. integer),
# because `decimals` can be negative.
float_dtype = np_dtypes.default_float_type()
a = a.astype(float_dtype).data
a = a.astype(float_dtype)
factor = math_ops.cast(factor, float_dtype)
a = math_ops.multiply(a, factor)
a = math_ops.round(a)
a = math_ops.divide(a, factor)
return np_utils.tensor_to_ndarray(a).astype(dtype)
return a.astype(dtype)
setattr(np_arrays.ndarray, '__round__', around)
@ -774,18 +751,16 @@ def reshape(a, newshape, order='C'):
raise ValueError('Unsupported order argument {}'.format(order))
a = asarray(a)
if isinstance(newshape, np_arrays.ndarray):
newshape = newshape.data
if isinstance(newshape, int):
newshape = [newshape]
if order == 'F':
r = array_ops.transpose(
array_ops.reshape(array_ops.transpose(a.data), newshape[::-1]))
array_ops.reshape(array_ops.transpose(a), newshape[::-1]))
else:
r = array_ops.reshape(a.data, newshape)
r = array_ops.reshape(a, newshape)
return np_utils.tensor_to_ndarray(r)
return r
def _reshape_method_wrapper(a, *newshape, **kwargs):
@ -802,13 +777,13 @@ def _reshape_method_wrapper(a, *newshape, **kwargs):
@np_utils.np_doc('expand_dims')
def expand_dims(a, axis):
a = asarray(a)
return np_utils.tensor_to_ndarray(array_ops.expand_dims(a.data, axis=axis))
return array_ops.expand_dims(a, axis=axis)
@np_utils.np_doc('squeeze')
def squeeze(a, axis=None):
a = asarray(a)
return np_utils.tensor_to_ndarray(array_ops.squeeze(a, axis))
return array_ops.squeeze(a, axis)
@np_utils.np_doc('transpose')
@ -816,12 +791,12 @@ def transpose(a, axes=None):
a = asarray(a)
if axes is not None:
axes = asarray(axes)
return np_utils.tensor_to_ndarray(array_ops.transpose(a=a.data, perm=axes))
return array_ops.transpose(a=a, perm=axes)
@np_utils.np_doc('swapaxes')
def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring
a = asarray(a).data
a = asarray(a)
def adjust_axes(axes, rank):
def f(x):
if isinstance(x, int):
@ -848,7 +823,7 @@ def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring
perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]],
[axis2, axis1])
a = array_ops.transpose(a, perm)
return np_utils.tensor_to_ndarray(a)
return a
@np_utils.np_doc('moveaxis')
@ -857,7 +832,7 @@ def moveaxis(a, source, destination): # pylint: disable=missing-docstring
if not source and not destination:
return a
a = asarray(a).data
a = asarray(a)
if isinstance(source, int):
source = (source,)
@ -908,13 +883,7 @@ def moveaxis(a, source, destination): # pylint: disable=missing-docstring
perm, array_ops.expand_dims(destination, 1), source)
a = array_ops.transpose(a, perm)
return np_utils.tensor_to_ndarray(a)
# TODO(wangpeng): Make a custom `setattr` that also sets docstring for the
# method.
setattr(np_arrays.ndarray, 'transpose', transpose)
setattr(np_arrays.ndarray, 'reshape', _reshape_method_wrapper)
return a
@np_utils.np_doc('pad')
@ -926,12 +895,11 @@ def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-na
mode = mode.upper()
array = asarray(array)
pad_width = asarray(pad_width, dtype=dtypes.int32)
return np_utils.tensor_to_ndarray(
array_ops.pad(
tensor=array.data,
paddings=pad_width.data,
mode=mode,
constant_values=constant_values))
return array_ops.pad(
tensor=array,
paddings=pad_width,
mode=mode,
constant_values=constant_values)
@np_utils.np_doc('take')
@ -943,8 +911,8 @@ def take(a, indices, axis=None, out=None, mode='clip'):
if mode not in {'raise', 'clip', 'wrap'}:
raise ValueError("Invalid mode '{}' for take".format(mode))
a = asarray(a).data
indices = asarray(indices).data
a = asarray(a)
indices = asarray(indices)
if axis is None:
a = array_ops.reshape(a, [-1])
@ -958,7 +926,7 @@ def take(a, indices, axis=None, out=None, mode='clip'):
else:
raise ValueError("The 'raise' mode to take is not supported.")
return np_utils.tensor_to_ndarray(array_ops.gather(a, indices, axis=axis))
return array_ops.gather(a, indices, axis=axis)
@np_utils.np_doc_only('where')
@ -969,8 +937,7 @@ def where(condition, x=None, y=None):
return nonzero(condition)
elif x is not None and y is not None:
x, y = _promote_dtype(x, y)
return np_utils.tensor_to_ndarray(
array_ops.where_v2(condition.data, x.data, y.data))
return array_ops.where_v2(condition, x, y)
raise ValueError('Both x and y must be ndarrays, or both must be None.')
@ -1044,8 +1011,7 @@ def split(ary, indices_or_sections, axis=0):
ary = asarray(ary)
if not isinstance(indices_or_sections, six.integer_types):
indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis)
result = array_ops.split(ary.data, indices_or_sections, axis=axis)
return [np_utils.tensor_to_ndarray(a) for a in result]
return array_ops.split(ary, indices_or_sections, axis=axis)
def _split_on_axis(np_fun_name, axis):
@ -1077,7 +1043,7 @@ def stack(arrays, axis=0): # pylint: disable=missing-function-docstring
return swapaxes(arrays, 0, axis)
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
unwrapped_arrays = [
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
a if isinstance(a, np_arrays.ndarray) else a for a in arrays
]
return asarray(array_ops.stack(unwrapped_arrays, axis))
@ -1087,7 +1053,7 @@ def hstack(tup):
arrays = [atleast_1d(a) for a in tup]
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
unwrapped_arrays = [
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
a if isinstance(a, np_arrays.ndarray) else a for a in arrays
]
rank = array_ops.rank(unwrapped_arrays[0])
return np_utils.cond(
@ -1101,7 +1067,7 @@ def vstack(tup):
arrays = [atleast_2d(a) for a in tup]
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
unwrapped_arrays = [
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
a if isinstance(a, np_arrays.ndarray) else a for a in arrays
]
return array_ops.concat(unwrapped_arrays, axis=0)
@ -1111,13 +1077,13 @@ def dstack(tup):
arrays = [atleast_3d(a) for a in tup]
arrays = _promote_dtype(*arrays) # pylint: disable=protected-access
unwrapped_arrays = [
a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays
a if isinstance(a, np_arrays.ndarray) else a for a in arrays
]
return array_ops.concat(unwrapped_arrays, axis=2)
def _pad_left_to(n, old_shape):
old_shape = asarray(old_shape, dtype=np.int32).data
old_shape = asarray(old_shape, dtype=np.int32)
new_shape = array_ops.pad(
old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]],
constant_values=1)
@ -1143,8 +1109,8 @@ def _atleast_nd(n, new_shape, *arys):
return asarray(
np_utils.cond(
np_utils.greater(n, array_ops.rank(x)),
lambda: reshape(x, new_shape(n, array_ops.shape(x.data))).data,
lambda: x.data))
lambda: reshape(x, new_shape(n, array_ops.shape(x))),
lambda: x))
arys = list(map(f, arys))
if len(arys) == 1:
@ -1182,16 +1148,14 @@ def atleast_3d(*arys): # pylint: disable=missing-docstring
@np_utils.np_doc('nonzero')
def nonzero(a):
a = atleast_1d(a).data
a = atleast_1d(a)
if a.shape.rank is None:
raise ValueError("The rank of `a` is unknown, so we can't decide how many "
'arrays to return.')
return nest.map_structure(
np_arrays.tensor_to_ndarray,
array_ops.unstack(
array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
a.shape.rank,
axis=1))
return array_ops.unstack(
array_ops.where_v2(math_ops.cast(a, dtypes.bool)),
a.shape.rank,
axis=1)
@np_utils.np_doc('diag_indices')
@ -1231,12 +1195,12 @@ def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-doc
r = o
else:
r = array_ops.matrix_band_part(o, -1, k)
return np_utils.tensor_to_ndarray(r)
return r
@np_utils.np_doc('tril')
def tril(m, k=0): # pylint: disable=missing-docstring
m = asarray(m).data
m = asarray(m)
if m.shape.ndims is None:
raise ValueError('Argument to tril should have known rank')
m_shape = m.shape.as_list()
@ -1251,14 +1215,13 @@ def tril(m, k=0): # pylint: disable=missing-docstring
z = constant_op.constant(0, m.dtype)
mask = tri(*m_shape[-2:], k=k, dtype=bool)
return np_utils.tensor_to_ndarray(
array_ops.where_v2(
array_ops.broadcast_to(mask, array_ops.shape(m)), m, z))
return array_ops.where_v2(
array_ops.broadcast_to(mask, array_ops.shape(m)), m, z)
@np_utils.np_doc('triu')
def triu(m, k=0): # pylint: disable=missing-docstring
m = asarray(m).data
m = asarray(m)
if m.shape.ndims is None:
raise ValueError('Argument to triu should have known rank')
m_shape = m.shape.as_list()
@ -1273,22 +1236,20 @@ def triu(m, k=0): # pylint: disable=missing-docstring
z = constant_op.constant(0, m.dtype)
mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
return np_utils.tensor_to_ndarray(
array_ops.where_v2(
array_ops.broadcast_to(mask, array_ops.shape(m)), z, m))
return array_ops.where_v2(
array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)
@np_utils.np_doc('flip')
def flip(m, axis=None): # pylint: disable=missing-docstring
m = asarray(m).data
m = asarray(m)
if axis is None:
return np_utils.tensor_to_ndarray(
array_ops.reverse(m, math_ops.range(array_ops.rank(m))))
return array_ops.reverse(m, math_ops.range(array_ops.rank(m)))
axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access
return np_utils.tensor_to_ndarray(array_ops.reverse(m, [axis]))
return array_ops.reverse(m, [axis])
@np_utils.np_doc('flipud')
@ -1303,15 +1264,15 @@ def fliplr(m): # pylint: disable=missing-docstring
@np_utils.np_doc('roll')
def roll(a, shift, axis=None): # pylint: disable=missing-docstring
a = asarray(a).data
a = asarray(a)
if axis is not None:
return np_utils.tensor_to_ndarray(manip_ops.roll(a, shift, axis))
return manip_ops.roll(a, shift, axis)
# If axis is None, the roll happens as a 1-d tensor.
original_shape = array_ops.shape(a)
a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0)
return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape))
return array_ops.reshape(a, original_shape)
@np_utils.np_doc('rot90')
@ -1336,7 +1297,7 @@ def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring
@np_utils.np_doc('vander')
def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name
x = asarray(x).data
x = asarray(x)
x_shape = array_ops.shape(x)
N = N or x_shape[0]
@ -1368,9 +1329,8 @@ def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,in
delta = -1
x = array_ops.expand_dims(x, -1)
return np_utils.tensor_to_ndarray(
math_ops.pow(
x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)))
return math_ops.pow(
x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))
@np_utils.np_doc('ix_')
@ -1378,7 +1338,7 @@ def ix_(*args): # pylint: disable=missing-docstring
n = len(args)
output = []
for i, a in enumerate(args):
a = asarray(a).data
a = asarray(a)
a_rank = array_ops.rank(a)
a_rank_temp = np_utils.get_static_value(a_rank)
if a_rank_temp is not None:
@ -1393,11 +1353,9 @@ def ix_(*args): # pylint: disable=missing-docstring
new_shape[i] = -1
dtype = a.dtype
if dtype == dtypes.bool:
output.append(
np_utils.tensor_to_ndarray(
array_ops.reshape(nonzero(a)[0].data, new_shape)))
output.append(array_ops.reshape(nonzero(a)[0], new_shape))
elif dtype.is_integer:
output.append(np_utils.tensor_to_ndarray(array_ops.reshape(a, new_shape)))
output.append(array_ops.reshape(a, new_shape))
else:
raise ValueError(
'Only integer and bool dtypes are supported, got {}'.format(dtype))
@ -1413,9 +1371,8 @@ def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring
if kwargs:
raise ValueError('Received unsupported arguments {}'.format(kwargs.keys()))
args = [asarray(arg).data for arg in args]
args = np_utils.tf_broadcast(*args)
return [np_utils.tensor_to_ndarray(arg) for arg in args]
args = [asarray(arg) for arg in args]
return np_utils.tf_broadcast(*args)
@np_utils.np_doc_only('sign')
@ -1428,13 +1385,13 @@ def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstrin
raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys()))
x = asarray(x)
dtype = x.dtype
dtype = x.dtype.as_numpy_dtype
if np.issubdtype(dtype, np.complex):
result = math_ops.cast(math_ops.sign(math_ops.real(x.data)), dtype)
result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype)
else:
result = math_ops.sign(x.data)
result = math_ops.sign(x)
return np_utils.tensor_to_ndarray(result)
return result
# Note that np.take_along_axis may not be present in some supported versions of
@ -1447,9 +1404,6 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
if axis is None:
return take_along_axis(arr.ravel(), indices, 0)
arr = arr.data
indices = indices.data
rank = array_ops.rank(arr)
axis = axis + rank if axis < 0 else axis
@ -1475,7 +1429,7 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
# Correct indices since gather doesn't correctly handle negative indices.
indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices)
swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data
swapaxes_ = lambda t: swapaxes(t, axis, -1)
dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1))
arr = np_utils.cond(dont_move_axis_to_end, lambda: arr,
@ -1495,7 +1449,7 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring
lambda: swapaxes_(result))
result.set_shape(possible_result_shape)
return np_utils.tensor_to_ndarray(result)
return result
_SLICE_ERORR = (
@ -1519,7 +1473,7 @@ def _as_index(idx, need_scalar=True):
"""
if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
return idx, True
data = asarray(idx).data
data = asarray(idx)
if data.dtype == dtypes.bool:
if data.shape.ndims != 1:
# TODO(agarwal): handle higher rank boolean masks.
@ -1730,14 +1684,14 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
dims_contiguous = False
break
indices = [advanced_indices_map[x] for x in dims]
indices = [x.data for x in _promote_dtype(*indices)]
indices = _promote_dtype(*indices)
indices = np_utils.tf_broadcast(*indices)
stacked_indices = array_ops.stack(indices, axis=-1)
# Skip the contiguous-dims optimization for update because there is no
# tf.*scatter* op that supports the `axis` argument.
if not dims_contiguous or updates is not None:
if range(len(dims)) != dims:
tensor = moveaxis(tensor, dims, range(len(dims))).data
tensor = moveaxis(tensor, dims, range(len(dims)))
tensor_shape_prefix = array_ops.shape(
tensor, out_type=stacked_indices.dtype)[:len(dims)]
stacked_indices = array_ops.where_v2(
@ -1763,7 +1717,7 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
def range_(start, length):
return range(start, start + length)
updates = moveaxis(updates, range_(batch_start, batch_size),
range(batch_size)).data
range(batch_size))
if update_method == _UpdateMethod.UPDATE:
update_op = array_ops.tensor_scatter_update
elif update_method == _UpdateMethod.ADD:
@ -1775,7 +1729,7 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None):
tensor = update_op(
tensor, stacked_indices, updates)
if range(len(dims)) != dims:
tensor = moveaxis(tensor, range(len(dims)), dims).data
tensor = moveaxis(tensor, range(len(dims)), dims)
return array_ops.tensor_strided_slice_update(
original_tensor,
packed_begin,
@ -1842,14 +1796,13 @@ def _getitem(self, slice_spec):
slice_spec.dtype == dtypes.bool) or
(isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and
slice_spec.dtype == np.bool)):
return np_utils.tensor_to_ndarray(
array_ops.boolean_mask(tensor=self.data, mask=slice_spec))
return array_ops.boolean_mask(tensor=self, mask=slice_spec)
if not isinstance(slice_spec, tuple):
slice_spec = _as_spec_tuple(slice_spec)
result_t = _slice_helper(self.data, slice_spec)
return np_utils.tensor_to_ndarray(result_t)
result_t = _slice_helper(self, slice_spec)
return result_t
def _with_index_update_helper(update_method, a, slice_spec, updates):
@ -1865,11 +1818,11 @@ def _with_index_update_helper(update_method, a, slice_spec, updates):
a_dtype = a.dtype
a, updates = _promote_dtype_binary(a, updates)
result_t = _slice_helper(a.data, slice_spec, update_method, updates.data)
return np_utils.tensor_to_ndarray(result_t).astype(a_dtype)
result_t = _slice_helper(a, slice_spec, update_method, updates)
return result_t.astype(a_dtype)
setattr(np_arrays.ndarray, '__getitem__', _getitem)
setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem)
setattr(np_arrays.ndarray, '_with_index_update',
functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE))
setattr(np_arrays.ndarray, '_with_index_add',

View File

@ -36,6 +36,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.numpy_ops import np_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.numpy_ops import np_math_ops
from tensorflow.python.platform import test
@ -305,49 +306,47 @@ class ArrayCreationTest(test.TestCase):
def test_copy_equal_false():
# Backing tensor is the same if copy=False, other attributes being None.
self.assertIs(
np_array_ops.array(zeros_list, copy=False).data, zeros_list.data)
self.assertIs(
np_array_ops.array(zeros_list.data, copy=False).data, zeros_list.data)
self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list)
self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list)
# Backing tensor is different if ndmin is not satisfied.
self.assertIsNot(
np_array_ops.array(zeros_list, copy=False, ndmin=2).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, ndmin=2),
zeros_list)
self.assertIsNot(
np_array_ops.array(zeros_list.data, copy=False, ndmin=2).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, ndmin=2),
zeros_list)
self.assertIs(
np_array_ops.array(zeros_list, copy=False, ndmin=1).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, ndmin=1),
zeros_list)
self.assertIs(
np_array_ops.array(zeros_list.data, copy=False, ndmin=1).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, ndmin=1),
zeros_list)
# Backing tensor is different if dtype is not satisfied.
self.assertIsNot(
np_array_ops.array(zeros_list, copy=False, dtype=int).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, dtype=int),
zeros_list)
self.assertIsNot(
np_array_ops.array(zeros_list.data, copy=False, dtype=int).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, dtype=int),
zeros_list)
self.assertIs(
np_array_ops.array(zeros_list, copy=False, dtype=float).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, dtype=float),
zeros_list)
self.assertIs(
np_array_ops.array(zeros_list.data, copy=False, dtype=float).data,
zeros_list.data)
np_array_ops.array(zeros_list, copy=False, dtype=float),
zeros_list)
test_copy_equal_false()
with ops.device('CPU:1'):
test_copy_equal_false()
self.assertNotIn('CPU:1', zeros_list.data.backing_device)
self.assertNotIn('CPU:1', zeros_list.backing_device)
with ops.device('CPU:1'):
self.assertIn('CPU:1', np_array_ops.array(zeros_list, copy=True).data
.backing_device)
self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True).data
.backing_device)
self.assertIn(
'CPU:1', np_array_ops.array(zeros_list, copy=True).backing_device)
self.assertIn(
'CPU:1', np_array_ops.array(np.array(0), copy=True).backing_device)
def testAsArray(self):
for a, dtype in itertools.product(self.all_arrays, self.all_types):
@ -515,9 +514,6 @@ class ArrayCreationTest(test.TestCase):
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
msg, expected.shape, actual.shape)
self.assertEqual(actual.shape, expected.shape, msg=msg)
if msg:
msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
self.assertIsInstance(actual.shape, tuple, msg=msg)
def match_dtype(self, actual, expected, msg=None):
if msg:
@ -535,7 +531,7 @@ class ArrayCreationTest(test.TestCase):
self.match_dtype(actual, expected, msg)
self.match_shape(actual, expected, msg)
if not almost:
if not actual.shape:
if not actual.shape.rank:
self.assertEqual(actual.tolist(), expected.tolist())
else:
self.assertSequenceEqual(actual.tolist(), expected.tolist())
@ -636,11 +632,11 @@ class ArrayMethodsTest(test.TestCase):
run_test(np.arange(9).reshape((3, 3)).tolist())
a = np_array_ops.asarray(0)
self.assertNotIn('CPU:1', a.data.backing_device)
self.assertNotIn('CPU:1', a.backing_device)
with ops.device('CPU:1'):
self.assertIn('CPU:1', np_array_ops.array(a, copy=True).data
self.assertIn('CPU:1', np_array_ops.array(a, copy=True)
.backing_device)
self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True).data
self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True)
.backing_device)
def testCumProdAndSum(self):
@ -824,12 +820,13 @@ class ArrayMethodsTest(test.TestCase):
self.assertRaises(NotImplementedError, np_array_ops.size, np.ones((2, 2)),
1)
@def_function.function(input_signature=[tensor_spec.TensorSpec(shape=None)])
@def_function.function(input_signature=[
tensor_spec.TensorSpec(dtype=dtypes.float64, shape=None)])
def f(arr):
arr = np_array_ops.asarray(arr)
return np_array_ops.size(arr)
self.assertEqual(f(np_array_ops.ones((3, 2))).data.numpy(), 6)
self.assertEqual(f(np_array_ops.ones((3, 2))).numpy(), 6)
def testRavel(self):
@ -984,9 +981,6 @@ class ArrayMethodsTest(test.TestCase):
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
msg, expected.shape, actual.shape)
self.assertEqual(actual.shape, expected.shape, msg=msg)
if msg:
msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
self.assertIsInstance(actual.shape, tuple, msg=msg)
def match_dtype(self, actual, expected, msg=None):
if msg:
@ -1004,7 +998,7 @@ class ArrayMethodsTest(test.TestCase):
if check_dtype:
self.match_dtype(actual, expected, msg)
self.match_shape(actual, expected, msg)
if not actual.shape:
if not actual.shape.rank:
self.assertAllClose(actual.tolist(), expected.tolist())
else:
self.assertAllClose(actual.tolist(), expected.tolist())
@ -1165,9 +1159,6 @@ class ArrayManipulationTest(test.TestCase):
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
msg, expected.shape, actual.shape)
self.assertEqual(actual.shape, expected.shape, msg=msg)
if msg:
msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
self.assertIsInstance(actual.shape, tuple, msg=msg)
def match_dtype(self, actual, expected, msg=None):
if msg:
@ -1184,7 +1175,7 @@ class ArrayManipulationTest(test.TestCase):
self.assertIsInstance(actual, np_arrays.ndarray)
self.match_dtype(actual, expected, msg)
self.match_shape(actual, expected, msg)
if not actual.shape:
if not actual.shape.rank:
self.assertEqual(actual.tolist(), expected.tolist())
else:
self.assertSequenceEqual(actual.tolist(), expected.tolist())
@ -1192,4 +1183,6 @@ class ArrayManipulationTest(test.TestCase):
if __name__ == '__main__':
ops.enable_eager_execution()
ops.enable_numpy_style_type_promotion()
np_math_ops.enable_numpy_methods_on_tensor()
test.main()

View File

@ -20,18 +20,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import six
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.ops.numpy_ops import np_export
def convert_to_tensor(value, dtype=None, dtype_hint=None):
@ -58,297 +51,4 @@ def convert_to_tensor(value, dtype=None, dtype_hint=None):
return ops.convert_to_tensor(value, dtype=dtype, dtype_hint=dtype_hint)
class NdarraySpec(type_spec.BatchableTypeSpec):
"""Type specification for a `tf.experiemntal.numpy.ndarray`."""
value_type = property(lambda self: ndarray)
def __init__(self, data_spec):
if not isinstance(data_spec, tensor_spec.TensorSpec):
raise ValueError('NdarraySpec.__init__ was expecting a tf.TypeSpec, '
'but got a {} instead.'.format(type(data_spec)))
self._data_spec = data_spec
self._hash = None
@property
def _component_specs(self):
return self._data_spec
def _to_components(self, value):
return value.data
def _from_components(self, data):
return tensor_to_ndarray(data)
def _serialize(self):
return (self._data_spec,)
def _batch(self, batch_size):
return NdarraySpec(self._data_spec._batch(batch_size)) # pylint: disable=protected-access
def _unbatch(self):
return NdarraySpec(self._data_spec._unbatch()) # pylint: disable=protected-access
def __hash__(self):
if self._hash is None:
self._hash = hash((type(self), self._data_spec))
return self._hash
@np_export.np_export('ndarray') # pylint: disable=invalid-name
class ndarray(composite_tensor.CompositeTensor):
"""Equivalent of numpy.ndarray backed by TensorFlow tensors.
This does not support all features of NumPy ndarrays e.g. strides and
memory order since, unlike NumPy, the backing storage is not a raw memory
buffer.
TODO(srbs): Clearly specify which attributes and methods are not supported
or if there are any differences in behavior.
"""
__slots__ = ['_data', '_dtype', '_type_spec_internal']
def __init__(self, shape, dtype=float, buffer=None): # pylint: disable=redefined-builtin
"""Initializes an ndarray.
This is a low level interface for building ndarrays and should be avoided.
Users should instead use methods in array_creation.py.
This class provides a numpy.ndarray like interface for a TF Tensor with a
fully-defined shape. Note that, unlike the backing buffer of np.ndarray,
Tensors are immutable. So, operations like `__setitem__` are performed by
replacing the Tensor. This restricts the ability to implement NumPy `view`
semantics.
Compared to numpy.ndarray, this does not support `offset`, `strides`
and `order` arguments.
Args:
shape: The shape of the array. Must be a scalar, an iterable of integers
or a `TensorShape` object.
dtype: Optional. The dtype of the array. Must be a python type, a numpy
type or a tensorflow `DType` object.
buffer: Optional. The backing buffer of the array. Must have shape
`shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`.
Raises:
ValueError: If `buffer` is specified and its shape does not match
`shape`.
"""
if dtype and not isinstance(dtype, dtypes.DType):
dtype = dtypes.as_dtype(np.dtype(dtype))
if buffer is None:
buffer = array_ops.zeros(shape, dtype=dtype)
else:
if isinstance(buffer, ndarray):
buffer = buffer.data
elif isinstance(buffer, np.ndarray):
# If `buffer` is a np.ndarray, the Tensor will share the underlying
# storage of the array.
buffer = convert_to_tensor(value=buffer, dtype=dtype)
elif not isinstance(buffer, ops.Tensor):
raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
' Tensor or np.ndarray.'.format(type(buffer)))
if shape is not None:
buffer.set_shape(shape)
assert isinstance(buffer, ops.Tensor)
if dtype and dtype != buffer.dtype:
buffer = math_ops.cast(buffer, dtype)
self._data = buffer
self._type_spec_internal = None
self._dtype = None
@classmethod
def from_tensor(cls, tensor):
o = cls.__new__(cls, None)
# pylint: disable=protected-access
o._data = tensor
o._dtype = None
o._type_spec_internal = None
# pylint: enable=protected-access
return o
@property
def _type_spec(self):
if self._type_spec_internal is None:
self._type_spec_internal = NdarraySpec(
type_spec.type_spec_from_value(self._data))
return self._type_spec_internal
@property
def data(self):
"""Tensor object containing the array data.
This has a few key differences from the Python buffer object used in
NumPy arrays.
1. Tensors are immutable. So operations requiring in-place edit, e.g.
__setitem__, are performed by replacing the underlying buffer with a new
one.
2. Tensors do not provide access to their raw buffer.
Returns:
A Tensor.
"""
return self._data
@property
def shape(self):
"""Returns a tuple or tf.Tensor of array dimensions."""
shape = self.data.shape
if shape.is_fully_defined():
return tuple(shape.as_list())
else:
return array_ops.shape(self.data)
@property
def dtype(self):
if self._dtype is None:
self._dtype = np_dtypes._get_cached_dtype(self._data.dtype) # pylint: disable=protected-access
return self._dtype
def _is_boolean(self):
return self._data.dtype == dtypes.bool
@property
def ndim(self):
ndims = self.data.shape.ndims
if ndims is None:
return array_ops.rank(self.data)
else:
return ndims
@property
def size(self):
"""Returns the number of elements in the array."""
shape = self.shape
if isinstance(shape, ops.Tensor):
return array_ops.size(self.data)
else:
return np.prod(self.shape)
@property
def T(self): # pylint: disable=invalid-name
return self.transpose()
def __len__(self):
shape = self.shape
if isinstance(shape, ops.Tensor):
raise TypeError('len() of symbolic tensor undefined')
elif shape:
return self.shape[0]
else:
raise TypeError('len() of unsized object.')
def astype(self, dtype):
if self.dtype == dtype:
return self
else:
return tensor_to_ndarray(math_ops.cast(self.data, dtype))
# Unary operations
def __neg__(self):
return tensor_to_ndarray(-self.data) # pylint: disable=invalid-unary-operand-type
def __pos__(self):
return self
__hash__ = None
def __int__(self):
return int(self.data)
def __float__(self):
return float(self.data)
def __bool__(self):
return bool(self.data)
def __nonzero__(self):
return self.__bool__()
def __iter__(self):
if not isinstance(self.data, ops.EagerTensor):
raise TypeError('Iteration over symbolic tensor is not allowed')
for i in range(self.shape[0]):
result_t = self.data[i]
yield tensor_to_ndarray(result_t)
return
def __array__(self, dtype=None):
"""Returns a NumPy ndarray.
This allows instances of this class to be directly used in NumPy routines.
However, doing that may force a copy to CPU.
Args:
dtype: A NumPy compatible type.
Returns:
A NumPy ndarray.
"""
return np.asarray(self.data, dtype)
# NOTE: we currently prefer interop with TF to allow TF to take precedence.
__array_priority__ = 90
def __array_module__(self, types):
# Experimental support for NumPy's module dispatch with NEP-37:
# https://numpy.org/neps/nep-0037-array-module.html
# Currently requires https://github.com/seberg/numpy-dispatch
# pylint: disable=g-import-not-at-top
import tensorflow.compat.v2 as tf
if all(issubclass(t, (ndarray, np.ndarray)) for t in types):
return tf.experimental.numpy
else:
return NotImplemented
def __index__(self):
"""Returns a python scalar.
This allows using an instance of this class as an array index.
Note that only arrays of integer types with size 1 can be used as array
indices.
Returns:
A Python scalar.
Raises:
TypeError: If the array is not of an integer type.
ValueError: If the array does not have size 1.
"""
# TODO(wangpeng): Handle graph mode
if not isinstance(self.data, ops.EagerTensor):
raise TypeError('Indexing using symbolic tensor is not allowed')
return self.data.numpy().item()
def tolist(self):
return self.data.numpy().tolist()
def __str__(self):
return 'ndarray<{}>'.format(self.data.__str__())
def __repr__(self):
return 'ndarray<{}>'.format(self.data.__repr__())
def tensor_to_ndarray(tensor):
return ndarray.from_tensor(tensor)
def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
if as_ref:
raise ValueError('as_ref is not supported.')
if dtype and dtypes.as_dtype(arr.dtype) != dtype:
return math_ops.cast(arr.data, dtype)
result_t = arr.data
if name:
result_t = array_ops.identity(result_t, name=name)
return result_t
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
ndarray = ops.Tensor

View File

@ -18,11 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -32,48 +30,33 @@ from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=unuse
from tensorflow.python.platform import test
from tensorflow.python.util import nest
t2a = np_arrays.tensor_to_ndarray
class ArrayTest(test.TestCase):
def testDtype(self):
a = t2a(array_ops.zeros(shape=[1, 2], dtype=dtypes.int64))
self.assertIs(a.dtype.type, np.int64)
self.assertAllEqual(0, a.dtype.type(0))
a = array_ops.zeros(shape=[1, 2], dtype=dtypes.int64)
self.assertIs(a.dtype.as_numpy_dtype, np.int64)
np_dt = a.dtype.as_numpy_dtype
self.assertAllEqual(0, np_dt(0))
def testAstype(self):
a = t2a(ops.convert_to_tensor(value=1.1,
dtype=dtypes.float32)).astype(np.int32)
self.assertIs(a.dtype.type, np.int32)
a = ops.convert_to_tensor(value=1.1, dtype=dtypes.float32).astype(np.int32)
self.assertIs(a.dtype.as_numpy_dtype, np.int32)
self.assertAllEqual(1, a)
a = t2a(ops.convert_to_tensor(value=[0.0, 1.1],
dtype=dtypes.float32)).astype(np.bool_)
self.assertIs(a.dtype.type, np.bool_)
a = ops.convert_to_tensor(value=[0.0, 1.1], dtype=dtypes.float32).astype(
np.bool_)
self.assertIs(a.dtype.as_numpy_dtype, np.bool_)
self.assertAllEqual([False, True], a)
def testConstructor(self):
t = constant_op.constant([[1], [1]])
a = np_arrays.ndarray(shape=(2, 1), buffer=t)
self.assertAllEqual(t, a)
self.assertEqual(dtypes.float64, a.dtype)
a = np_arrays.ndarray(shape=(2, 1), dtype=dtypes.int32, buffer=t)
self.assertAllEqual(t, a)
self.assertEqual(dtypes.int32, a.dtype)
with self.assertRaises(ValueError): # bad shape
_ = np_arrays.ndarray((2, 2), buffer=t)
def testNeg(self):
a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
self.assertAllEqual([-1.0, -2.0], -a)
a = ops.convert_to_tensor(value=[1.0, 2.0])
self.assertAllEqual([-1.0, -2.0], -a) # pylint: disable=invalid-unary-operand-type
def _testBinOp(self, a, b, out, f, types=None):
a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
a = ops.convert_to_tensor(value=a, dtype=np.int32)
b = ops.convert_to_tensor(value=b, dtype=np.int32)
if not isinstance(out, np_arrays.ndarray):
out = t2a(ops.convert_to_tensor(value=out, dtype=np.int32))
out = ops.convert_to_tensor(value=out, dtype=np.int32)
if types is None:
types = [[np.int32, np.int32, np.int32], [np.int64, np.int32, np.int64],
[np.int32, np.int64, np.int64],
@ -84,7 +67,7 @@ class ArrayTest(test.TestCase):
[np.float32, np.float64, np.float64]]
for a_type, b_type, out_type in types:
o = f(a.astype(a_type), b.astype(b_type))
self.assertIs(o.dtype.type, out_type)
self.assertIs(o.dtype.as_numpy_dtype, out_type)
out = out.astype(out_type)
if np.issubdtype(out_type, np.inexact):
self.assertAllClose(out, o)
@ -126,19 +109,20 @@ class ArrayTest(test.TestCase):
def testTruediv(self):
self._testBinOp([3, 5], [2, 4],
t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
ops.convert_to_tensor(value=[1.5, 1.25]),
lambda a, b: a.__truediv__(b),
types=self._truediv_types)
def testRtruediv(self):
self._testBinOp([3, 5], [2, 4],
t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
ops.convert_to_tensor(value=[1.5, 1.25]),
lambda a, b: b.__rtruediv__(a),
types=self._truediv_types)
def _testCmp(self, a, b, out, f):
a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
a = ops.convert_to_tensor(value=a, dtype=np.int32)
b = ops.convert_to_tensor(value=b, dtype=np.int32)
types = [[np.int32, np.int32], [np.int64, np.int32], [np.int32, np.int64],
[np.float32, np.int32], [np.int32, np.float32],
[np.float32, np.float32], [np.float64, np.float32],
@ -173,32 +157,41 @@ class ArrayTest(test.TestCase):
def testInt(self):
v = 10
u = int(t2a(ops.convert_to_tensor(value=v)))
u = int(ops.convert_to_tensor(value=v))
self.assertIsInstance(u, int)
self.assertAllEqual(v, u)
def testFloat(self):
v = 21.32
u = float(t2a(ops.convert_to_tensor(value=v)))
u = float(ops.convert_to_tensor(value=v))
self.assertIsInstance(u, float)
self.assertAllClose(v, u)
def testBool(self):
b = bool(t2a(ops.convert_to_tensor(value=10)))
b = bool(ops.convert_to_tensor(value=10))
self.assertIsInstance(b, bool)
self.assertTrue(b)
self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0))))
self.assertTrue(bool(t2a(ops.convert_to_tensor(value=0.1))))
self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0.0))))
self.assertFalse(bool(ops.convert_to_tensor(value=0)))
self.assertTrue(bool(ops.convert_to_tensor(value=0.1)))
self.assertFalse(bool(ops.convert_to_tensor(value=0.0)))
def testHash(self):
a = t2a(ops.convert_to_tensor(value=10))
self.assertNotIsInstance(a, collections.Hashable)
with self.assertRaisesWithPredicateMatch(TypeError, r'unhashable type'):
a = ops.convert_to_tensor(value=10)
def eager():
hash(a)
def graph():
@def_function.function
def f(x):
hash(x)
f(a)
for f in [eager, graph]:
with self.assertRaisesRegexp(
TypeError,
r'Tensor is unhashable. Instead, use tensor.ref\(\) as the key.'):
f()
def testFromToCompositeTensor(self):
tensors = [t2a(ops.convert_to_tensor(0.1)), t2a(ops.convert_to_tensor(0.2))]
tensors = [ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)]
flattened = nest.flatten(tensors, expand_composites=True)
# Each ndarray contains only one tensor, so the flattened output should be
@ -216,6 +209,10 @@ class ArrayTest(test.TestCase):
if __name__ == '__main__':
# TODO(wangpeng): Test in graph mode as well.
# TODO(wangpeng): Test in graph mode as well. Also test in V2 (the requirement
# for setting _USE_EQUALITY points to V2 behavior not being on).
ops.enable_eager_execution()
ops.Tensor._USE_EQUALITY = True
ops.enable_numpy_style_type_promotion()
np_math_ops.enable_numpy_methods_on_tensor()
test.main()

View File

@ -0,0 +1,39 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config functions for TF NumPy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.ops.numpy_ops import np_math_ops
def enable_numpy_behavior(prefer_float32=False):
"""Enable NumPy behavior on Tensors.
Includes addition of methods, type promotion on operator overloads and
support for NumPy-style slicing.
Args:
prefer_float32: Whether to allow type inference to use float32, or use
float64 similar to NumPy.
"""
ops.enable_numpy_style_type_promotion()
ops.enable_numpy_style_slicing()
np_math_ops.enable_numpy_methods_on_tensor()
np_dtypes.set_prefer_float32(prefer_float32)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.ops.numpy_ops import np_export
@ -63,9 +64,27 @@ _to_float32 = {
_cached_np_dtypes = {}
# Difference between is_prefer_float32 and is_allow_float64: is_prefer_float32
# only decides which dtype to use for Python floats; is_allow_float64 decides
# whether float64 dtypes can ever appear in programs. The latter is more
# restrictive than the former.
_prefer_float32 = False
# TODO(b/178862061): Consider removing this knob
_allow_float64 = True
def is_prefer_float32():
return _prefer_float32
def set_prefer_float32(b):
global _prefer_float32
_prefer_float32 = b
def is_allow_float64():
return _allow_float64
@ -85,8 +104,13 @@ def canonicalize_dtype(dtype):
def _result_type(*arrays_and_dtypes):
def preprocess_float(x):
if is_prefer_float32() and isinstance(x, float):
return np.float32(x)
return x
arrays_and_dtypes = [preprocess_float(x) for x in arrays_and_dtypes]
dtype = np.result_type(*arrays_and_dtypes)
return canonicalize_dtype(dtype)
return dtypes.as_dtype(canonicalize_dtype(dtype))
def _get_cached_dtype(dtype):
@ -105,9 +129,10 @@ def default_float_type():
"""Gets the default float type.
Returns:
If `is_allow_float64()` is true, returns float64; otherwise returns float32.
If `is_prefer_float32()` is false and `is_allow_float64()` is true, returns
float64; otherwise returns float32.
"""
if is_allow_float64():
if not is_prefer_float32() and is_allow_float64():
return float64
else:
return float32

View File

@ -0,0 +1,57 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf-numpy dtype utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.platform import test
class DTypeTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters([False, True])
def testAllowF64False(self, prefer_f32):
np_dtypes.set_allow_float64(False)
np_dtypes.set_prefer_float32(prefer_f32)
self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
self.assertEqual(dtypes.float32,
np_dtypes._result_type(np.zeros([], np.float64), 1.1))
def testAllowF64TruePreferF32False(self):
np_dtypes.set_allow_float64(True)
np_dtypes.set_prefer_float32(False)
self.assertEqual(dtypes.float64, np_dtypes.default_float_type())
self.assertEqual(dtypes.float64, np_dtypes._result_type(1.1))
def testAllowF64TruePreferF32True(self):
np_dtypes.set_allow_float64(True)
np_dtypes.set_prefer_float32(True)
self.assertEqual(dtypes.float32, np_dtypes.default_float_type())
self.assertEqual(dtypes.float32, np_dtypes._result_type(1.1))
self.assertEqual(dtypes.float64,
np_dtypes._result_type(np.zeros([], np.float64), 1.1))
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()

View File

@ -21,7 +21,9 @@ from __future__ import print_function
import numpy as onp
import tensorflow.compat.v2 as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import numpy_ops as np
from tensorflow.python.ops.numpy_ops import np_math_ops
# Tests for code snippet put in README.md
@ -174,27 +176,26 @@ class InteropTest(tf.test.TestCase):
self.assertIsInstance(sq, onp.ndarray)
self.assertEqual(100., sq[0])
# TODO(b/171313773): why doesn't tensor have __array_module__
def testArrayModule(self):
self.skipTest("Tensor doesn't have __array_module__")
arr = np.asarray([10])
module = arr.__array_module__((np.ndarray,))
module = arr.__array_module__((tf.Tensor,))
self.assertIs(module, tf.experimental.numpy)
class Dummy:
pass
module = arr.__array_module__((np.ndarray, Dummy))
module = arr.__array_module__((tf.Tensor, Dummy))
self.assertIs(module, NotImplemented)
# TODO(nareshmodi): Fails since the autopacking code doesn't use
# nest.flatten.
# TODO(nareshmodi): Fails since the autopacking code doesn't use
# nest.flatten.
# def testAutopacking(self):
# arr1 = np.asarray(1.)
# arr2 = np.asarray(2.)
# arr3 = np.asarray(3.)
# t = ops.convert_to_tensor_v2([arr1, arr2, arr3])
# self.assertEqual(t.numpy(), [1., 2., 3.])
def testDistStratInterop(self):
@ -409,7 +410,9 @@ class FunctionTest(InteropTest):
def testLen(self):
@tf.function
# len can be fixed by autograph.
# TODO(wangpeng): this test can just be removed
@tf.function(autograph=False)
def f(x):
# Note that shape of input to len is data dependent.
return len(np.where(x)[0])
@ -451,5 +454,7 @@ class VariableTest(InteropTest):
if __name__ == '__main__':
ops.enable_numpy_style_type_promotion()
np_math_ops.enable_numpy_methods_on_tensor()
tf.compat.v1.enable_eager_execution()
tf.test.main()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf numpy random number methods."""
"""Tests for tf numpy logical methods."""
from __future__ import absolute_import
from __future__ import division
@ -76,9 +76,6 @@ class LogicTest(test.TestCase):
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
msg, expected.shape, actual.shape)
self.assertEqual(actual.shape, expected.shape, msg=msg)
if msg:
msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
self.assertIsInstance(actual.shape, tuple, msg=msg)
def match_dtype(self, actual, expected, msg=None):
if msg:
@ -95,16 +92,17 @@ class LogicTest(test.TestCase):
self.assertIsInstance(actual, np_arrays.ndarray)
self.match_dtype(actual, expected, msg)
self.match_shape(actual, expected, msg)
if not actual.shape:
if not actual.shape.rank:
self.assertEqual(actual.tolist(), expected.tolist())
else:
self.assertSequenceEqual(actual.tolist(), expected.tolist())
def make_numpy_compatible(s):
return s if not isinstance(s, np_arrays.ndarray) else s.data.numpy()
return s if not isinstance(s, np_arrays.ndarray) else s.numpy()
if __name__ == '__main__':
ops.enable_eager_execution()
np_math_ops.enable_numpy_methods_on_tensor()
test.main()

View File

@ -74,7 +74,7 @@ def _bin_op(tf_fun, a, b, promote=True):
else:
a = np_array_ops.array(a)
b = np_array_ops.array(b)
return np_utils.tensor_to_ndarray(tf_fun(a.data, b.data))
return tf_fun(a, b)
@np_utils.np_doc('add')
@ -177,9 +177,8 @@ def maximum(x1, x2): # pylint: disable=missing-function-docstring
# Fast path for when maximum is used as relu.
if isinstance(
x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance(
x1, np_arrays.ndarray) and not x1._is_boolean(): # pylint: disable=protected-access
return np_utils.tensor_to_ndarray(
nn_ops.relu(np_array_ops.asarray(x1).data))
x1, np_arrays.ndarray) and x1.dtype != dtypes.bool:
return nn_ops.relu(np_array_ops.asarray(x1))
def max_or_or(x1, x2):
if x1.dtype == dtypes.bool:
@ -212,12 +211,7 @@ def clip(a, a_min, a_max): # pylint: disable=missing-docstring
return maximum(a, a_min)
else:
a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access
return np_utils.tensor_to_ndarray(
clip_ops.clip_by_value(
*np_utils.tf_broadcast(a.data, a_min.data, a_max.data)))
setattr(np_arrays.ndarray, 'clip', clip)
return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max))
@np_utils.np_doc('matmul')
@ -241,6 +235,12 @@ def matmul(x1, x2): # pylint: disable=missing-docstring
return _bin_op(f, x1, x2)
# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles
# batched matmul as well, so simply including promotion in TF's current
# __matmul__ implementation was not sufficient.
setattr(np_arrays.ndarray, '_matmul', matmul)
@np_utils.np_doc('tensordot')
def tensordot(a, b, axes=2):
return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b)
@ -375,7 +375,7 @@ def heaviside(x1, x2): # pylint: disable=missing-function-docstring
array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2))
y = _bin_op(f, x1, x2)
if not np.issubdtype(y.dtype, np.inexact):
if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact):
y = y.astype(np_dtypes.default_float_type())
return y
@ -392,13 +392,13 @@ def kron(a, b): # pylint: disable=missing-function-docstring
t_a = np_utils.cond(
a.ndim < b.ndim,
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
a.data, np_array_ops._pad_left_to(b.ndim, a.shape)),
lambda: a.data)
a, np_array_ops._pad_left_to(b.ndim, a.shape)),
lambda: a)
t_b = np_utils.cond(
b.ndim < a.ndim,
lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda
b.data, np_array_ops._pad_left_to(a.ndim, b.shape)),
lambda: b.data)
b, np_array_ops._pad_left_to(a.ndim, b.shape)),
lambda: b)
def _make_shape(shape, prepend):
ones = array_ops.ones_like(shape)
@ -596,9 +596,9 @@ def _scalar(tf_fn, x, promote_to_float=False):
floating point type, in which case the output type is same as x.dtype.
"""
x = np_array_ops.asarray(x)
if promote_to_float and not np.issubdtype(x.dtype, np.inexact):
if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact):
x = x.astype(np_dtypes.default_float_type())
return np_utils.tensor_to_ndarray(tf_fn(x.data))
return tf_fn(x)
@np_utils.np_doc('log')
@ -814,7 +814,7 @@ def isreal(x):
@np_utils.np_doc('iscomplexobj')
def iscomplexobj(x):
x = np_array_ops.array(x)
return np.issubdtype(x.dtype, np.complexfloating)
return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating)
@np_utils.np_doc('isrealobj')
@ -850,11 +850,12 @@ nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1)
@np_utils.np_doc('nanmean')
def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring
a = np_array_ops.array(a)
if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer):
if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype(
a.dtype.as_numpy_dtype, np.integer):
return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
nan_mask = logical_not(isnan(a))
if dtype is None:
dtype = a.dtype
dtype = a.dtype.as_numpy_dtype
normalizer = np_array_ops.sum(
nan_mask, axis=axis, dtype=dtype, keepdims=keepdims)
return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer
@ -960,37 +961,16 @@ def _wrap(f, reverse=False):
return _f
setattr(np_arrays.ndarray, '__abs__', absolute)
setattr(np_arrays.ndarray, '__floordiv__', _wrap(floor_divide))
setattr(np_arrays.ndarray, '__rfloordiv__', _wrap(floor_divide, True))
setattr(np_arrays.ndarray, '__mod__', _wrap(mod))
setattr(np_arrays.ndarray, '__rmod__', _wrap(mod, True))
setattr(np_arrays.ndarray, '__add__', _wrap(add))
setattr(np_arrays.ndarray, '__radd__', _wrap(add, True))
setattr(np_arrays.ndarray, '__sub__', _wrap(subtract))
setattr(np_arrays.ndarray, '__rsub__', _wrap(subtract, True))
setattr(np_arrays.ndarray, '__mul__', _wrap(multiply))
setattr(np_arrays.ndarray, '__rmul__', _wrap(multiply, True))
setattr(np_arrays.ndarray, '__matmul__', _wrap(matmul))
setattr(np_arrays.ndarray, '__rmatmul__', _wrap(matmul, True))
setattr(np_arrays.ndarray, '__pow__', _wrap(power))
setattr(np_arrays.ndarray, '__rpow__', _wrap(power, True))
setattr(np_arrays.ndarray, '__truediv__', _wrap(true_divide))
setattr(np_arrays.ndarray, '__rtruediv__', _wrap(true_divide, True))
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
"""Helper function for comparision."""
dtype = np_utils.result_type(x1, x2)
# Cast x1 and x2 to the result_type if needed.
x1 = np_array_ops.array(x1, dtype=dtype)
x2 = np_array_ops.array(x2, dtype=dtype)
x1 = x1.data
x2 = x2.data
if cast_bool_to_int and x1.dtype == dtypes.bool:
x1 = math_ops.cast(x1, dtypes.int32)
x2 = math_ops.cast(x2, dtypes.int32)
return np_utils.tensor_to_ndarray(tf_fun(x1, x2))
return tf_fun(x1, x2)
@np_utils.np_doc('equal')
@ -1043,7 +1023,7 @@ def array_equal(a1, a2): # pylint: disable=missing-function-docstring
def _logical_binary_op(tf_fun, x1, x2):
x1 = np_array_ops.array(x1, dtype=np.bool_)
x2 = np_array_ops.array(x2, dtype=np.bool_)
return np_utils.tensor_to_ndarray(tf_fun(x1.data, x2.data))
return tf_fun(x1, x2)
@np_utils.np_doc('logical_and')
@ -1064,16 +1044,7 @@ def logical_xor(x1, x2):
@np_utils.np_doc('logical_not')
def logical_not(x):
x = np_array_ops.array(x, dtype=np.bool_)
return np_utils.tensor_to_ndarray(math_ops.logical_not(x.data))
setattr(np_arrays.ndarray, '__invert__', logical_not)
setattr(np_arrays.ndarray, '__lt__', _wrap(less))
setattr(np_arrays.ndarray, '__le__', _wrap(less_equal))
setattr(np_arrays.ndarray, '__gt__', _wrap(greater))
setattr(np_arrays.ndarray, '__ge__', _wrap(greater_equal))
setattr(np_arrays.ndarray, '__eq__', _wrap(equal))
setattr(np_arrays.ndarray, '__ne__', _wrap(not_equal))
return math_ops.logical_not(x)
@np_utils.np_doc('linspace')
@ -1087,8 +1058,8 @@ def linspace( # pylint: disable=missing-docstring
axis=0):
if dtype:
dtype = np_utils.result_type(dtype)
start = np_array_ops.array(start, dtype=dtype).data
stop = np_array_ops.array(stop, dtype=dtype).data
start = np_array_ops.array(start, dtype=dtype)
stop = np_array_ops.array(stop, dtype=dtype)
if num < 0:
raise ValueError('Number of samples {} must be non-negative.'.format(num))
step = ops.convert_to_tensor(np.nan)
@ -1109,28 +1080,27 @@ def linspace( # pylint: disable=missing-docstring
if dtype:
result = math_ops.cast(result, dtype)
if retstep:
return (np_arrays.tensor_to_ndarray(result),
np_arrays.tensor_to_ndarray(step))
return (result, step)
else:
return np_arrays.tensor_to_ndarray(result)
return result
@np_utils.np_doc('logspace')
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
dtype = np_utils.result_type(start, stop, dtype)
result = linspace(
start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis).data
start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis)
result = math_ops.pow(math_ops.cast(base, result.dtype), result)
if dtype:
result = math_ops.cast(result, dtype)
return np_arrays.tensor_to_ndarray(result)
return result
@np_utils.np_doc('geomspace')
def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring
dtype = dtype or np_utils.result_type(start, stop, float(num),
np_array_ops.zeros((), dtype))
computation_dtype = np.promote_types(dtype, np.float32)
dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type(
start, stop, float(num), np_array_ops.zeros((), dtype))
computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32)
start = np_array_ops.asarray(start, dtype=computation_dtype)
stop = np_array_ops.asarray(stop, dtype=computation_dtype)
# follow the numpy geomspace convention for negative and complex endpoints
@ -1147,7 +1117,7 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint
axis=0)
if axis != 0:
res = np_array_ops.moveaxis(res, 0, axis)
return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype))
return math_ops.cast(res, dtype)
@np_utils.np_doc('ptp')
@ -1163,14 +1133,14 @@ def concatenate(arys, axis=0):
if not arys:
raise ValueError('Need at least one array to concatenate.')
dtype = np_utils.result_type(*arys)
arys = [np_array_ops.array(array, dtype=dtype).data for array in arys]
return np_arrays.tensor_to_ndarray(array_ops.concat(arys, axis))
arys = [np_array_ops.array(array, dtype=dtype) for array in arys]
return array_ops.concat(arys, axis)
@np_utils.np_doc_only('tile')
def tile(a, reps): # pylint: disable=missing-function-docstring
a = np_array_ops.array(a).data
reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data
a = np_array_ops.array(a)
reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1])
a_rank = array_ops.rank(a)
reps_size = array_ops.size(reps)
@ -1181,13 +1151,12 @@ def tile(a, reps): # pylint: disable=missing-function-docstring
constant_values=1)
a = array_ops.reshape(a, a_shape)
return np_arrays.tensor_to_ndarray(array_ops.tile(a, reps))
return array_ops.tile(a, reps)
@np_utils.np_doc('count_nonzero')
def count_nonzero(a, axis=None):
return np_arrays.tensor_to_ndarray(
math_ops.count_nonzero(np_array_ops.array(a).data, axis))
return math_ops.count_nonzero(np_array_ops.array(a), axis)
@np_utils.np_doc('argsort')
@ -1199,7 +1168,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missin
raise ValueError("'order' argument to sort is not supported.")
stable = (kind == 'stable')
a = np_array_ops.array(a).data
a = np_array_ops.array(a)
def _argsort(a, axis, stable):
if axis is None:
@ -1225,20 +1194,19 @@ def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-d
a = np_array_ops.array(a)
if axis is None:
result_t = sort_ops.sort(array_ops.reshape(a.data, [-1]), 0)
return np_utils.tensor_to_ndarray(result_t)
return sort_ops.sort(array_ops.reshape(a, [-1]), 0)
else:
return np_utils.tensor_to_ndarray(sort_ops.sort(a.data, axis))
return sort_ops.sort(a, axis)
def _argminmax(fn, a, axis=None):
a = np_array_ops.array(a)
if axis is None:
# When axis is None numpy flattens the array.
a_t = array_ops.reshape(a.data, [-1])
a_t = array_ops.reshape(a, [-1])
else:
a_t = np_array_ops.atleast_1d(a).data
return np_utils.tensor_to_ndarray(fn(input=a_t, axis=axis))
a_t = np_array_ops.atleast_1d(a)
return fn(input=a_t, axis=axis)
@np_utils.np_doc('argmax')
@ -1267,24 +1235,24 @@ def average(a, axis=None, weights=None, returned=False): # pylint: disable=miss
'supported yet. Got type: %s' % type(axis))
a = np_array_ops.array(a)
if weights is None: # Treat all weights as 1
if not np.issubdtype(a.dtype, np.inexact):
if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
a = a.astype(
np_utils.result_type(a.dtype, np_dtypes.default_float_type()))
avg = math_ops.reduce_mean(a.data, axis=axis)
avg = math_ops.reduce_mean(a, axis=axis)
if returned:
if axis is None:
weights_sum = array_ops.size(a.data)
weights_sum = array_ops.size(a)
else:
weights_sum = array_ops.shape(a.data)[axis]
weights_sum = math_ops.cast(weights_sum, a.data.dtype)
weights_sum = array_ops.shape(a)[axis]
weights_sum = math_ops.cast(weights_sum, a.dtype)
else:
if np.issubdtype(a.dtype, np.inexact):
if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact):
out_dtype = np_utils.result_type(a.dtype, weights)
else:
out_dtype = np_utils.result_type(a.dtype, weights,
np_dtypes.default_float_type())
a = np_array_ops.array(a, out_dtype).data
weights = np_array_ops.array(weights, out_dtype).data
a = np_array_ops.array(a, out_dtype)
weights = np_array_ops.array(weights, out_dtype)
def rank_equal_case():
control_flow_ops.Assert(
@ -1316,8 +1284,7 @@ def average(a, axis=None, weights=None, returned=False): # pylint: disable=miss
avg = np_array_ops.array(avg)
if returned:
weights_sum = np_array_ops.broadcast_to(weights_sum,
array_ops.shape(avg.data))
weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg))
return avg, weights_sum
return avg
@ -1326,7 +1293,7 @@ def average(a, axis=None, weights=None, returned=False): # pylint: disable=miss
def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring
if dtype:
dtype = np_utils.result_type(dtype)
a = np_array_ops.asarray(a, dtype).data
a = np_array_ops.asarray(a, dtype)
if offset == 0:
a_shape = a.shape
@ -1334,7 +1301,7 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing
rank = len(a_shape)
if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or
axis2 == rank - 1):
return np_utils.tensor_to_ndarray(math_ops.trace(a))
return math_ops.trace(a)
a = np_array_ops.diagonal(a, offset, axis1, axis2)
return np_array_ops.sum(a, -1, dtype)
@ -1353,11 +1320,10 @@ def meshgrid(*xi, **kwargs):
indexing = kwargs.get('indexing', 'xy')
xi = [np_array_ops.asarray(arg).data for arg in xi]
xi = [np_array_ops.asarray(arg) for arg in xi]
kwargs = {'indexing': indexing}
outputs = array_ops.meshgrid(*xi, **kwargs)
outputs = [np_utils.tensor_to_ndarray(output) for output in outputs]
return outputs
@ -1387,7 +1353,62 @@ def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstrin
tf_optimize = 'optimal'
else:
raise ValueError('`optimize` method not supported: %s' % optimize)
operands = [x.data for x in operands]
res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize)
res = np_utils.tensor_to_ndarray(res)
return res
def _tensor_t(self):
"""Returns a Tensor which is the transpose of this Tensor."""
return self.transpose()
def _tensor_ndim(self):
"""Returns the rank of the Tensor."""
return self.shape.ndims
def _tensor_pos(self):
"""Returns self, for unary operator `+`."""
return self
def _tensor_size(self):
"""Returns the number of elements in this Tensor, if fully known."""
if not self.shape.is_fully_defined():
return None
return np.prod(self.shape.as_list())
def _tensor_tolist(self):
if isinstance(self, ops.EagerTensor):
return self._numpy().tolist() # pylint: disable=protected-access
raise ValueError('Symbolic Tensors do not support the tolist API.')
def enable_numpy_methods_on_tensor():
"""Adds additional NumPy methods on tf.Tensor class."""
t = property(_tensor_t)
setattr(ops.Tensor, 'T', t)
ndim = property(_tensor_ndim)
setattr(ops.Tensor, 'ndim', ndim)
size = property(_tensor_size)
setattr(ops.Tensor, 'size', size)
setattr(ops.Tensor, '__pos__', _tensor_pos)
setattr(ops.Tensor, 'tolist', _tensor_tolist)
# TODO(b/178540516): Make a custom `setattr` that changes the method's
# docstring to the TF one.
setattr(ops.Tensor, 'transpose', np_array_ops.transpose)
setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access
setattr(ops.Tensor, 'ravel', np_array_ops.ravel)
setattr(ops.Tensor, 'clip', clip)
setattr(ops.Tensor, 'astype', math_ops.cast)
setattr(ops.Tensor, '__round__', np_array_ops.around)
# TODO(wangpeng): Remove `data` when all uses of it are removed
data = property(lambda self: self)
setattr(ops.Tensor, 'data', data)

View File

@ -160,7 +160,7 @@ class MathTest(test.TestCase, parameterized.TestCase):
self.assertEqual(
actual.dtype, expected.dtype,
'Dtype mismatch.\nActual: {}\nExpected: {}\n{}'.format(
actual.dtype, expected.dtype, msg))
actual.dtype.as_numpy_dtype, expected.dtype, msg))
self.assertEqual(
actual.shape, expected.shape,
'Shape mismatch.\nActual: {}\nExpected: {}\n{}'.format(
@ -350,4 +350,6 @@ class MathTest(test.TestCase, parameterized.TestCase):
if __name__ == '__main__':
ops.enable_eager_execution()
ops.enable_numpy_style_type_promotion()
np_math_ops.enable_numpy_methods_on_tensor()
test.main()

View File

@ -73,7 +73,7 @@ def standard_normal(size=None):
elif np_utils.isscalar(size):
size = (size,)
dtype = np_dtypes.default_float_type()
return np_utils.tensor_to_ndarray(random_ops.random_normal(size, dtype=dtype))
return random_ops.random_normal(size, dtype=dtype)
@np_utils.np_doc('random.uniform')
@ -83,9 +83,8 @@ def uniform(low=0.0, high=1.0, size=None):
high = np_array_ops.asarray(high, dtype=dtype)
if size is None:
size = array_ops.broadcast_dynamic_shape(low.shape, high.shape)
return np_utils.tensor_to_ndarray(
random_ops.random_uniform(
shape=size, minval=low, maxval=high, dtype=dtype))
return random_ops.random_uniform(
shape=size, minval=low, maxval=high, dtype=dtype)
@np_utils.np_doc('random.poisson')
@ -94,8 +93,7 @@ def poisson(lam=1.0, size=None):
size = ()
elif np_utils.isscalar(size):
size = (size,)
return np_utils.tensor_to_ndarray(
random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_))
return random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_)
@np_utils.np_doc('random.random')
@ -121,6 +119,5 @@ def randint(low, high=None, size=None, dtype=onp.int): # pylint: disable=missin
dtype = np_utils.result_type(dtype)
if dtype not in (onp.int32, onp.int64):
raise ValueError('Only np.int32 or np.int64 types are supported')
return np_utils.tensor_to_ndarray(
random_ops.random_uniform(
shape=size, minval=low, maxval=high, dtype=dtype))
return random_ops.random_uniform(
shape=size, minval=low, maxval=high, dtype=dtype)

View File

@ -28,6 +28,7 @@ from tensorflow.python.ops import numpy_ops as np
# Needed for ndarray.reshape.
from tensorflow.python.ops.numpy_ops import np_array_ops # pylint: disable=unused-import
from tensorflow.python.ops.numpy_ops import np_dtypes
from tensorflow.python.ops.numpy_ops import np_math_ops
from tensorflow.python.ops.numpy_ops import np_random
from tensorflow.python.platform import test
@ -192,7 +193,7 @@ class RandNDistriutionTest(test.TestCase):
self.assertEqual(output.shape, tuple(args))
default_dtype = (
np.float64 if np_dtypes.is_allow_float64() else np.float32)
self.assertEqual(output.dtype.type, default_dtype)
self.assertEqual(output.dtype.as_numpy_dtype, default_dtype)
if np.prod(args): # Don't bother with empty arrays.
outputs = [output.tolist() for output in outputs]
@ -230,4 +231,5 @@ class RandNDistriutionTest(test.TestCase):
if __name__ == '__main__':
ops.enable_eager_execution()
np_math_ops.enable_numpy_methods_on_tensor()
test.main()

View File

@ -38,9 +38,6 @@ from tensorflow.python.types import core
from tensorflow.python.util import nest
tensor_to_ndarray = np_arrays.tensor_to_ndarray
def _canonicalize_axis(axis, rank):
return _canonicalize_axes([axis], rank)[0]
@ -478,8 +475,6 @@ def _maybe_get_dtype(x):
"""Returns a numpy type if available from x. Skips if x is numpy.ndarray."""
# Don't put np.ndarray in this list, because np.result_type looks at the
# value (not just dtype) of np.ndarray to decide the result type.
if isinstance(x, np_arrays.ndarray):
return x.dtype
if isinstance(x, numbers.Real):
return x
if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)):

View File

@ -34,7 +34,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.parallel_for.pfor import PFor
from tensorflow.python.ops.parallel_for.pfor import PForConfig
from tensorflow.python.platform import tf_logging as logging
@ -289,7 +288,6 @@ def _pfor_impl(loop_fn,
loop_fn_outputs)
# Convert outputs to Tensor if needed.
rewrap_as_ndarray = False
tmp_loop_fn_outputs = []
for loop_fn_output in nest.flatten(loop_fn_output_tensors):
if (loop_fn_output is not None and not isinstance(
@ -301,9 +299,6 @@ def _pfor_impl(loop_fn,
" IndexedSlices separately, and handle the vectorized"
" outputs directly." % loop_fn_output)
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
elif isinstance(loop_fn_output, np_arrays.ndarray):
loop_fn_output = loop_fn_output.data
rewrap_as_ndarray = True
else:
loop_fn_output = ops.convert_to_tensor(loop_fn_output)
tmp_loop_fn_outputs.append(loop_fn_output)
@ -327,8 +322,6 @@ def _pfor_impl(loop_fn,
flattened_output_tensors = []
for loop_fn_output in nest.flatten(loop_fn_output_tensors):
output = converter.convert(loop_fn_output)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
flattened_output_tensors.append(output)
else:
if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access
@ -346,8 +339,6 @@ def _pfor_impl(loop_fn,
flattened_output_tensors = nest.flatten(loop_fn_output_tensors)
for loop_fn_output in flattened_output_tensors:
output = converter.convert(loop_fn_output)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
remaining_output_tensors.append(output)
with ops.name_scope("pfor_tiled"):
@ -398,10 +389,6 @@ def _pfor_impl(loop_fn,
tensor_shape.TensorShape([iters_value]).concatenate(
original_output.shape))
if rewrap_as_ndarray:
flattened_output_tensors = [
np_arrays.tensor_to_ndarray(x) for x in flattened_output_tensors]
return nest.map_structure_up_to(
loop_fn_outputs,
functools.partial(_composite_from_tensors, batch_size=iters_value),
@ -418,8 +405,6 @@ def _broadcasting_gather(x, i):
elif static_first_dim is None:
i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0)
result = array_ops.gather(x, i)
if isinstance(x, np_arrays.ndarray):
result = np_arrays.ndarray.from_tensor(result)
return result
@ -548,8 +533,6 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
is_batched=True),
elems))
def _get_shape(x):
if isinstance(x, np_arrays.ndarray):
x = x.data
if x.shape.rank is None:
return None
return x.shape.as_list()[0]

View File

@ -50,12 +50,10 @@ from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numpy_ops as tnp
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.saved_model import load
@ -1967,34 +1965,6 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
def test_saving_ndarray_specs(self, cycles):
class NdarrayModule(module.Module):
@def_function.function
def plain(self, x):
return tnp.add(x, 1)
@def_function.function(input_signature=[
np_arrays.NdarraySpec(tensor_spec.TensorSpec([], dtypes.float32))])
def with_signature(self, x):
return tnp.add(x, 1)
m = NdarrayModule()
c = tnp.asarray(3.0, tnp.float32)
output_plain, output_with_signature = m.plain(c), m.with_signature(c)
loaded_m = cycle(m, cycles)
load_output_plain, load_output_with_signature = (
loaded_m.plain(c), loaded_m.with_signature(c))
self.assertIsInstance(output_plain, tnp.ndarray)
self.assertIsInstance(load_output_plain, tnp.ndarray)
self.assertIsInstance(output_with_signature, tnp.ndarray)
self.assertIsInstance(load_output_with_signature, tnp.ndarray)
self.assertAllClose(output_plain, load_output_plain)
self.assertAllClose(output_with_signature, load_output_with_signature)
class SingleCycleTests(test.TestCase, parameterized.TestCase):

View File

@ -48,7 +48,6 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import row_partition
from tensorflow.python.util import compat
@ -517,8 +516,6 @@ class _TypeSpecCodec(object):
resource_variable_ops.VariableSpec,
struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC:
row_partition.RowPartitionSpec,
struct_pb2.TypeSpecProto.NDARRAY_SPEC:
np_arrays.NdarraySpec,
}
# Mapping from type (TypeSpec subclass) to enum value.

View File

@ -28,7 +28,6 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops.numpy_ops import np_arrays
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.saved_model import nested_structure_coder
@ -332,14 +331,6 @@ class NestedStructureTest(test.TestCase):
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)
def testEncodeDecodeNdarraySpec(self):
structure = [np_arrays.NdarraySpec(
tensor_spec.TensorSpec([4, 2], dtypes.float32))]
self.assertTrue(self._coder.can_encode(structure))
encoded = self._coder.encode_structure(structure)
decoded = self._coder.decode_proto(encoded)
self.assertEqual(structure, decoded)
def testNotEncodable(self):
class NotEncodable(object):

View File

@ -1,14 +1,15 @@
path: "tensorflow.experimental.numpy.ndarray"
tf_class {
is_instance: "<class \'tensorflow.python.ops.numpy_ops.np_arrays.ndarray\'>"
is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>"
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>"
member {
name: "T"
mtype: "<type \'property\'>"
name: "OVERLOADABLE_OPERATORS"
mtype: "<type \'set\'>"
}
member {
name: "data"
name: "device"
mtype: "<type \'property\'>"
}
member {
@ -16,7 +17,15 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
name: "ndim"
name: "graph"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "op"
mtype: "<type \'property\'>"
}
member {
@ -24,39 +33,35 @@ tf_class {
mtype: "<type \'property\'>"
}
member {
name: "size"
name: "value_index"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\', \'buffer\'], varargs=None, keywords=None, defaults=[\"<class \'float\'>\", \'None\'], "
argspec: "args=[\'self\', \'op\', \'value_index\', \'dtype\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "astype"
argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "clip"
argspec: "args=[\'a\', \'a_min\', \'a_max\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_tensor"
argspec: "args=[\'cls\', \'tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ravel"
argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reshape"
argspec: "args=[\'a\'], varargs=newshape, keywords=kwargs, defaults=None"
}
member_method {
name: "tolist"
name: "consumers"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "transpose"
argspec: "args=[\'a\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\'], "
name: "eval"
argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "experimental_ref"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_shape"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "ref"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_shape"
argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
}