From 0b9ff2eb1a097602206c6b29823543768bfb34fe Mon Sep 17 00:00:00 2001 From: Peng Wang <wangpeng@google.com> Date: Thu, 4 Feb 2021 19:13:46 -0800 Subject: [PATCH] Remove ndarray wrapper from TF Numpy. We return tensors directly. PiperOrigin-RevId: 355761429 Change-Id: I1ab012bcd831550cd2aa2a8de3d758c23bc6332a --- tensorflow/core/protobuf/struct.proto | 2 +- tensorflow/python/eager/backprop.py | 34 -- tensorflow/python/eager/forwardprop.py | 8 +- tensorflow/python/eager/function.py | 11 - tensorflow/python/framework/ops.py | 36 +- tensorflow/python/framework/ops_test.py | 11 + tensorflow/python/framework/tensor_shape.py | 10 + .../python/framework/tensor_shape_test.py | 14 + .../mixed_precision/autocast_variable_test.py | 4 +- tensorflow/python/ops/array_ops.py | 3 + tensorflow/python/ops/map_fn.py | 8 - tensorflow/python/ops/math_ops.py | 89 ++++- tensorflow/python/ops/numpy_ops/BUILD | 12 + .../ops/numpy_ops/integration_test/BUILD | 11 + .../integration_test/np_config_test.py | 44 +++ .../python/ops/numpy_ops/np_array_ops.py | 313 ++++++++---------- .../python/ops/numpy_ops/np_array_ops_test.py | 77 ++--- tensorflow/python/ops/numpy_ops/np_arrays.py | 302 +---------------- .../python/ops/numpy_ops/np_arrays_test.py | 93 +++--- tensorflow/python/ops/numpy_ops/np_config.py | 39 +++ tensorflow/python/ops/numpy_ops/np_dtypes.py | 31 +- .../python/ops/numpy_ops/np_dtypes_test.py | 57 ++++ .../python/ops/numpy_ops/np_interop_test.py | 21 +- .../python/ops/numpy_ops/np_logic_test.py | 10 +- .../python/ops/numpy_ops/np_math_ops.py | 209 ++++++------ .../python/ops/numpy_ops/np_math_ops_test.py | 4 +- tensorflow/python/ops/numpy_ops/np_random.py | 15 +- .../python/ops/numpy_ops/np_random_test.py | 4 +- tensorflow/python/ops/numpy_ops/np_utils.py | 5 - .../ops/parallel_for/control_flow_ops.py | 17 - tensorflow/python/saved_model/load_test.py | 30 -- .../saved_model/nested_structure_coder.py | 3 - .../nested_structure_coder_test.py | 9 - ...ensorflow.experimental.numpy.ndarray.pbtxt | 67 ++-- 34 files changed, 744 insertions(+), 859 deletions(-) create mode 100644 tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py create mode 100644 tensorflow/python/ops/numpy_ops/np_config.py create mode 100644 tensorflow/python/ops/numpy_ops/np_dtypes_test.py diff --git a/tensorflow/core/protobuf/struct.proto b/tensorflow/core/protobuf/struct.proto index c99eab5dd88..19cd5dfb6dd 100644 --- a/tensorflow/core/protobuf/struct.proto +++ b/tensorflow/core/protobuf/struct.proto @@ -136,7 +136,7 @@ message TypeSpecProto { PER_REPLICA_SPEC = 8; // PerReplicaSpec from distribute/values.py VARIABLE_SPEC = 9; // tf.VariableSpec ROW_PARTITION_SPEC = 10; // RowPartitionSpec from ragged/row_partition.py - NDARRAY_SPEC = 11; // TF Numpy NDarray spec + reserved 11; } TypeSpecClass type_spec_class = 1; diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 94a3c5a67ac..4681b9c6185 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -62,9 +62,6 @@ from tensorflow.python.util.tf_export import tf_export pfor_ops = LazyLoader( "pfor_ops", globals(), "tensorflow.python.ops.parallel_for.control_flow_ops") -np_arrays = LazyLoader( - "np_arrays", globals(), - "tensorflow.python.ops.numpy_ops.np_arrays") function = LazyLoader("function", globals(), "tensorflow.python.eager.function") @@ -727,8 +724,6 @@ def _handle_or_self(x): """Unwrap resource variable/ndarray to return tensors.""" if resource_variable_ops.is_resource_variable(x): return x.handle - if isinstance(x, np_arrays.ndarray): - return x.data return x @@ -1034,7 +1029,6 @@ class GradientTape(object): " of Tensors or Variables to be differentiated," " but recieved %r" % (target)) - num_ndarrays = 0 flat_targets = [] for t in nest.flatten(target): if not backprop_util.IsTrainable(t): @@ -1045,12 +1039,7 @@ class GradientTape(object): if resource_variable_ops.is_resource_variable(t): with self: t = ops.convert_to_tensor(t) - elif isinstance(t, np_arrays.ndarray): - t = t.data - num_ndarrays += 1 flat_targets.append(t) - # Only rewrap if all targets are ndarray. If not, prefer tensors. - rewrap_as_ndarray = num_ndarrays == len(flat_targets) flat_sources = nest.flatten(sources) flat_sources_raw = flat_sources @@ -1083,13 +1072,6 @@ class GradientTape(object): self._watched_variables = self._tape.watched_variables() self._tape = None - if rewrap_as_ndarray: - def _tensor_to_ndarray(x): - if x is not None: - return np_arrays.tensor_to_ndarray(x) - return None - flat_grad = nest.map_structure(_tensor_to_ndarray, flat_grad) - grad = nest.pack_sequence_as(sources, flat_grad) return grad @@ -1158,10 +1140,6 @@ class GradientTape(object): "compute one set of gradients (or jacobians)") flat_sources = nest.flatten(sources) - rewrap_as_ndarray = False - if isinstance(target, np_arrays.ndarray): - target = target.data - rewrap_as_ndarray = True target_static_shape = target.shape target_shape = array_ops.shape(target) # Note that we push and pop the tape here and below. This is needed since we @@ -1211,8 +1189,6 @@ class GradientTape(object): out = array_ops.reshape(out, new_shape) if context.executing_eagerly(): out.set_shape(target_static_shape.concatenate(flat_sources[i].shape)) - if rewrap_as_ndarray: - out = np_arrays.tensor_to_ndarray(out) output[i] = out return nest.pack_sequence_as(sources, output) @@ -1287,12 +1263,6 @@ class GradientTape(object): if self._tape is None: raise RuntimeError("A non-persistent GradientTape can only be used to" "compute one set of gradients (or jacobians)") - rewrap_as_ndarray = False - if isinstance(target, np_arrays.ndarray): - target = target.data - rewrap_as_ndarray = True - if isinstance(source, np_arrays.ndarray): - source = source.data target_shape = target.shape if target_shape.rank is None: dim = tensor_shape.Dimension(None) @@ -1354,8 +1324,6 @@ class GradientTape(object): # represent unconnected gradients. This is to maintain compatibility with # the previous behavior, which ignored `unconnected_gradients`. output = array_ops.zeros(new_shape, target.dtype) - if rewrap_as_ndarray: - output = np_arrays.tensor_to_ndarray(output) return output else: output = array_ops.reshape(output, @@ -1363,6 +1331,4 @@ class GradientTape(object): output = array_ops.transpose(output, [1, 0, 2]) output = array_ops.reshape(output, new_shape) - if rewrap_as_ndarray: - output = np_arrays.tensor_to_ndarray(output) return output diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py index 2f64bad0dff..f3ae1643073 100644 --- a/tensorflow/python/eager/forwardprop.py +++ b/tensorflow/python/eager/forwardprop.py @@ -32,7 +32,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.parallel_for import control_flow_ops from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.platform import tf_logging as logging @@ -441,16 +440,11 @@ class ForwardAccumulator(): if hasattr(tensor, "handle"): unwrapped_tensor = ops.convert_to_tensor(tensor.handle) else: - if isinstance(tensor, np_arrays.ndarray): - unwrapped_tensor = tensor.data - else: - unwrapped_tensor = tensor + unwrapped_tensor = tensor result = pywrap_tfe.TFE_Py_ForwardAccumulatorJVP(self._accumulator, unwrapped_tensor) if result is None and unconnected_gradients == UnconnectedGradients.ZERO: result = array_ops.zeros_like(tensor) - if result is not None and isinstance(tensor, np_arrays.ndarray): - return np_arrays.tensor_to_ndarray(result) return result return nest.map_structure(_fetch_jvp, primals) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 280040d4157..2daccff8a89 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1522,11 +1522,6 @@ class ConcreteFunction(object): self._func_graph = func_graph self._captured_inputs = self._func_graph.external_captures self._captured_closures = self._func_graph.deferred_external_captures - structured_outputs = self._func_graph.structured_outputs - self._ndarrays_list = ( - isinstance(structured_outputs, (list, tuple)) and structured_outputs and - all(isinstance(o, np_arrays.ndarray) for o in structured_outputs)) - self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray) # function_spec defines the structured signature. self._set_function_spec(function_spec) @@ -2176,12 +2171,6 @@ class ConcreteFunction(object): if self._func_graph.structured_outputs is None: return result - if result: - if self._ndarrays_list: - return [np_arrays.tensor_to_ndarray(o) for o in result] - elif self._ndarray_singleton: - return np_arrays.tensor_to_ndarray(result[0]) - # Replace outputs with results, skipping over any 'None' values. outputs_list = nest.flatten( self._func_graph.structured_outputs, expand_composites=True) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ea6ae1b2f6f..794f48799ab 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -257,7 +257,7 @@ def disable_tensor_equality(): # TODO(mdan): This object should subclass Symbol, not just Tensor. -@tf_export("Tensor") +@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"]) class Tensor(internal.NativeObject, core_tf_types.Tensor): """A tensor is a multidimensional array of elements represented by a @@ -386,6 +386,16 @@ class Tensor(internal.NativeObject, core_tf_types.Tensor): self._id = uid() self._name = None + def __getattr__(self, name): + if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size", + "tolist", "data"}: + # TODO(wangpeng): Export the enable_numpy_behavior knob + raise AttributeError(""" + If you are looking for numpy-related methods, please run the following: + import tensorflow.python.ops.numpy_ops.np_config + np_config.enable_numpy_behavior()""") + self.__getattribute__(name) + @staticmethod def _create_with_tf_output(op, value_index, dtype, tf_output): ret = Tensor(op, value_index, dtype) @@ -6943,6 +6953,30 @@ def _reconstruct_sequence_inputs(op_def, inputs, attrs): return grouped_inputs +_numpy_style_type_promotion = False + + +def enable_numpy_style_type_promotion(): + """If called, follows NumPy's rules for type promotion. + + Used for enabling NumPy behavior on methods for TF NumPy. + """ + global _numpy_style_type_promotion + _numpy_style_type_promotion = True + + +_numpy_style_slicing = False + + +def enable_numpy_style_slicing(): + """If called, follows NumPy's rules for slicing Tensors. + + Used for enabling NumPy behavior on slicing for TF NumPy. + """ + global _numpy_style_slicing + _numpy_style_slicing = True + + class _TensorIterator(object): """Iterates over the leading dim of a Tensor. Performs no error checks.""" diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 14db2375d96..19894060ab5 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -202,6 +202,17 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase): self.assertAllEqual(np.array(x), np.ones((3, 4))) self.assertEqual(len(x), 3) + def testConstructor(self): + a = array_ops.ones([]) + for name in ["T", "astype", "ravel", "transpose", "reshape", "clip", "size", + "tolist", "data"]: + with self.assertRaisesRegex( + AttributeError, r"If you are looking for numpy-related methods"): + getattr(a, name) + with self.assertRaisesRegex( + AttributeError, r"object has no attribute"): + a.foo_bar() + def testRef(self): x1 = constant_op.constant(3) x2 = x1 diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 8c45906ab37..08c866cd2cf 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -250,6 +250,11 @@ class Dimension(object): # Allow use in Python 3 range return self._value + def __hash__(self): + if self._value is None: + raise ValueError("Unable to hash Dimension with value 'None'") + return hash(self._value) + @property def value(self): """The value of this dimension, or None if it is unknown.""" @@ -986,6 +991,11 @@ class TensorShape(object): other = TensorShape(other) return other.concatenate(self) + def __hash__(self): + if not self.is_fully_defined(): + raise ValueError("Unable to hash partially defined TensorShape.") + return hash(tuple(self._dims)) + def concatenate(self, other): """Returns the concatenation of the dimension in `self` and `other`. diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py index a83d024d5ca..5eb3ff02672 100644 --- a/tensorflow/python/framework/tensor_shape_test.py +++ b/tensorflow/python/framework/tensor_shape_test.py @@ -384,6 +384,20 @@ class ShapeTest(test_util.TensorFlowTestCase, parameterized.TestCase): else: self.assertEqual(expected, mcs.as_list()) + def testHash(self): + base = tensor_shape.TensorShape([1, 2, 3, 4]) + base_copy = tensor_shape.TensorShape([1, 2, 3, 4]) + self.assertEqual(hash(base), hash(base_copy)) + + with self.assertRaisesRegex(ValueError, r"Unable to hash partially"): + hash(tensor_shape.TensorShape([1, 2, 3, 4, None])) + + with self.assertRaisesRegex(ValueError, r"Unable to hash partially"): + hash(tensor_shape.TensorShape(None)) + + with self.assertRaisesRegex(ValueError, r"Unable to hash Dimension"): + hash(tensor_shape.Dimension(None)) + def testMostSpecificCompatibleShape(self): self._testMostSpecificCompatibleShapeHelper([1, 2], None, None) self._testMostSpecificCompatibleShapeHelper(None, [1, 2], None) diff --git a/tensorflow/python/keras/mixed_precision/autocast_variable_test.py b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py index 6d70aaaca27..3fd00608e38 100644 --- a/tensorflow/python/keras/mixed_precision/autocast_variable_test.py +++ b/tensorflow/python/keras/mixed_precision/autocast_variable_test.py @@ -368,9 +368,7 @@ class AutoCastVariableTest(test.TestCase, parameterized.TestCase): # mode. Variable.assign(...).op is None in Eager mode and an op in Graph # mode or a tf.function. We test this is also true of AutoCastVariable. if context.executing_eagerly(): - with self.assertRaisesRegex( - AttributeError, - 'Tensor.op is meaningless when eager execution is enabled'): + with self.assertRaises(AttributeError): x.op # pylint: disable=pointless-statement self.assertIsNone(x.assign(1.0).op) self.assertIsNone(x.assign_add(1.0).op) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 3e053fca024..f0f97f6a054 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -962,6 +962,9 @@ def _slice_helper(tensor, slice_spec, var=None): tf.newaxis or scalar int32/int64 tensors. """ tensor = ops.convert_to_tensor(tensor) + # TODO(wangpeng): Consider supporting var + if var is None and ops._numpy_style_slicing: # pylint: disable=protected-access + return tensor._numpy_style_getitem(slice_spec) # pylint: disable=protected-access if isinstance(slice_spec, bool) or \ (isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \ diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py index dc4ee6bc5d7..75394ba7fe7 100644 --- a/tensorflow/python/ops/map_fn.py +++ b/tensorflow/python/ops/map_fn.py @@ -38,16 +38,10 @@ from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation -from tensorflow.python.util import lazy_loader from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export -np_arrays = lazy_loader.LazyLoader( - "np_arrays", globals(), - "tensorflow.python.ops.numpy_ops.np_arrays") - - @tf_export(v1=["map_fn"]) @deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype") def map_fn(fn, @@ -426,8 +420,6 @@ def map_fn(fn, # Check that inputs are not scalars. first_elem = elems_flat[0] - if isinstance(first_elem, np_arrays.ndarray): - first_elem = first_elem.data elems_static_shape = first_elem.shape if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1: if len(elems_flat) == 1: diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index decad558e85..1cdc90193ab 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -70,6 +70,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numbers import numpy as np import six from six.moves import builtins @@ -99,9 +100,17 @@ from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import dispatch from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.tf_export import tf_export + +np_dtypes = LazyLoader( + "np_dtypes", globals(), + "tensorflow.python.ops.numpy_ops.np_dtypes") + + # Aliases for some automatically-generated names. nextafter = gen_math_ops.next_after @@ -1130,6 +1139,48 @@ ops.Tensor._override_operator("__neg__", gen_math_ops.neg) ops.Tensor._override_operator("__abs__", abs) +def _maybe_get_dtype(x): + """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" + # Don't put np.ndarray in this list, because np.result_type looks at the + # value (not just dtype) of np.ndarray to decide the result type. + if isinstance(x, numbers.Real): + return x + if isinstance(x, ops.Tensor): + return x.dtype.as_numpy_dtype + if isinstance(x, dtypes.DType): + return x.as_numpy_dtype + if isinstance(x, tensor_shape.TensorShape): + return np.int32 + if isinstance(x, (list, tuple)): + raise ValueError("Got sequence {}".format(x)) + return x + + +def maybe_promote_tensors(*tensors, force_same_dtype=True): + """Promote tensors if numpy style promotion is enabled.""" + if not tensors: + return tensors + if not ops._numpy_style_type_promotion: + if not force_same_dtype: + return tensors + promoted_tensors = [] + promoted_tensors.append(tensors[0]) + dtype = tensors[0].dtype.base_dtype + for tensor in tensors[1:]: + promoted_tensors.append( + ops.convert_to_tensor(tensor, dtype, name="x")) + return promoted_tensors + result_type = np_dtypes._result_type( + *[_maybe_get_dtype(x) for x in nest.flatten(tensors)]) + def _promote_or_cast(x): + if isinstance(x, ops.Tensor): + x = cast(x, result_type) + else: + x = ops.convert_to_tensor(x, result_type) + return x + return [_promote_or_cast(x) for x in tensors] + + def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): """Register operators with different tensor and scalar versions. @@ -1145,6 +1196,10 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): def binary_op_wrapper(x, y): with ops.name_scope(None, op_name, [x, y]) as name: try: + # force_same_dtype=False to preserve existing TF behavior + # TODO(b/178860388): Figure out why binary_op_wrapper and + # r_binary_op_wrapper use different force_same_dtype values. + x, y = maybe_promote_tensors(x, y, force_same_dtype=False) return func(x, y, name=name) except (TypeError, ValueError) as e: # Even if dispatching the op failed, the RHS may be a tensor aware @@ -1175,7 +1230,9 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor): def r_binary_op_wrapper(y, x): with ops.name_scope(None, op_name, [x, y]) as name: - x = ops.convert_to_tensor(x, dtype=y.dtype.base_dtype, name="x") + # TODO(b/178860388): Figure out why binary_op_wrapper and + # r_binary_op_wrapper use different force_same_dtype values. + y, x = maybe_promote_tensors(y, x) return func(x, y, name=name) # Propagate func.__doc__ to the wrappers @@ -1581,10 +1638,21 @@ _OverrideBinaryOperatorHelper(xor_, "xor") ops.Tensor._override_operator("__invert__", invert_) -ops.Tensor._override_operator("__lt__", gen_math_ops.less) -ops.Tensor._override_operator("__le__", gen_math_ops.less_equal) -ops.Tensor._override_operator("__gt__", gen_math_ops.greater) -ops.Tensor._override_operator("__ge__", gen_math_ops.greater_equal) +def _promote_dtypes_decorator(fn): + def wrapper(x, y, *args, **kwargs): + x, y = maybe_promote_tensors(x, y, force_same_dtype=False) + return fn(x, y, *args, **kwargs) + return tf_decorator.make_decorator(fn, wrapper) + + +ops.Tensor._override_operator("__lt__", _promote_dtypes_decorator( + gen_math_ops.less)) +ops.Tensor._override_operator("__le__", _promote_dtypes_decorator( + gen_math_ops.less_equal)) +ops.Tensor._override_operator("__gt__", _promote_dtypes_decorator( + gen_math_ops.greater)) +ops.Tensor._override_operator("__ge__", _promote_dtypes_decorator( + gen_math_ops.greater_equal)) @tf_export("math.equal", "equal") @@ -1691,6 +1759,7 @@ def tensor_equals(self, other): g = getattr(self, "graph", None) if (ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions() and (g is None or g.building_function)): + self, other = maybe_promote_tensors(self, other) return gen_math_ops.equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality @@ -1727,6 +1796,7 @@ def tensor_not_equals(self, other): if other is None: return True if ops.Tensor._USE_EQUALITY and ops.executing_eagerly_outside_functions(): + self, other = maybe_promote_tensors(self, other) return gen_math_ops.not_equal(self, other, incompatible_shape_error=False) else: # In legacy graph mode, tensor equality is object equality @@ -3482,7 +3552,14 @@ def matvec(a, return array_ops.squeeze(output, axis=-1) -_OverrideBinaryOperatorHelper(matmul, "matmul") +# TODO(b/178650720): Also support numpy-style type promotion in freestanding TF +# functions (e.g. tf.add). +def matmul_wrapper(a, b, name=None): # pylint: disable=missing-function-docstring + if ops._numpy_style_type_promotion: + return a._matmul(b) + return matmul(a, b, name=name) +matmul_wrapper.__doc__ = matmul.__doc__ +_OverrideBinaryOperatorHelper(matmul_wrapper, "matmul") sparse_matmul = deprecation.deprecated(None, "Use `tf.linalg.matmul` instead")( gen_math_ops.sparse_mat_mul) diff --git a/tensorflow/python/ops/numpy_ops/BUILD b/tensorflow/python/ops/numpy_ops/BUILD index 8b742f68c02..514e2ab0cad 100644 --- a/tensorflow/python/ops/numpy_ops/BUILD +++ b/tensorflow/python/ops/numpy_ops/BUILD @@ -13,6 +13,7 @@ py_library( "__init__.py", "np_array_ops.py", "np_arrays.py", + "np_config.py", "np_dtypes.py", "np_export.py", "np_math_ops.py", @@ -40,6 +41,17 @@ py_library( ], ) +cuda_py_test( + name = "np_dtypes_test", + srcs = ["np_dtypes_test.py"], + deps = [ + ":numpy", + "//tensorflow/python:platform", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "np_arrays_test", srcs = ["np_arrays_test.py"], diff --git a/tensorflow/python/ops/numpy_ops/integration_test/BUILD b/tensorflow/python/ops/numpy_ops/integration_test/BUILD index e5483166406..06cce0c466e 100644 --- a/tensorflow/python/ops/numpy_ops/integration_test/BUILD +++ b/tensorflow/python/ops/numpy_ops/integration_test/BUILD @@ -1,4 +1,5 @@ load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") licenses(["notice"]) @@ -10,3 +11,13 @@ py_test( "//tensorflow:tensorflow_py", ], ) + +cuda_py_test( + name = "np_config_test", + srcs = ["np_config_test.py"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python/ops/numpy_ops:numpy", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py b/tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py new file mode 100644 index 00000000000..014eba12960 --- /dev/null +++ b/tensorflow/python/ops/numpy_ops/integration_test/np_config_test.py @@ -0,0 +1,44 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that an error is raised when numpy functions are called.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v2 as tf +from tensorflow.python.ops.numpy_ops import np_config + + +class ConfigTest(tf.test.TestCase): + + def testMethods(self): + a = tf.constant(1.) + + for name in {'T', 'astype', 'ravel', 'transpose', 'reshape', 'clip', 'size', + 'tolist'}: + with self.assertRaisesRegex(AttributeError, 'enable_numpy_behavior'): + getattr(a, name) + + np_config.enable_numpy_behavior() + + for name in {'T', 'astype', 'ravel', 'transpose', 'reshape', 'clip', 'size', + 'tolist'}: + _ = getattr(a, name) + + +if __name__ == '__main__': + tf.compat.v1.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index 3d5c3f93d2e..042dc096586 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -61,15 +61,11 @@ def empty_like(a, dtype=None): def zeros(shape, dtype=float): # pylint: disable=redefined-outer-name dtype = ( np_utils.result_type(dtype) if dtype else np_dtypes.default_float_type()) - if isinstance(shape, np_arrays.ndarray): - shape = shape.data - return np_arrays.tensor_to_ndarray(array_ops.zeros(shape, dtype=dtype)) + return array_ops.zeros(shape, dtype=dtype) @np_utils.np_doc('zeros_like') def zeros_like(a, dtype=None): # pylint: disable=missing-docstring - if isinstance(a, np_arrays.ndarray): - a = a.data if dtype is None: # We need to let np_utils.result_type decide the dtype, not tf.zeros_like dtype = np_utils.result_type(a) @@ -78,27 +74,23 @@ def zeros_like(a, dtype=None): # pylint: disable=missing-docstring # `float`, so we let `np_utils.result_type` decide. dtype = np_utils.result_type(dtype) dtype = dtypes.as_dtype(dtype) # Work around b/149877262 - return np_arrays.tensor_to_ndarray(array_ops.zeros_like(a, dtype)) + return array_ops.zeros_like(a, dtype) @np_utils.np_doc('ones') def ones(shape, dtype=float): # pylint: disable=redefined-outer-name if dtype: dtype = np_utils.result_type(dtype) - if isinstance(shape, np_arrays.ndarray): - shape = shape.data - return np_arrays.tensor_to_ndarray(array_ops.ones(shape, dtype=dtype)) + return array_ops.ones(shape, dtype=dtype) @np_utils.np_doc('ones_like') def ones_like(a, dtype=None): - if isinstance(a, np_arrays.ndarray): - a = a.data if dtype is None: dtype = np_utils.result_type(a) else: dtype = np_utils.result_type(dtype) - return np_arrays.tensor_to_ndarray(array_ops.ones_like(a, dtype)) + return array_ops.ones_like(a, dtype) @np_utils.np_doc('eye') @@ -115,7 +107,7 @@ def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-do # tf.linalg.diag will raise an error in this case return zeros([N, M], dtype=dtype) if k == 0: - return np_arrays.tensor_to_ndarray(linalg_ops.eye(N, M, dtype=dtype)) + return linalg_ops.eye(N, M, dtype=dtype) # We need the precise length, otherwise tf.linalg.diag will raise an error diag_len = min(N, M) if k > 0: @@ -129,8 +121,7 @@ def eye(N, M=None, k=0, dtype=float): # pylint: disable=invalid-name,missing-do elif M - k > N: diag_len = N + k diagonal_ = array_ops.ones([diag_len], dtype=dtype) - return np_arrays.tensor_to_ndarray( - array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k)) + return array_ops.matrix_diag(diagonal=diagonal_, num_rows=N, num_cols=M, k=k) @np_utils.np_doc('identity') @@ -142,10 +133,9 @@ def identity(n, dtype=float): def full(shape, fill_value, dtype=None): # pylint: disable=redefined-outer-name if not isinstance(shape, np_arrays.ndarray): shape = asarray(np_arrays.convert_to_tensor(shape, dtype_hint=np.int32)) - shape = atleast_1d(shape).data + shape = atleast_1d(shape) fill_value = asarray(fill_value, dtype=dtype) - return np_arrays.tensor_to_ndarray( - array_ops.broadcast_to(fill_value.data, shape)) + return array_ops.broadcast_to(fill_value, shape) # Using doc only here since np full_like signature doesn't seem to have the @@ -160,19 +150,15 @@ def full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): # if shape: raise ValueError('Overriding the shape is not supported.') - a = asarray(a).data + a = asarray(a) dtype = dtype or np_utils.result_type(a) fill_value = asarray(fill_value, dtype=dtype) - return np_arrays.tensor_to_ndarray( - array_ops.broadcast_to(fill_value.data, array_ops.shape(a))) + return array_ops.broadcast_to(fill_value, array_ops.shape(a)) def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-outer-name """Main implementation of np.array().""" - if isinstance(val, np_arrays.ndarray): - result_t = val.data - else: - result_t = val + result_t = val if not isinstance(result_t, ops.Tensor): if not dtype: @@ -180,13 +166,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red # We can't call `convert_to_tensor(result_t, dtype=dtype)` here because # convert_to_tensor doesn't allow incompatible arguments such as (5.5, int) # while np.array allows them. We need to convert-then-cast. - def maybe_data(x): - if isinstance(x, np_arrays.ndarray): - return x.data - return x - # Handles lists of ndarrays - result_t = nest.map_structure(maybe_data, result_t) # EagerTensor conversion complains about "mixed types" when converting # tensors with no dtype information. This is because it infers types based # on one selected item in the list. So e.g. when converting [2., 2j] @@ -204,7 +184,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red result_t = array_ops.identity(result_t) if ndmin == 0: - return np_arrays.tensor_to_ndarray(result_t) + return result_t ndims = array_ops.rank(result_t) @@ -216,7 +196,7 @@ def _array_internal(val, dtype=None, copy=True, ndmin=0): # pylint: disable=red result_t = np_utils.cond( np_utils.greater(ndmin, ndims), true_fn, lambda: result_t) - return np_arrays.tensor_to_ndarray(result_t) + return result_t # TODO(wangpeng): investigate whether we can make `copy` default to False. @@ -241,7 +221,8 @@ def array(val, dtype=None, copy=True, ndmin=0): # pylint: disable=redefined-out def asarray(a, dtype=None): if dtype: dtype = np_utils.result_type(dtype) - if isinstance(a, np_arrays.ndarray) and (not dtype or dtype == a.dtype): + if isinstance(a, np_arrays.ndarray) and ( + not dtype or dtype == a.dtype.as_numpy_dtype): return a return array(a, dtype, copy=False) @@ -294,15 +275,15 @@ def arange(start, stop=None, step=1, dtype=None): return array([], dtype=dtype) # TODO(srbs): There are some bugs when start or stop is float type and dtype # is integer type. - return np_arrays.tensor_to_ndarray( - math_ops.cast(math_ops.range(start, limit=stop, delta=step), dtype=dtype)) + return math_ops.cast( + math_ops.range(start, limit=stop, delta=step), dtype=dtype) # Building matrices. @np_utils.np_doc('diag') def diag(v, k=0): # pylint: disable=missing-docstring """Raises an error if input is not 1- or 2-d.""" - v = asarray(v).data + v = asarray(v) v_rank = array_ops.rank(v) v.shape.with_rank_at_most(2) @@ -331,20 +312,20 @@ def diag(v, k=0): # pylint: disable=missing-docstring result = np_utils.cond( math_ops.equal(v_rank, 1), lambda: _diag(v, k), lambda: _diag_part(v, k)) - return np_utils.tensor_to_ndarray(result) + return result @np_utils.np_doc('diagonal') def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstring - a = asarray(a).data + a = asarray(a) maybe_rank = a.shape.rank if maybe_rank is not None and offset == 0 and ( axis1 == maybe_rank - 2 or axis1 == -2) and (axis2 == maybe_rank - 1 or axis2 == -1): - return np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a)) + return array_ops.matrix_diag_part(a) - a = moveaxis(np_utils.tensor_to_ndarray(a), (axis1, axis2), (-2, -1)).data + a = moveaxis(a, (axis1, axis2), (-2, -1)) a_shape = array_ops.shape(a) @@ -361,20 +342,20 @@ def diagonal(a, offset=0, axis1=0, axis2=1): # pylint: disable=missing-docstrin np_utils.greater_equal(offset, np_utils.getitem(a_shape, -1)), ), _zeros, lambda: (a, offset)) - a = np_utils.tensor_to_ndarray(array_ops.matrix_diag_part(a, k=offset)) + a = array_ops.matrix_diag_part(a, k=offset) return a @np_utils.np_doc('diagflat') def diagflat(v, k=0): v = asarray(v) - return diag(array_ops.reshape(v.data, [-1]), k) + return diag(array_ops.reshape(v, [-1]), k) def _promote_dtype(*arrays): dtype = np_utils.result_type(*arrays) def _fast_asarray(a): - if isinstance(a, np_arrays.ndarray) and dtype == a.dtype: + if isinstance(a, np_arrays.ndarray) and dtype == a.dtype.as_numpy_dtype: return a return _array_internal(a, dtype=dtype, copy=False) return [_fast_asarray(a) for a in arrays] @@ -382,9 +363,11 @@ def _promote_dtype(*arrays): def _promote_dtype_binary(t1, t2): dtype = np_utils._result_type_binary(t1, t2) # pylint: disable=protected-access - if not(isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype): + if not( + isinstance(t1, np_arrays.ndarray) and dtype == t1.dtype.as_numpy_dtype): t1 = _array_internal(t1, dtype=dtype, copy=False) - if not(isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype): + if not( + isinstance(t2, np_arrays.ndarray) and dtype == t2.dtype.as_numpy_dtype): t2 = _array_internal(t2, dtype=dtype, copy=False) return t1, t2 @@ -392,15 +375,13 @@ def _promote_dtype_binary(t1, t2): @np_utils.np_doc('all') def all(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin a = asarray(a, dtype=bool) - return np_utils.tensor_to_ndarray( - math_ops.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims)) + return math_ops.reduce_all(input_tensor=a, axis=axis, keepdims=keepdims) @np_utils.np_doc('any') def any(a, axis=None, keepdims=None): # pylint: disable=redefined-builtin a = asarray(a, dtype=bool) - return np_utils.tensor_to_ndarray( - math_ops.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims)) + return math_ops.reduce_any(input_tensor=a, axis=axis, keepdims=keepdims) @np_utils.np_doc('compress') @@ -425,13 +406,12 @@ def compress(condition, a, axis=None): # pylint: disable=redefined-outer-name,m # `tf.boolean_mask` requires the first dimensions of array and condition to # match. `np.compress` pads condition with False when it is shorter. - condition_t = condition.data - a_t = a.data + condition_t = condition + a_t = a if condition.shape[0] < a.shape[axis]: padding = array_ops.fill([a.shape[axis] - condition.shape[0]], False) condition_t = array_ops.concat([condition_t, padding], axis=0) - return np_utils.tensor_to_ndarray( - array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis)) + return array_ops.boolean_mask(tensor=a_t, mask=condition_t, axis=axis) @np_utils.np_doc('copy') @@ -443,8 +423,9 @@ def _maybe_promote_to_int(a): if dtypes.as_dtype(a.dtype).is_integer: # If a is an integer type and its precision is less than that of `int`, # the output type will be `int`. - output_type = np.promote_types(a.dtype, int) - if output_type != a.dtype: + a_numpy_dtype = a.dtype.as_numpy_dtype + output_type = np.promote_types(a_numpy_dtype, int) + if output_type != a_numpy_dtype: a = asarray(a, dtype=output_type) return a @@ -462,8 +443,8 @@ def cumprod(a, axis=None, dtype=None): # pylint: disable=missing-docstring a = ravel(a) axis = 0 elif axis < 0: - axis += array_ops.rank(a.data) - return np_utils.tensor_to_ndarray(math_ops.cumprod(a.data, axis)) + axis += array_ops.rank(a) + return math_ops.cumprod(a, axis) @np_utils.np_doc('cumsum') @@ -478,8 +459,8 @@ def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring a = ravel(a) axis = 0 elif axis < 0: - axis += array_ops.rank(a.data) - return np_utils.tensor_to_ndarray(math_ops.cumsum(a.data, axis)) + axis += array_ops.rank(a) + return math_ops.cumsum(a, axis) @np_utils.np_doc('imag') @@ -487,7 +468,7 @@ def imag(val): val = asarray(val) # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always # return an ndarray. - return np_utils.tensor_to_ndarray(math_ops.imag(val.data)) + return math_ops.imag(val) _TO_INT_ = 0 @@ -532,10 +513,9 @@ def _reduce(tf_fn, a = asarray(a, dtype=dtype) if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_) and tf_bool_fn is not None): - return np_utils.tensor_to_ndarray( - tf_bool_fn(input_tensor=a.data, axis=axis, keepdims=keepdims)) + return tf_bool_fn(input_tensor=a, axis=axis, keepdims=keepdims) if dtype is None: - dtype = a.dtype + dtype = a.dtype.as_numpy_dtype if np.issubdtype(dtype, np.integer) or dtype == np.bool_: if promote_int == _TO_INT_: # If a is an integer/bool type and whose bit width is less than np.int_, @@ -554,12 +534,15 @@ def _reduce(tf_fn, dtype = np.int_ else: dtype = np.uint - a = a.astype(dtype) + a = math_ops.cast(a, dtype) elif promote_int == _TO_FLOAT: - a = a.astype(np_dtypes.default_float_type()) + a = math_ops.cast(a, np_dtypes.default_float_type()) - return np_utils.tensor_to_ndarray( - tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims)) + if isinstance(axis, ops.Tensor) and axis.dtype not in ( + dtypes.int32, dtypes.int64): + axis = math_ops.cast(axis, dtypes.int64) + + return tf_fn(input_tensor=a, axis=axis, keepdims=keepdims) # TODO (DarrenZhang01): Add `axis` support to the `size` API. @@ -570,11 +553,11 @@ def size(x, axis=None): # pylint: disable=missing-docstring '`np.size` implementation') if isinstance(x, (int, float, np.int32, np.int64, np.float32, np.float64)): return 1 - x = asarray(x).data + x = asarray(x) if x.shape.is_fully_defined(): return np.prod(x.shape.as_list(), dtype=int) else: - return np_utils.tensor_to_ndarray(array_ops.size_v2(x)) + return array_ops.size_v2(x) @np_utils.np_doc('sum') @@ -677,10 +660,10 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): # pylint: d axis=axis, dtype=working_dtype, keepdims=keepdims, - promote_int=_TO_FLOAT).data + promote_int=_TO_FLOAT) if dtype: result = math_ops.cast(result, dtype) - return np_utils.tensor_to_ndarray(result) + return result @np_utils.np_doc('std') @@ -697,13 +680,7 @@ def std(a, axis=None, keepdims=None): # pylint: disable=missing-function-docstr @np_utils.np_doc('ravel') def ravel(a): # pylint: disable=missing-docstring a = asarray(a) - out = np_utils.cond( - math_ops.equal(a.ndim, 1), lambda: a.data, - lambda: array_ops.reshape(a.data, [-1])) - return np_utils.tensor_to_ndarray(out) - - -setattr(np_arrays.ndarray, 'ravel', ravel) + return array_ops.reshape(a, [-1]) @np_utils.np_doc('real') @@ -711,12 +688,12 @@ def real(val): val = asarray(val) # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always # return an ndarray. - return np_utils.tensor_to_ndarray(math_ops.real(val.data)) + return math_ops.real(val) @np_utils.np_doc('repeat') def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring - a = asarray(a).data + a = asarray(a) original_shape = a._shape_as_list() # pylint: disable=protected-access # Best effort recovery of the shape. known_shape = original_shape is not None and None not in original_shape @@ -737,18 +714,18 @@ def repeat(a, repeats, axis=None): # pylint: disable=missing-docstring else: original_shape[axis] = repeats_np.sum() - repeats = asarray(repeats).data + repeats = asarray(repeats) result = array_ops.repeat(a, repeats, axis) if known_shape: result.set_shape(original_shape) - return np_utils.tensor_to_ndarray(result) + return result @np_utils.np_doc('around') def around(a, decimals=0): # pylint: disable=missing-docstring a = asarray(a) - dtype = a.dtype + dtype = a.dtype.as_numpy_dtype factor = math.pow(10, decimals) if np.issubdtype(dtype, np.inexact): factor = math_ops.cast(factor, dtype) @@ -756,12 +733,12 @@ def around(a, decimals=0): # pylint: disable=missing-docstring # Use float as the working dtype when a.dtype is exact (e.g. integer), # because `decimals` can be negative. float_dtype = np_dtypes.default_float_type() - a = a.astype(float_dtype).data + a = a.astype(float_dtype) factor = math_ops.cast(factor, float_dtype) a = math_ops.multiply(a, factor) a = math_ops.round(a) a = math_ops.divide(a, factor) - return np_utils.tensor_to_ndarray(a).astype(dtype) + return a.astype(dtype) setattr(np_arrays.ndarray, '__round__', around) @@ -774,18 +751,16 @@ def reshape(a, newshape, order='C'): raise ValueError('Unsupported order argument {}'.format(order)) a = asarray(a) - if isinstance(newshape, np_arrays.ndarray): - newshape = newshape.data if isinstance(newshape, int): newshape = [newshape] if order == 'F': r = array_ops.transpose( - array_ops.reshape(array_ops.transpose(a.data), newshape[::-1])) + array_ops.reshape(array_ops.transpose(a), newshape[::-1])) else: - r = array_ops.reshape(a.data, newshape) + r = array_ops.reshape(a, newshape) - return np_utils.tensor_to_ndarray(r) + return r def _reshape_method_wrapper(a, *newshape, **kwargs): @@ -802,13 +777,13 @@ def _reshape_method_wrapper(a, *newshape, **kwargs): @np_utils.np_doc('expand_dims') def expand_dims(a, axis): a = asarray(a) - return np_utils.tensor_to_ndarray(array_ops.expand_dims(a.data, axis=axis)) + return array_ops.expand_dims(a, axis=axis) @np_utils.np_doc('squeeze') def squeeze(a, axis=None): a = asarray(a) - return np_utils.tensor_to_ndarray(array_ops.squeeze(a, axis)) + return array_ops.squeeze(a, axis) @np_utils.np_doc('transpose') @@ -816,12 +791,12 @@ def transpose(a, axes=None): a = asarray(a) if axes is not None: axes = asarray(axes) - return np_utils.tensor_to_ndarray(array_ops.transpose(a=a.data, perm=axes)) + return array_ops.transpose(a=a, perm=axes) @np_utils.np_doc('swapaxes') def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring - a = asarray(a).data + a = asarray(a) def adjust_axes(axes, rank): def f(x): if isinstance(x, int): @@ -848,7 +823,7 @@ def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring perm = array_ops.tensor_scatter_update(perm, [[axis1], [axis2]], [axis2, axis1]) a = array_ops.transpose(a, perm) - return np_utils.tensor_to_ndarray(a) + return a @np_utils.np_doc('moveaxis') @@ -857,7 +832,7 @@ def moveaxis(a, source, destination): # pylint: disable=missing-docstring if not source and not destination: return a - a = asarray(a).data + a = asarray(a) if isinstance(source, int): source = (source,) @@ -908,13 +883,7 @@ def moveaxis(a, source, destination): # pylint: disable=missing-docstring perm, array_ops.expand_dims(destination, 1), source) a = array_ops.transpose(a, perm) - return np_utils.tensor_to_ndarray(a) - - -# TODO(wangpeng): Make a custom `setattr` that also sets docstring for the -# method. -setattr(np_arrays.ndarray, 'transpose', transpose) -setattr(np_arrays.ndarray, 'reshape', _reshape_method_wrapper) + return a @np_utils.np_doc('pad') @@ -926,12 +895,11 @@ def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-na mode = mode.upper() array = asarray(array) pad_width = asarray(pad_width, dtype=dtypes.int32) - return np_utils.tensor_to_ndarray( - array_ops.pad( - tensor=array.data, - paddings=pad_width.data, - mode=mode, - constant_values=constant_values)) + return array_ops.pad( + tensor=array, + paddings=pad_width, + mode=mode, + constant_values=constant_values) @np_utils.np_doc('take') @@ -943,8 +911,8 @@ def take(a, indices, axis=None, out=None, mode='clip'): if mode not in {'raise', 'clip', 'wrap'}: raise ValueError("Invalid mode '{}' for take".format(mode)) - a = asarray(a).data - indices = asarray(indices).data + a = asarray(a) + indices = asarray(indices) if axis is None: a = array_ops.reshape(a, [-1]) @@ -958,7 +926,7 @@ def take(a, indices, axis=None, out=None, mode='clip'): else: raise ValueError("The 'raise' mode to take is not supported.") - return np_utils.tensor_to_ndarray(array_ops.gather(a, indices, axis=axis)) + return array_ops.gather(a, indices, axis=axis) @np_utils.np_doc_only('where') @@ -969,8 +937,7 @@ def where(condition, x=None, y=None): return nonzero(condition) elif x is not None and y is not None: x, y = _promote_dtype(x, y) - return np_utils.tensor_to_ndarray( - array_ops.where_v2(condition.data, x.data, y.data)) + return array_ops.where_v2(condition, x, y) raise ValueError('Both x and y must be ndarrays, or both must be None.') @@ -1044,8 +1011,7 @@ def split(ary, indices_or_sections, axis=0): ary = asarray(ary) if not isinstance(indices_or_sections, six.integer_types): indices_or_sections = _boundaries_to_sizes(ary, indices_or_sections, axis) - result = array_ops.split(ary.data, indices_or_sections, axis=axis) - return [np_utils.tensor_to_ndarray(a) for a in result] + return array_ops.split(ary, indices_or_sections, axis=axis) def _split_on_axis(np_fun_name, axis): @@ -1077,7 +1043,7 @@ def stack(arrays, axis=0): # pylint: disable=missing-function-docstring return swapaxes(arrays, 0, axis) arrays = _promote_dtype(*arrays) # pylint: disable=protected-access unwrapped_arrays = [ - a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays + a if isinstance(a, np_arrays.ndarray) else a for a in arrays ] return asarray(array_ops.stack(unwrapped_arrays, axis)) @@ -1087,7 +1053,7 @@ def hstack(tup): arrays = [atleast_1d(a) for a in tup] arrays = _promote_dtype(*arrays) # pylint: disable=protected-access unwrapped_arrays = [ - a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays + a if isinstance(a, np_arrays.ndarray) else a for a in arrays ] rank = array_ops.rank(unwrapped_arrays[0]) return np_utils.cond( @@ -1101,7 +1067,7 @@ def vstack(tup): arrays = [atleast_2d(a) for a in tup] arrays = _promote_dtype(*arrays) # pylint: disable=protected-access unwrapped_arrays = [ - a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays + a if isinstance(a, np_arrays.ndarray) else a for a in arrays ] return array_ops.concat(unwrapped_arrays, axis=0) @@ -1111,13 +1077,13 @@ def dstack(tup): arrays = [atleast_3d(a) for a in tup] arrays = _promote_dtype(*arrays) # pylint: disable=protected-access unwrapped_arrays = [ - a.data if isinstance(a, np_arrays.ndarray) else a for a in arrays + a if isinstance(a, np_arrays.ndarray) else a for a in arrays ] return array_ops.concat(unwrapped_arrays, axis=2) def _pad_left_to(n, old_shape): - old_shape = asarray(old_shape, dtype=np.int32).data + old_shape = asarray(old_shape, dtype=np.int32) new_shape = array_ops.pad( old_shape, [[math_ops.maximum(n - array_ops.size(old_shape), 0), 0]], constant_values=1) @@ -1143,8 +1109,8 @@ def _atleast_nd(n, new_shape, *arys): return asarray( np_utils.cond( np_utils.greater(n, array_ops.rank(x)), - lambda: reshape(x, new_shape(n, array_ops.shape(x.data))).data, - lambda: x.data)) + lambda: reshape(x, new_shape(n, array_ops.shape(x))), + lambda: x)) arys = list(map(f, arys)) if len(arys) == 1: @@ -1182,16 +1148,14 @@ def atleast_3d(*arys): # pylint: disable=missing-docstring @np_utils.np_doc('nonzero') def nonzero(a): - a = atleast_1d(a).data + a = atleast_1d(a) if a.shape.rank is None: raise ValueError("The rank of `a` is unknown, so we can't decide how many " 'arrays to return.') - return nest.map_structure( - np_arrays.tensor_to_ndarray, - array_ops.unstack( - array_ops.where_v2(math_ops.cast(a, dtypes.bool)), - a.shape.rank, - axis=1)) + return array_ops.unstack( + array_ops.where_v2(math_ops.cast(a, dtypes.bool)), + a.shape.rank, + axis=1) @np_utils.np_doc('diag_indices') @@ -1231,12 +1195,12 @@ def tri(N, M=None, k=0, dtype=None): # pylint: disable=invalid-name,missing-doc r = o else: r = array_ops.matrix_band_part(o, -1, k) - return np_utils.tensor_to_ndarray(r) + return r @np_utils.np_doc('tril') def tril(m, k=0): # pylint: disable=missing-docstring - m = asarray(m).data + m = asarray(m) if m.shape.ndims is None: raise ValueError('Argument to tril should have known rank') m_shape = m.shape.as_list() @@ -1251,14 +1215,13 @@ def tril(m, k=0): # pylint: disable=missing-docstring z = constant_op.constant(0, m.dtype) mask = tri(*m_shape[-2:], k=k, dtype=bool) - return np_utils.tensor_to_ndarray( - array_ops.where_v2( - array_ops.broadcast_to(mask, array_ops.shape(m)), m, z)) + return array_ops.where_v2( + array_ops.broadcast_to(mask, array_ops.shape(m)), m, z) @np_utils.np_doc('triu') def triu(m, k=0): # pylint: disable=missing-docstring - m = asarray(m).data + m = asarray(m) if m.shape.ndims is None: raise ValueError('Argument to triu should have known rank') m_shape = m.shape.as_list() @@ -1273,22 +1236,20 @@ def triu(m, k=0): # pylint: disable=missing-docstring z = constant_op.constant(0, m.dtype) mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) - return np_utils.tensor_to_ndarray( - array_ops.where_v2( - array_ops.broadcast_to(mask, array_ops.shape(m)), z, m)) + return array_ops.where_v2( + array_ops.broadcast_to(mask, array_ops.shape(m)), z, m) @np_utils.np_doc('flip') def flip(m, axis=None): # pylint: disable=missing-docstring - m = asarray(m).data + m = asarray(m) if axis is None: - return np_utils.tensor_to_ndarray( - array_ops.reverse(m, math_ops.range(array_ops.rank(m)))) + return array_ops.reverse(m, math_ops.range(array_ops.rank(m))) axis = np_utils._canonicalize_axis(axis, array_ops.rank(m)) # pylint: disable=protected-access - return np_utils.tensor_to_ndarray(array_ops.reverse(m, [axis])) + return array_ops.reverse(m, [axis]) @np_utils.np_doc('flipud') @@ -1303,15 +1264,15 @@ def fliplr(m): # pylint: disable=missing-docstring @np_utils.np_doc('roll') def roll(a, shift, axis=None): # pylint: disable=missing-docstring - a = asarray(a).data + a = asarray(a) if axis is not None: - return np_utils.tensor_to_ndarray(manip_ops.roll(a, shift, axis)) + return manip_ops.roll(a, shift, axis) # If axis is None, the roll happens as a 1-d tensor. original_shape = array_ops.shape(a) a = manip_ops.roll(array_ops.reshape(a, [-1]), shift, 0) - return np_utils.tensor_to_ndarray(array_ops.reshape(a, original_shape)) + return array_ops.reshape(a, original_shape) @np_utils.np_doc('rot90') @@ -1336,7 +1297,7 @@ def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring @np_utils.np_doc('vander') def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,invalid-name - x = asarray(x).data + x = asarray(x) x_shape = array_ops.shape(x) N = N or x_shape[0] @@ -1368,9 +1329,8 @@ def vander(x, N=None, increasing=False): # pylint: disable=missing-docstring,in delta = -1 x = array_ops.expand_dims(x, -1) - return np_utils.tensor_to_ndarray( - math_ops.pow( - x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype))) + return math_ops.pow( + x, math_ops.cast(math_ops.range(start, limit, delta), dtype=x.dtype)) @np_utils.np_doc('ix_') @@ -1378,7 +1338,7 @@ def ix_(*args): # pylint: disable=missing-docstring n = len(args) output = [] for i, a in enumerate(args): - a = asarray(a).data + a = asarray(a) a_rank = array_ops.rank(a) a_rank_temp = np_utils.get_static_value(a_rank) if a_rank_temp is not None: @@ -1393,11 +1353,9 @@ def ix_(*args): # pylint: disable=missing-docstring new_shape[i] = -1 dtype = a.dtype if dtype == dtypes.bool: - output.append( - np_utils.tensor_to_ndarray( - array_ops.reshape(nonzero(a)[0].data, new_shape))) + output.append(array_ops.reshape(nonzero(a)[0], new_shape)) elif dtype.is_integer: - output.append(np_utils.tensor_to_ndarray(array_ops.reshape(a, new_shape))) + output.append(array_ops.reshape(a, new_shape)) else: raise ValueError( 'Only integer and bool dtypes are supported, got {}'.format(dtype)) @@ -1413,9 +1371,8 @@ def broadcast_arrays(*args, **kwargs): # pylint: disable=missing-docstring if kwargs: raise ValueError('Received unsupported arguments {}'.format(kwargs.keys())) - args = [asarray(arg).data for arg in args] - args = np_utils.tf_broadcast(*args) - return [np_utils.tensor_to_ndarray(arg) for arg in args] + args = [asarray(arg) for arg in args] + return np_utils.tf_broadcast(*args) @np_utils.np_doc_only('sign') @@ -1428,13 +1385,13 @@ def sign(x, out=None, where=None, **kwargs): # pylint: disable=missing-docstrin raise ValueError('tf.numpy doesnt support setting {}'.format(kwargs.keys())) x = asarray(x) - dtype = x.dtype + dtype = x.dtype.as_numpy_dtype if np.issubdtype(dtype, np.complex): - result = math_ops.cast(math_ops.sign(math_ops.real(x.data)), dtype) + result = math_ops.cast(math_ops.sign(math_ops.real(x)), dtype) else: - result = math_ops.sign(x.data) + result = math_ops.sign(x) - return np_utils.tensor_to_ndarray(result) + return result # Note that np.take_along_axis may not be present in some supported versions of @@ -1447,9 +1404,6 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring if axis is None: return take_along_axis(arr.ravel(), indices, 0) - arr = arr.data - indices = indices.data - rank = array_ops.rank(arr) axis = axis + rank if axis < 0 else axis @@ -1475,7 +1429,7 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring # Correct indices since gather doesn't correctly handle negative indices. indices = array_ops.where_v2(indices < 0, indices + arr_shape[axis], indices) - swapaxes_ = lambda t: swapaxes(np_utils.tensor_to_ndarray(t), axis, -1).data + swapaxes_ = lambda t: swapaxes(t, axis, -1) dont_move_axis_to_end = math_ops.equal(axis, np_utils.subtract(rank, 1)) arr = np_utils.cond(dont_move_axis_to_end, lambda: arr, @@ -1495,7 +1449,7 @@ def take_along_axis(arr, indices, axis): # pylint: disable=missing-docstring lambda: swapaxes_(result)) result.set_shape(possible_result_shape) - return np_utils.tensor_to_ndarray(result) + return result _SLICE_ERORR = ( @@ -1519,7 +1473,7 @@ def _as_index(idx, need_scalar=True): """ if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)): return idx, True - data = asarray(idx).data + data = asarray(idx) if data.dtype == dtypes.bool: if data.shape.ndims != 1: # TODO(agarwal): handle higher rank boolean masks. @@ -1730,14 +1684,14 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None): dims_contiguous = False break indices = [advanced_indices_map[x] for x in dims] - indices = [x.data for x in _promote_dtype(*indices)] + indices = _promote_dtype(*indices) indices = np_utils.tf_broadcast(*indices) stacked_indices = array_ops.stack(indices, axis=-1) # Skip the contiguous-dims optimization for update because there is no # tf.*scatter* op that supports the `axis` argument. if not dims_contiguous or updates is not None: if range(len(dims)) != dims: - tensor = moveaxis(tensor, dims, range(len(dims))).data + tensor = moveaxis(tensor, dims, range(len(dims))) tensor_shape_prefix = array_ops.shape( tensor, out_type=stacked_indices.dtype)[:len(dims)] stacked_indices = array_ops.where_v2( @@ -1763,7 +1717,7 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None): def range_(start, length): return range(start, start + length) updates = moveaxis(updates, range_(batch_start, batch_size), - range(batch_size)).data + range(batch_size)) if update_method == _UpdateMethod.UPDATE: update_op = array_ops.tensor_scatter_update elif update_method == _UpdateMethod.ADD: @@ -1775,7 +1729,7 @@ def _slice_helper(tensor, slice_spec, update_method=None, updates=None): tensor = update_op( tensor, stacked_indices, updates) if range(len(dims)) != dims: - tensor = moveaxis(tensor, range(len(dims)), dims).data + tensor = moveaxis(tensor, range(len(dims)), dims) return array_ops.tensor_strided_slice_update( original_tensor, packed_begin, @@ -1842,14 +1796,13 @@ def _getitem(self, slice_spec): slice_spec.dtype == dtypes.bool) or (isinstance(slice_spec, (np.ndarray, np_arrays.ndarray)) and slice_spec.dtype == np.bool)): - return np_utils.tensor_to_ndarray( - array_ops.boolean_mask(tensor=self.data, mask=slice_spec)) + return array_ops.boolean_mask(tensor=self, mask=slice_spec) if not isinstance(slice_spec, tuple): slice_spec = _as_spec_tuple(slice_spec) - result_t = _slice_helper(self.data, slice_spec) - return np_utils.tensor_to_ndarray(result_t) + result_t = _slice_helper(self, slice_spec) + return result_t def _with_index_update_helper(update_method, a, slice_spec, updates): @@ -1865,11 +1818,11 @@ def _with_index_update_helper(update_method, a, slice_spec, updates): a_dtype = a.dtype a, updates = _promote_dtype_binary(a, updates) - result_t = _slice_helper(a.data, slice_spec, update_method, updates.data) - return np_utils.tensor_to_ndarray(result_t).astype(a_dtype) + result_t = _slice_helper(a, slice_spec, update_method, updates) + return result_t.astype(a_dtype) -setattr(np_arrays.ndarray, '__getitem__', _getitem) +setattr(np_arrays.ndarray, '_numpy_style_getitem', _getitem) setattr(np_arrays.ndarray, '_with_index_update', functools.partial(_with_index_update_helper, _UpdateMethod.UPDATE)) setattr(np_arrays.ndarray, '_with_index_add', diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py index b3beb32793b..8fa324cbd8b 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops_test.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops.numpy_ops import np_array_ops from tensorflow.python.ops.numpy_ops import np_arrays +from tensorflow.python.ops.numpy_ops import np_math_ops from tensorflow.python.platform import test @@ -305,49 +306,47 @@ class ArrayCreationTest(test.TestCase): def test_copy_equal_false(): # Backing tensor is the same if copy=False, other attributes being None. - self.assertIs( - np_array_ops.array(zeros_list, copy=False).data, zeros_list.data) - self.assertIs( - np_array_ops.array(zeros_list.data, copy=False).data, zeros_list.data) + self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list) + self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list) # Backing tensor is different if ndmin is not satisfied. self.assertIsNot( - np_array_ops.array(zeros_list, copy=False, ndmin=2).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, ndmin=2), + zeros_list) self.assertIsNot( - np_array_ops.array(zeros_list.data, copy=False, ndmin=2).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, ndmin=2), + zeros_list) self.assertIs( - np_array_ops.array(zeros_list, copy=False, ndmin=1).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, ndmin=1), + zeros_list) self.assertIs( - np_array_ops.array(zeros_list.data, copy=False, ndmin=1).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, ndmin=1), + zeros_list) # Backing tensor is different if dtype is not satisfied. self.assertIsNot( - np_array_ops.array(zeros_list, copy=False, dtype=int).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, dtype=int), + zeros_list) self.assertIsNot( - np_array_ops.array(zeros_list.data, copy=False, dtype=int).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, dtype=int), + zeros_list) self.assertIs( - np_array_ops.array(zeros_list, copy=False, dtype=float).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, dtype=float), + zeros_list) self.assertIs( - np_array_ops.array(zeros_list.data, copy=False, dtype=float).data, - zeros_list.data) + np_array_ops.array(zeros_list, copy=False, dtype=float), + zeros_list) test_copy_equal_false() with ops.device('CPU:1'): test_copy_equal_false() - self.assertNotIn('CPU:1', zeros_list.data.backing_device) + self.assertNotIn('CPU:1', zeros_list.backing_device) with ops.device('CPU:1'): - self.assertIn('CPU:1', np_array_ops.array(zeros_list, copy=True).data - .backing_device) - self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True).data - .backing_device) + self.assertIn( + 'CPU:1', np_array_ops.array(zeros_list, copy=True).backing_device) + self.assertIn( + 'CPU:1', np_array_ops.array(np.array(0), copy=True).backing_device) def testAsArray(self): for a, dtype in itertools.product(self.all_arrays, self.all_types): @@ -515,9 +514,6 @@ class ArrayCreationTest(test.TestCase): msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( msg, expected.shape, actual.shape) self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) def match_dtype(self, actual, expected, msg=None): if msg: @@ -535,7 +531,7 @@ class ArrayCreationTest(test.TestCase): self.match_dtype(actual, expected, msg) self.match_shape(actual, expected, msg) if not almost: - if not actual.shape: + if not actual.shape.rank: self.assertEqual(actual.tolist(), expected.tolist()) else: self.assertSequenceEqual(actual.tolist(), expected.tolist()) @@ -636,11 +632,11 @@ class ArrayMethodsTest(test.TestCase): run_test(np.arange(9).reshape((3, 3)).tolist()) a = np_array_ops.asarray(0) - self.assertNotIn('CPU:1', a.data.backing_device) + self.assertNotIn('CPU:1', a.backing_device) with ops.device('CPU:1'): - self.assertIn('CPU:1', np_array_ops.array(a, copy=True).data + self.assertIn('CPU:1', np_array_ops.array(a, copy=True) .backing_device) - self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True).data + self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True) .backing_device) def testCumProdAndSum(self): @@ -824,12 +820,13 @@ class ArrayMethodsTest(test.TestCase): self.assertRaises(NotImplementedError, np_array_ops.size, np.ones((2, 2)), 1) - @def_function.function(input_signature=[tensor_spec.TensorSpec(shape=None)]) + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(dtype=dtypes.float64, shape=None)]) def f(arr): arr = np_array_ops.asarray(arr) return np_array_ops.size(arr) - self.assertEqual(f(np_array_ops.ones((3, 2))).data.numpy(), 6) + self.assertEqual(f(np_array_ops.ones((3, 2))).numpy(), 6) def testRavel(self): @@ -984,9 +981,6 @@ class ArrayMethodsTest(test.TestCase): msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( msg, expected.shape, actual.shape) self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) def match_dtype(self, actual, expected, msg=None): if msg: @@ -1004,7 +998,7 @@ class ArrayMethodsTest(test.TestCase): if check_dtype: self.match_dtype(actual, expected, msg) self.match_shape(actual, expected, msg) - if not actual.shape: + if not actual.shape.rank: self.assertAllClose(actual.tolist(), expected.tolist()) else: self.assertAllClose(actual.tolist(), expected.tolist()) @@ -1165,9 +1159,6 @@ class ArrayManipulationTest(test.TestCase): msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( msg, expected.shape, actual.shape) self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) def match_dtype(self, actual, expected, msg=None): if msg: @@ -1184,7 +1175,7 @@ class ArrayManipulationTest(test.TestCase): self.assertIsInstance(actual, np_arrays.ndarray) self.match_dtype(actual, expected, msg) self.match_shape(actual, expected, msg) - if not actual.shape: + if not actual.shape.rank: self.assertEqual(actual.tolist(), expected.tolist()) else: self.assertSequenceEqual(actual.tolist(), expected.tolist()) @@ -1192,4 +1183,6 @@ class ArrayManipulationTest(test.TestCase): if __name__ == '__main__': ops.enable_eager_execution() + ops.enable_numpy_style_type_promotion() + np_math_ops.enable_numpy_methods_on_tensor() test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index ade758d36d3..cd879204a3e 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -20,18 +20,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np import six -from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_spec -from tensorflow.python.framework import type_spec -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops.numpy_ops import np_dtypes -from tensorflow.python.ops.numpy_ops import np_export def convert_to_tensor(value, dtype=None, dtype_hint=None): @@ -58,297 +51,4 @@ def convert_to_tensor(value, dtype=None, dtype_hint=None): return ops.convert_to_tensor(value, dtype=dtype, dtype_hint=dtype_hint) -class NdarraySpec(type_spec.BatchableTypeSpec): - """Type specification for a `tf.experiemntal.numpy.ndarray`.""" - - value_type = property(lambda self: ndarray) - - def __init__(self, data_spec): - if not isinstance(data_spec, tensor_spec.TensorSpec): - raise ValueError('NdarraySpec.__init__ was expecting a tf.TypeSpec, ' - 'but got a {} instead.'.format(type(data_spec))) - self._data_spec = data_spec - self._hash = None - - @property - def _component_specs(self): - return self._data_spec - - def _to_components(self, value): - return value.data - - def _from_components(self, data): - return tensor_to_ndarray(data) - - def _serialize(self): - return (self._data_spec,) - - def _batch(self, batch_size): - return NdarraySpec(self._data_spec._batch(batch_size)) # pylint: disable=protected-access - - def _unbatch(self): - return NdarraySpec(self._data_spec._unbatch()) # pylint: disable=protected-access - - def __hash__(self): - if self._hash is None: - self._hash = hash((type(self), self._data_spec)) - return self._hash - - -@np_export.np_export('ndarray') # pylint: disable=invalid-name -class ndarray(composite_tensor.CompositeTensor): - """Equivalent of numpy.ndarray backed by TensorFlow tensors. - - This does not support all features of NumPy ndarrays e.g. strides and - memory order since, unlike NumPy, the backing storage is not a raw memory - buffer. - - TODO(srbs): Clearly specify which attributes and methods are not supported - or if there are any differences in behavior. - """ - - __slots__ = ['_data', '_dtype', '_type_spec_internal'] - - def __init__(self, shape, dtype=float, buffer=None): # pylint: disable=redefined-builtin - """Initializes an ndarray. - - This is a low level interface for building ndarrays and should be avoided. - Users should instead use methods in array_creation.py. - - This class provides a numpy.ndarray like interface for a TF Tensor with a - fully-defined shape. Note that, unlike the backing buffer of np.ndarray, - Tensors are immutable. So, operations like `__setitem__` are performed by - replacing the Tensor. This restricts the ability to implement NumPy `view` - semantics. - - Compared to numpy.ndarray, this does not support `offset`, `strides` - and `order` arguments. - - Args: - shape: The shape of the array. Must be a scalar, an iterable of integers - or a `TensorShape` object. - dtype: Optional. The dtype of the array. Must be a python type, a numpy - type or a tensorflow `DType` object. - buffer: Optional. The backing buffer of the array. Must have shape - `shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`. - - Raises: - ValueError: If `buffer` is specified and its shape does not match - `shape`. - """ - if dtype and not isinstance(dtype, dtypes.DType): - dtype = dtypes.as_dtype(np.dtype(dtype)) - if buffer is None: - buffer = array_ops.zeros(shape, dtype=dtype) - else: - if isinstance(buffer, ndarray): - buffer = buffer.data - elif isinstance(buffer, np.ndarray): - # If `buffer` is a np.ndarray, the Tensor will share the underlying - # storage of the array. - buffer = convert_to_tensor(value=buffer, dtype=dtype) - elif not isinstance(buffer, ops.Tensor): - raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,' - ' Tensor or np.ndarray.'.format(type(buffer))) - - if shape is not None: - buffer.set_shape(shape) - - assert isinstance(buffer, ops.Tensor) - if dtype and dtype != buffer.dtype: - buffer = math_ops.cast(buffer, dtype) - self._data = buffer - self._type_spec_internal = None - self._dtype = None - - @classmethod - def from_tensor(cls, tensor): - o = cls.__new__(cls, None) - # pylint: disable=protected-access - o._data = tensor - o._dtype = None - o._type_spec_internal = None - # pylint: enable=protected-access - return o - - @property - def _type_spec(self): - if self._type_spec_internal is None: - self._type_spec_internal = NdarraySpec( - type_spec.type_spec_from_value(self._data)) - return self._type_spec_internal - - @property - def data(self): - """Tensor object containing the array data. - - This has a few key differences from the Python buffer object used in - NumPy arrays. - 1. Tensors are immutable. So operations requiring in-place edit, e.g. - __setitem__, are performed by replacing the underlying buffer with a new - one. - 2. Tensors do not provide access to their raw buffer. - - Returns: - A Tensor. - """ - return self._data - - @property - def shape(self): - """Returns a tuple or tf.Tensor of array dimensions.""" - shape = self.data.shape - if shape.is_fully_defined(): - return tuple(shape.as_list()) - else: - return array_ops.shape(self.data) - - @property - def dtype(self): - if self._dtype is None: - self._dtype = np_dtypes._get_cached_dtype(self._data.dtype) # pylint: disable=protected-access - return self._dtype - - def _is_boolean(self): - return self._data.dtype == dtypes.bool - - @property - def ndim(self): - ndims = self.data.shape.ndims - if ndims is None: - return array_ops.rank(self.data) - else: - return ndims - - @property - def size(self): - """Returns the number of elements in the array.""" - shape = self.shape - if isinstance(shape, ops.Tensor): - return array_ops.size(self.data) - else: - return np.prod(self.shape) - - @property - def T(self): # pylint: disable=invalid-name - return self.transpose() - - def __len__(self): - shape = self.shape - if isinstance(shape, ops.Tensor): - raise TypeError('len() of symbolic tensor undefined') - elif shape: - return self.shape[0] - else: - raise TypeError('len() of unsized object.') - - def astype(self, dtype): - if self.dtype == dtype: - return self - else: - return tensor_to_ndarray(math_ops.cast(self.data, dtype)) - - # Unary operations - def __neg__(self): - return tensor_to_ndarray(-self.data) # pylint: disable=invalid-unary-operand-type - - def __pos__(self): - return self - - __hash__ = None - - def __int__(self): - return int(self.data) - - def __float__(self): - return float(self.data) - - def __bool__(self): - return bool(self.data) - - def __nonzero__(self): - return self.__bool__() - - def __iter__(self): - if not isinstance(self.data, ops.EagerTensor): - raise TypeError('Iteration over symbolic tensor is not allowed') - for i in range(self.shape[0]): - result_t = self.data[i] - yield tensor_to_ndarray(result_t) - return - - def __array__(self, dtype=None): - """Returns a NumPy ndarray. - - This allows instances of this class to be directly used in NumPy routines. - However, doing that may force a copy to CPU. - - Args: - dtype: A NumPy compatible type. - - Returns: - A NumPy ndarray. - """ - return np.asarray(self.data, dtype) - - # NOTE: we currently prefer interop with TF to allow TF to take precedence. - __array_priority__ = 90 - - def __array_module__(self, types): - # Experimental support for NumPy's module dispatch with NEP-37: - # https://numpy.org/neps/nep-0037-array-module.html - # Currently requires https://github.com/seberg/numpy-dispatch - - # pylint: disable=g-import-not-at-top - import tensorflow.compat.v2 as tf - - if all(issubclass(t, (ndarray, np.ndarray)) for t in types): - return tf.experimental.numpy - else: - return NotImplemented - - def __index__(self): - """Returns a python scalar. - - This allows using an instance of this class as an array index. - Note that only arrays of integer types with size 1 can be used as array - indices. - - Returns: - A Python scalar. - - Raises: - TypeError: If the array is not of an integer type. - ValueError: If the array does not have size 1. - """ - # TODO(wangpeng): Handle graph mode - if not isinstance(self.data, ops.EagerTensor): - raise TypeError('Indexing using symbolic tensor is not allowed') - return self.data.numpy().item() - - def tolist(self): - return self.data.numpy().tolist() - - def __str__(self): - return 'ndarray<{}>'.format(self.data.__str__()) - - def __repr__(self): - return 'ndarray<{}>'.format(self.data.__repr__()) - - -def tensor_to_ndarray(tensor): - return ndarray.from_tensor(tensor) - - -def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False): - if as_ref: - raise ValueError('as_ref is not supported.') - if dtype and dtypes.as_dtype(arr.dtype) != dtype: - return math_ops.cast(arr.data, dtype) - result_t = arr.data - if name: - result_t = array_ops.identity(result_t, name=name) - return result_t - - -ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor) +ndarray = ops.Tensor diff --git a/tensorflow/python/ops/numpy_ops/np_arrays_test.py b/tensorflow/python/ops/numpy_ops/np_arrays_test.py index ab407d2bfcf..782e3f36617 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays_test.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays_test.py @@ -18,11 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - import numpy as np -from tensorflow.python.framework import constant_op +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -32,48 +30,33 @@ from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=unuse from tensorflow.python.platform import test from tensorflow.python.util import nest -t2a = np_arrays.tensor_to_ndarray - class ArrayTest(test.TestCase): def testDtype(self): - a = t2a(array_ops.zeros(shape=[1, 2], dtype=dtypes.int64)) - self.assertIs(a.dtype.type, np.int64) - self.assertAllEqual(0, a.dtype.type(0)) + a = array_ops.zeros(shape=[1, 2], dtype=dtypes.int64) + self.assertIs(a.dtype.as_numpy_dtype, np.int64) + np_dt = a.dtype.as_numpy_dtype + self.assertAllEqual(0, np_dt(0)) def testAstype(self): - a = t2a(ops.convert_to_tensor(value=1.1, - dtype=dtypes.float32)).astype(np.int32) - self.assertIs(a.dtype.type, np.int32) + a = ops.convert_to_tensor(value=1.1, dtype=dtypes.float32).astype(np.int32) + self.assertIs(a.dtype.as_numpy_dtype, np.int32) self.assertAllEqual(1, a) - a = t2a(ops.convert_to_tensor(value=[0.0, 1.1], - dtype=dtypes.float32)).astype(np.bool_) - self.assertIs(a.dtype.type, np.bool_) + a = ops.convert_to_tensor(value=[0.0, 1.1], dtype=dtypes.float32).astype( + np.bool_) + self.assertIs(a.dtype.as_numpy_dtype, np.bool_) self.assertAllEqual([False, True], a) - def testConstructor(self): - t = constant_op.constant([[1], [1]]) - a = np_arrays.ndarray(shape=(2, 1), buffer=t) - self.assertAllEqual(t, a) - self.assertEqual(dtypes.float64, a.dtype) - - a = np_arrays.ndarray(shape=(2, 1), dtype=dtypes.int32, buffer=t) - self.assertAllEqual(t, a) - self.assertEqual(dtypes.int32, a.dtype) - - with self.assertRaises(ValueError): # bad shape - _ = np_arrays.ndarray((2, 2), buffer=t) - def testNeg(self): - a = t2a(ops.convert_to_tensor(value=[1.0, 2.0])) - self.assertAllEqual([-1.0, -2.0], -a) + a = ops.convert_to_tensor(value=[1.0, 2.0]) + self.assertAllEqual([-1.0, -2.0], -a) # pylint: disable=invalid-unary-operand-type def _testBinOp(self, a, b, out, f, types=None): - a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32)) - b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32)) + a = ops.convert_to_tensor(value=a, dtype=np.int32) + b = ops.convert_to_tensor(value=b, dtype=np.int32) if not isinstance(out, np_arrays.ndarray): - out = t2a(ops.convert_to_tensor(value=out, dtype=np.int32)) + out = ops.convert_to_tensor(value=out, dtype=np.int32) if types is None: types = [[np.int32, np.int32, np.int32], [np.int64, np.int32, np.int64], [np.int32, np.int64, np.int64], @@ -84,7 +67,7 @@ class ArrayTest(test.TestCase): [np.float32, np.float64, np.float64]] for a_type, b_type, out_type in types: o = f(a.astype(a_type), b.astype(b_type)) - self.assertIs(o.dtype.type, out_type) + self.assertIs(o.dtype.as_numpy_dtype, out_type) out = out.astype(out_type) if np.issubdtype(out_type, np.inexact): self.assertAllClose(out, o) @@ -126,19 +109,20 @@ class ArrayTest(test.TestCase): def testTruediv(self): self._testBinOp([3, 5], [2, 4], - t2a(ops.convert_to_tensor(value=[1.5, 1.25])), + ops.convert_to_tensor(value=[1.5, 1.25]), lambda a, b: a.__truediv__(b), types=self._truediv_types) def testRtruediv(self): self._testBinOp([3, 5], [2, 4], - t2a(ops.convert_to_tensor(value=[1.5, 1.25])), + ops.convert_to_tensor(value=[1.5, 1.25]), lambda a, b: b.__rtruediv__(a), types=self._truediv_types) def _testCmp(self, a, b, out, f): - a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32)) - b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32)) + a = ops.convert_to_tensor(value=a, dtype=np.int32) + b = ops.convert_to_tensor(value=b, dtype=np.int32) + types = [[np.int32, np.int32], [np.int64, np.int32], [np.int32, np.int64], [np.float32, np.int32], [np.int32, np.float32], [np.float32, np.float32], [np.float64, np.float32], @@ -173,32 +157,41 @@ class ArrayTest(test.TestCase): def testInt(self): v = 10 - u = int(t2a(ops.convert_to_tensor(value=v))) + u = int(ops.convert_to_tensor(value=v)) self.assertIsInstance(u, int) self.assertAllEqual(v, u) def testFloat(self): v = 21.32 - u = float(t2a(ops.convert_to_tensor(value=v))) + u = float(ops.convert_to_tensor(value=v)) self.assertIsInstance(u, float) self.assertAllClose(v, u) def testBool(self): - b = bool(t2a(ops.convert_to_tensor(value=10))) + b = bool(ops.convert_to_tensor(value=10)) self.assertIsInstance(b, bool) self.assertTrue(b) - self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0)))) - self.assertTrue(bool(t2a(ops.convert_to_tensor(value=0.1)))) - self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0.0)))) + self.assertFalse(bool(ops.convert_to_tensor(value=0))) + self.assertTrue(bool(ops.convert_to_tensor(value=0.1))) + self.assertFalse(bool(ops.convert_to_tensor(value=0.0))) def testHash(self): - a = t2a(ops.convert_to_tensor(value=10)) - self.assertNotIsInstance(a, collections.Hashable) - with self.assertRaisesWithPredicateMatch(TypeError, r'unhashable type'): + a = ops.convert_to_tensor(value=10) + def eager(): hash(a) + def graph(): + @def_function.function + def f(x): + hash(x) + f(a) + for f in [eager, graph]: + with self.assertRaisesRegexp( + TypeError, + r'Tensor is unhashable. Instead, use tensor.ref\(\) as the key.'): + f() def testFromToCompositeTensor(self): - tensors = [t2a(ops.convert_to_tensor(0.1)), t2a(ops.convert_to_tensor(0.2))] + tensors = [ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)] flattened = nest.flatten(tensors, expand_composites=True) # Each ndarray contains only one tensor, so the flattened output should be @@ -216,6 +209,10 @@ class ArrayTest(test.TestCase): if __name__ == '__main__': - # TODO(wangpeng): Test in graph mode as well. + # TODO(wangpeng): Test in graph mode as well. Also test in V2 (the requirement + # for setting _USE_EQUALITY points to V2 behavior not being on). ops.enable_eager_execution() + ops.Tensor._USE_EQUALITY = True + ops.enable_numpy_style_type_promotion() + np_math_ops.enable_numpy_methods_on_tensor() test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_config.py b/tensorflow/python/ops/numpy_ops/np_config.py new file mode 100644 index 00000000000..05fc64ffd76 --- /dev/null +++ b/tensorflow/python/ops/numpy_ops/np_config.py @@ -0,0 +1,39 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Config functions for TF NumPy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops.numpy_ops import np_dtypes +from tensorflow.python.ops.numpy_ops import np_math_ops + + +def enable_numpy_behavior(prefer_float32=False): + """Enable NumPy behavior on Tensors. + + Includes addition of methods, type promotion on operator overloads and + support for NumPy-style slicing. + + Args: + prefer_float32: Whether to allow type inference to use float32, or use + float64 similar to NumPy. + """ + ops.enable_numpy_style_type_promotion() + ops.enable_numpy_style_slicing() + np_math_ops.enable_numpy_methods_on_tensor() + np_dtypes.set_prefer_float32(prefer_float32) diff --git a/tensorflow/python/ops/numpy_ops/np_dtypes.py b/tensorflow/python/ops/numpy_ops/np_dtypes.py index cde3883d3d9..1f4bb97e380 100644 --- a/tensorflow/python/ops/numpy_ops/np_dtypes.py +++ b/tensorflow/python/ops/numpy_ops/np_dtypes.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.ops.numpy_ops import np_export @@ -63,9 +64,27 @@ _to_float32 = { _cached_np_dtypes = {} + +# Difference between is_prefer_float32 and is_allow_float64: is_prefer_float32 +# only decides which dtype to use for Python floats; is_allow_float64 decides +# whether float64 dtypes can ever appear in programs. The latter is more +# restrictive than the former. +_prefer_float32 = False + + +# TODO(b/178862061): Consider removing this knob _allow_float64 = True +def is_prefer_float32(): + return _prefer_float32 + + +def set_prefer_float32(b): + global _prefer_float32 + _prefer_float32 = b + + def is_allow_float64(): return _allow_float64 @@ -85,8 +104,13 @@ def canonicalize_dtype(dtype): def _result_type(*arrays_and_dtypes): + def preprocess_float(x): + if is_prefer_float32() and isinstance(x, float): + return np.float32(x) + return x + arrays_and_dtypes = [preprocess_float(x) for x in arrays_and_dtypes] dtype = np.result_type(*arrays_and_dtypes) - return canonicalize_dtype(dtype) + return dtypes.as_dtype(canonicalize_dtype(dtype)) def _get_cached_dtype(dtype): @@ -105,9 +129,10 @@ def default_float_type(): """Gets the default float type. Returns: - If `is_allow_float64()` is true, returns float64; otherwise returns float32. + If `is_prefer_float32()` is false and `is_allow_float64()` is true, returns + float64; otherwise returns float32. """ - if is_allow_float64(): + if not is_prefer_float32() and is_allow_float64(): return float64 else: return float32 diff --git a/tensorflow/python/ops/numpy_ops/np_dtypes_test.py b/tensorflow/python/ops/numpy_ops/np_dtypes_test.py new file mode 100644 index 00000000000..b5a3ab9c325 --- /dev/null +++ b/tensorflow/python/ops/numpy_ops/np_dtypes_test.py @@ -0,0 +1,57 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf-numpy dtype utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops.numpy_ops import np_dtypes +from tensorflow.python.platform import test + + +class DTypeTest(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([False, True]) + def testAllowF64False(self, prefer_f32): + np_dtypes.set_allow_float64(False) + np_dtypes.set_prefer_float32(prefer_f32) + self.assertEqual(dtypes.float32, np_dtypes.default_float_type()) + self.assertEqual(dtypes.float32, + np_dtypes._result_type(np.zeros([], np.float64), 1.1)) + + def testAllowF64TruePreferF32False(self): + np_dtypes.set_allow_float64(True) + np_dtypes.set_prefer_float32(False) + self.assertEqual(dtypes.float64, np_dtypes.default_float_type()) + self.assertEqual(dtypes.float64, np_dtypes._result_type(1.1)) + + def testAllowF64TruePreferF32True(self): + np_dtypes.set_allow_float64(True) + np_dtypes.set_prefer_float32(True) + self.assertEqual(dtypes.float32, np_dtypes.default_float_type()) + self.assertEqual(dtypes.float32, np_dtypes._result_type(1.1)) + self.assertEqual(dtypes.float64, + np_dtypes._result_type(np.zeros([], np.float64), 1.1)) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_interop_test.py b/tensorflow/python/ops/numpy_ops/np_interop_test.py index 8999c8f832e..d265b5e7d66 100644 --- a/tensorflow/python/ops/numpy_ops/np_interop_test.py +++ b/tensorflow/python/ops/numpy_ops/np_interop_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as onp import tensorflow.compat.v2 as tf +from tensorflow.python.framework import ops from tensorflow.python.ops import numpy_ops as np +from tensorflow.python.ops.numpy_ops import np_math_ops # Tests for code snippet put in README.md @@ -174,27 +176,26 @@ class InteropTest(tf.test.TestCase): self.assertIsInstance(sq, onp.ndarray) self.assertEqual(100., sq[0]) +# TODO(b/171313773): why doesn't tensor have __array_module__ def testArrayModule(self): + self.skipTest("Tensor doesn't have __array_module__") arr = np.asarray([10]) - module = arr.__array_module__((np.ndarray,)) + module = arr.__array_module__((tf.Tensor,)) self.assertIs(module, tf.experimental.numpy) class Dummy: pass - module = arr.__array_module__((np.ndarray, Dummy)) + module = arr.__array_module__((tf.Tensor, Dummy)) self.assertIs(module, NotImplemented) - # TODO(nareshmodi): Fails since the autopacking code doesn't use - # nest.flatten. - - +# TODO(nareshmodi): Fails since the autopacking code doesn't use +# nest.flatten. # def testAutopacking(self): # arr1 = np.asarray(1.) # arr2 = np.asarray(2.) # arr3 = np.asarray(3.) # t = ops.convert_to_tensor_v2([arr1, arr2, arr3]) - # self.assertEqual(t.numpy(), [1., 2., 3.]) def testDistStratInterop(self): @@ -409,7 +410,9 @@ class FunctionTest(InteropTest): def testLen(self): - @tf.function + # len can be fixed by autograph. + # TODO(wangpeng): this test can just be removed + @tf.function(autograph=False) def f(x): # Note that shape of input to len is data dependent. return len(np.where(x)[0]) @@ -451,5 +454,7 @@ class VariableTest(InteropTest): if __name__ == '__main__': + ops.enable_numpy_style_type_promotion() + np_math_ops.enable_numpy_methods_on_tensor() tf.compat.v1.enable_eager_execution() tf.test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_logic_test.py b/tensorflow/python/ops/numpy_ops/np_logic_test.py index 85826873356..9e38a87e70c 100644 --- a/tensorflow/python/ops/numpy_ops/np_logic_test.py +++ b/tensorflow/python/ops/numpy_ops/np_logic_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tf numpy random number methods.""" +"""Tests for tf numpy logical methods.""" from __future__ import absolute_import from __future__ import division @@ -76,9 +76,6 @@ class LogicTest(test.TestCase): msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format( msg, expected.shape, actual.shape) self.assertEqual(actual.shape, expected.shape, msg=msg) - if msg: - msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg) - self.assertIsInstance(actual.shape, tuple, msg=msg) def match_dtype(self, actual, expected, msg=None): if msg: @@ -95,16 +92,17 @@ class LogicTest(test.TestCase): self.assertIsInstance(actual, np_arrays.ndarray) self.match_dtype(actual, expected, msg) self.match_shape(actual, expected, msg) - if not actual.shape: + if not actual.shape.rank: self.assertEqual(actual.tolist(), expected.tolist()) else: self.assertSequenceEqual(actual.tolist(), expected.tolist()) def make_numpy_compatible(s): - return s if not isinstance(s, np_arrays.ndarray) else s.data.numpy() + return s if not isinstance(s, np_arrays.ndarray) else s.numpy() if __name__ == '__main__': ops.enable_eager_execution() + np_math_ops.enable_numpy_methods_on_tensor() test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 85cfdf6c5b8..1fd90df06c8 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -74,7 +74,7 @@ def _bin_op(tf_fun, a, b, promote=True): else: a = np_array_ops.array(a) b = np_array_ops.array(b) - return np_utils.tensor_to_ndarray(tf_fun(a.data, b.data)) + return tf_fun(a, b) @np_utils.np_doc('add') @@ -177,9 +177,8 @@ def maximum(x1, x2): # pylint: disable=missing-function-docstring # Fast path for when maximum is used as relu. if isinstance( x2, numbers.Real) and not isinstance(x2, bool) and x2 == 0 and isinstance( - x1, np_arrays.ndarray) and not x1._is_boolean(): # pylint: disable=protected-access - return np_utils.tensor_to_ndarray( - nn_ops.relu(np_array_ops.asarray(x1).data)) + x1, np_arrays.ndarray) and x1.dtype != dtypes.bool: + return nn_ops.relu(np_array_ops.asarray(x1)) def max_or_or(x1, x2): if x1.dtype == dtypes.bool: @@ -212,12 +211,7 @@ def clip(a, a_min, a_max): # pylint: disable=missing-docstring return maximum(a, a_min) else: a, a_min, a_max = np_array_ops._promote_dtype(a, a_min, a_max) # pylint: disable=protected-access - return np_utils.tensor_to_ndarray( - clip_ops.clip_by_value( - *np_utils.tf_broadcast(a.data, a_min.data, a_max.data))) - - -setattr(np_arrays.ndarray, 'clip', clip) + return clip_ops.clip_by_value(*np_utils.tf_broadcast(a, a_min, a_max)) @np_utils.np_doc('matmul') @@ -241,6 +235,12 @@ def matmul(x1, x2): # pylint: disable=missing-docstring return _bin_op(f, x1, x2) +# Exported so it can be called from Tensor.__matmul__. NumPy's matmul handles +# batched matmul as well, so simply including promotion in TF's current +# __matmul__ implementation was not sufficient. +setattr(np_arrays.ndarray, '_matmul', matmul) + + @np_utils.np_doc('tensordot') def tensordot(a, b, axes=2): return _bin_op(lambda a, b: math_ops.tensordot(a, b, axes=axes), a, b) @@ -375,7 +375,7 @@ def heaviside(x1, x2): # pylint: disable=missing-function-docstring array_ops.where_v2(x1 > 0, constant_op.constant(1, dtype=x2.dtype), x2)) y = _bin_op(f, x1, x2) - if not np.issubdtype(y.dtype, np.inexact): + if not np.issubdtype(y.dtype.as_numpy_dtype, np.inexact): y = y.astype(np_dtypes.default_float_type()) return y @@ -392,13 +392,13 @@ def kron(a, b): # pylint: disable=missing-function-docstring t_a = np_utils.cond( a.ndim < b.ndim, lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda - a.data, np_array_ops._pad_left_to(b.ndim, a.shape)), - lambda: a.data) + a, np_array_ops._pad_left_to(b.ndim, a.shape)), + lambda: a) t_b = np_utils.cond( b.ndim < a.ndim, lambda: np_array_ops.reshape( # pylint: disable=g-long-lambda - b.data, np_array_ops._pad_left_to(a.ndim, b.shape)), - lambda: b.data) + b, np_array_ops._pad_left_to(a.ndim, b.shape)), + lambda: b) def _make_shape(shape, prepend): ones = array_ops.ones_like(shape) @@ -596,9 +596,9 @@ def _scalar(tf_fn, x, promote_to_float=False): floating point type, in which case the output type is same as x.dtype. """ x = np_array_ops.asarray(x) - if promote_to_float and not np.issubdtype(x.dtype, np.inexact): + if promote_to_float and not np.issubdtype(x.dtype.as_numpy_dtype, np.inexact): x = x.astype(np_dtypes.default_float_type()) - return np_utils.tensor_to_ndarray(tf_fn(x.data)) + return tf_fn(x) @np_utils.np_doc('log') @@ -814,7 +814,7 @@ def isreal(x): @np_utils.np_doc('iscomplexobj') def iscomplexobj(x): x = np_array_ops.array(x) - return np.issubdtype(x.dtype, np.complexfloating) + return np.issubdtype(x.dtype.as_numpy_dtype, np.complexfloating) @np_utils.np_doc('isrealobj') @@ -850,11 +850,12 @@ nanprod = _make_nan_reduction('nanprod', np_array_ops.prod, 1) @np_utils.np_doc('nanmean') def nanmean(a, axis=None, dtype=None, keepdims=None): # pylint: disable=missing-docstring a = np_array_ops.array(a) - if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer): + if np.issubdtype(a.dtype.as_numpy_dtype, np.bool_) or np.issubdtype( + a.dtype.as_numpy_dtype, np.integer): return np_array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims) nan_mask = logical_not(isnan(a)) if dtype is None: - dtype = a.dtype + dtype = a.dtype.as_numpy_dtype normalizer = np_array_ops.sum( nan_mask, axis=axis, dtype=dtype, keepdims=keepdims) return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer @@ -960,37 +961,16 @@ def _wrap(f, reverse=False): return _f -setattr(np_arrays.ndarray, '__abs__', absolute) -setattr(np_arrays.ndarray, '__floordiv__', _wrap(floor_divide)) -setattr(np_arrays.ndarray, '__rfloordiv__', _wrap(floor_divide, True)) -setattr(np_arrays.ndarray, '__mod__', _wrap(mod)) -setattr(np_arrays.ndarray, '__rmod__', _wrap(mod, True)) -setattr(np_arrays.ndarray, '__add__', _wrap(add)) -setattr(np_arrays.ndarray, '__radd__', _wrap(add, True)) -setattr(np_arrays.ndarray, '__sub__', _wrap(subtract)) -setattr(np_arrays.ndarray, '__rsub__', _wrap(subtract, True)) -setattr(np_arrays.ndarray, '__mul__', _wrap(multiply)) -setattr(np_arrays.ndarray, '__rmul__', _wrap(multiply, True)) -setattr(np_arrays.ndarray, '__matmul__', _wrap(matmul)) -setattr(np_arrays.ndarray, '__rmatmul__', _wrap(matmul, True)) -setattr(np_arrays.ndarray, '__pow__', _wrap(power)) -setattr(np_arrays.ndarray, '__rpow__', _wrap(power, True)) -setattr(np_arrays.ndarray, '__truediv__', _wrap(true_divide)) -setattr(np_arrays.ndarray, '__rtruediv__', _wrap(true_divide, True)) - - def _comparison(tf_fun, x1, x2, cast_bool_to_int=False): """Helper function for comparision.""" dtype = np_utils.result_type(x1, x2) # Cast x1 and x2 to the result_type if needed. x1 = np_array_ops.array(x1, dtype=dtype) x2 = np_array_ops.array(x2, dtype=dtype) - x1 = x1.data - x2 = x2.data if cast_bool_to_int and x1.dtype == dtypes.bool: x1 = math_ops.cast(x1, dtypes.int32) x2 = math_ops.cast(x2, dtypes.int32) - return np_utils.tensor_to_ndarray(tf_fun(x1, x2)) + return tf_fun(x1, x2) @np_utils.np_doc('equal') @@ -1043,7 +1023,7 @@ def array_equal(a1, a2): # pylint: disable=missing-function-docstring def _logical_binary_op(tf_fun, x1, x2): x1 = np_array_ops.array(x1, dtype=np.bool_) x2 = np_array_ops.array(x2, dtype=np.bool_) - return np_utils.tensor_to_ndarray(tf_fun(x1.data, x2.data)) + return tf_fun(x1, x2) @np_utils.np_doc('logical_and') @@ -1064,16 +1044,7 @@ def logical_xor(x1, x2): @np_utils.np_doc('logical_not') def logical_not(x): x = np_array_ops.array(x, dtype=np.bool_) - return np_utils.tensor_to_ndarray(math_ops.logical_not(x.data)) - - -setattr(np_arrays.ndarray, '__invert__', logical_not) -setattr(np_arrays.ndarray, '__lt__', _wrap(less)) -setattr(np_arrays.ndarray, '__le__', _wrap(less_equal)) -setattr(np_arrays.ndarray, '__gt__', _wrap(greater)) -setattr(np_arrays.ndarray, '__ge__', _wrap(greater_equal)) -setattr(np_arrays.ndarray, '__eq__', _wrap(equal)) -setattr(np_arrays.ndarray, '__ne__', _wrap(not_equal)) + return math_ops.logical_not(x) @np_utils.np_doc('linspace') @@ -1087,8 +1058,8 @@ def linspace( # pylint: disable=missing-docstring axis=0): if dtype: dtype = np_utils.result_type(dtype) - start = np_array_ops.array(start, dtype=dtype).data - stop = np_array_ops.array(stop, dtype=dtype).data + start = np_array_ops.array(start, dtype=dtype) + stop = np_array_ops.array(stop, dtype=dtype) if num < 0: raise ValueError('Number of samples {} must be non-negative.'.format(num)) step = ops.convert_to_tensor(np.nan) @@ -1109,28 +1080,27 @@ def linspace( # pylint: disable=missing-docstring if dtype: result = math_ops.cast(result, dtype) if retstep: - return (np_arrays.tensor_to_ndarray(result), - np_arrays.tensor_to_ndarray(step)) + return (result, step) else: - return np_arrays.tensor_to_ndarray(result) + return result @np_utils.np_doc('logspace') def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0): dtype = np_utils.result_type(start, stop, dtype) result = linspace( - start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis).data + start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis) result = math_ops.pow(math_ops.cast(base, result.dtype), result) if dtype: result = math_ops.cast(result, dtype) - return np_arrays.tensor_to_ndarray(result) + return result @np_utils.np_doc('geomspace') def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint: disable=missing-docstring - dtype = dtype or np_utils.result_type(start, stop, float(num), - np_array_ops.zeros((), dtype)) - computation_dtype = np.promote_types(dtype, np.float32) + dtype = dtypes.as_dtype(dtype) if dtype else np_utils.result_type( + start, stop, float(num), np_array_ops.zeros((), dtype)) + computation_dtype = np.promote_types(dtype.as_numpy_dtype, np.float32) start = np_array_ops.asarray(start, dtype=computation_dtype) stop = np_array_ops.asarray(stop, dtype=computation_dtype) # follow the numpy geomspace convention for negative and complex endpoints @@ -1147,7 +1117,7 @@ def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0): # pylint axis=0) if axis != 0: res = np_array_ops.moveaxis(res, 0, axis) - return np_utils.tensor_to_ndarray(math_ops.cast(res, dtype)) + return math_ops.cast(res, dtype) @np_utils.np_doc('ptp') @@ -1163,14 +1133,14 @@ def concatenate(arys, axis=0): if not arys: raise ValueError('Need at least one array to concatenate.') dtype = np_utils.result_type(*arys) - arys = [np_array_ops.array(array, dtype=dtype).data for array in arys] - return np_arrays.tensor_to_ndarray(array_ops.concat(arys, axis)) + arys = [np_array_ops.array(array, dtype=dtype) for array in arys] + return array_ops.concat(arys, axis) @np_utils.np_doc_only('tile') def tile(a, reps): # pylint: disable=missing-function-docstring - a = np_array_ops.array(a).data - reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]).data + a = np_array_ops.array(a) + reps = np_array_ops.array(reps, dtype=dtypes.int32).reshape([-1]) a_rank = array_ops.rank(a) reps_size = array_ops.size(reps) @@ -1181,13 +1151,12 @@ def tile(a, reps): # pylint: disable=missing-function-docstring constant_values=1) a = array_ops.reshape(a, a_shape) - return np_arrays.tensor_to_ndarray(array_ops.tile(a, reps)) + return array_ops.tile(a, reps) @np_utils.np_doc('count_nonzero') def count_nonzero(a, axis=None): - return np_arrays.tensor_to_ndarray( - math_ops.count_nonzero(np_array_ops.array(a).data, axis)) + return math_ops.count_nonzero(np_array_ops.array(a), axis) @np_utils.np_doc('argsort') @@ -1199,7 +1168,7 @@ def argsort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missin raise ValueError("'order' argument to sort is not supported.") stable = (kind == 'stable') - a = np_array_ops.array(a).data + a = np_array_ops.array(a) def _argsort(a, axis, stable): if axis is None: @@ -1225,20 +1194,19 @@ def sort(a, axis=-1, kind='quicksort', order=None): # pylint: disable=missing-d a = np_array_ops.array(a) if axis is None: - result_t = sort_ops.sort(array_ops.reshape(a.data, [-1]), 0) - return np_utils.tensor_to_ndarray(result_t) + return sort_ops.sort(array_ops.reshape(a, [-1]), 0) else: - return np_utils.tensor_to_ndarray(sort_ops.sort(a.data, axis)) + return sort_ops.sort(a, axis) def _argminmax(fn, a, axis=None): a = np_array_ops.array(a) if axis is None: # When axis is None numpy flattens the array. - a_t = array_ops.reshape(a.data, [-1]) + a_t = array_ops.reshape(a, [-1]) else: - a_t = np_array_ops.atleast_1d(a).data - return np_utils.tensor_to_ndarray(fn(input=a_t, axis=axis)) + a_t = np_array_ops.atleast_1d(a) + return fn(input=a_t, axis=axis) @np_utils.np_doc('argmax') @@ -1267,24 +1235,24 @@ def average(a, axis=None, weights=None, returned=False): # pylint: disable=miss 'supported yet. Got type: %s' % type(axis)) a = np_array_ops.array(a) if weights is None: # Treat all weights as 1 - if not np.issubdtype(a.dtype, np.inexact): + if not np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): a = a.astype( np_utils.result_type(a.dtype, np_dtypes.default_float_type())) - avg = math_ops.reduce_mean(a.data, axis=axis) + avg = math_ops.reduce_mean(a, axis=axis) if returned: if axis is None: - weights_sum = array_ops.size(a.data) + weights_sum = array_ops.size(a) else: - weights_sum = array_ops.shape(a.data)[axis] - weights_sum = math_ops.cast(weights_sum, a.data.dtype) + weights_sum = array_ops.shape(a)[axis] + weights_sum = math_ops.cast(weights_sum, a.dtype) else: - if np.issubdtype(a.dtype, np.inexact): + if np.issubdtype(a.dtype.as_numpy_dtype, np.inexact): out_dtype = np_utils.result_type(a.dtype, weights) else: out_dtype = np_utils.result_type(a.dtype, weights, np_dtypes.default_float_type()) - a = np_array_ops.array(a, out_dtype).data - weights = np_array_ops.array(weights, out_dtype).data + a = np_array_ops.array(a, out_dtype) + weights = np_array_ops.array(weights, out_dtype) def rank_equal_case(): control_flow_ops.Assert( @@ -1316,8 +1284,7 @@ def average(a, axis=None, weights=None, returned=False): # pylint: disable=miss avg = np_array_ops.array(avg) if returned: - weights_sum = np_array_ops.broadcast_to(weights_sum, - array_ops.shape(avg.data)) + weights_sum = np_array_ops.broadcast_to(weights_sum, array_ops.shape(avg)) return avg, weights_sum return avg @@ -1326,7 +1293,7 @@ def average(a, axis=None, weights=None, returned=False): # pylint: disable=miss def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing-docstring if dtype: dtype = np_utils.result_type(dtype) - a = np_array_ops.asarray(a, dtype).data + a = np_array_ops.asarray(a, dtype) if offset == 0: a_shape = a.shape @@ -1334,7 +1301,7 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None): # pylint: disable=missing rank = len(a_shape) if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or axis2 == rank - 1): - return np_utils.tensor_to_ndarray(math_ops.trace(a)) + return math_ops.trace(a) a = np_array_ops.diagonal(a, offset, axis1, axis2) return np_array_ops.sum(a, -1, dtype) @@ -1353,11 +1320,10 @@ def meshgrid(*xi, **kwargs): indexing = kwargs.get('indexing', 'xy') - xi = [np_array_ops.asarray(arg).data for arg in xi] + xi = [np_array_ops.asarray(arg) for arg in xi] kwargs = {'indexing': indexing} outputs = array_ops.meshgrid(*xi, **kwargs) - outputs = [np_utils.tensor_to_ndarray(output) for output in outputs] return outputs @@ -1387,7 +1353,62 @@ def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstrin tf_optimize = 'optimal' else: raise ValueError('`optimize` method not supported: %s' % optimize) - operands = [x.data for x in operands] res = special_math_ops.einsum(subscripts, *operands, optimize=tf_optimize) - res = np_utils.tensor_to_ndarray(res) return res + + +def _tensor_t(self): + """Returns a Tensor which is the transpose of this Tensor.""" + return self.transpose() + + +def _tensor_ndim(self): + """Returns the rank of the Tensor.""" + return self.shape.ndims + + +def _tensor_pos(self): + """Returns self, for unary operator `+`.""" + return self + + +def _tensor_size(self): + """Returns the number of elements in this Tensor, if fully known.""" + if not self.shape.is_fully_defined(): + return None + return np.prod(self.shape.as_list()) + + +def _tensor_tolist(self): + if isinstance(self, ops.EagerTensor): + return self._numpy().tolist() # pylint: disable=protected-access + + raise ValueError('Symbolic Tensors do not support the tolist API.') + + +def enable_numpy_methods_on_tensor(): + """Adds additional NumPy methods on tf.Tensor class.""" + t = property(_tensor_t) + setattr(ops.Tensor, 'T', t) + + ndim = property(_tensor_ndim) + setattr(ops.Tensor, 'ndim', ndim) + + size = property(_tensor_size) + setattr(ops.Tensor, 'size', size) + + setattr(ops.Tensor, '__pos__', _tensor_pos) + setattr(ops.Tensor, 'tolist', _tensor_tolist) + + # TODO(b/178540516): Make a custom `setattr` that changes the method's + # docstring to the TF one. + setattr(ops.Tensor, 'transpose', np_array_ops.transpose) + setattr(ops.Tensor, 'reshape', np_array_ops._reshape_method_wrapper) # pylint: disable=protected-access + setattr(ops.Tensor, 'ravel', np_array_ops.ravel) + setattr(ops.Tensor, 'clip', clip) + setattr(ops.Tensor, 'astype', math_ops.cast) + setattr(ops.Tensor, '__round__', np_array_ops.around) + + # TODO(wangpeng): Remove `data` when all uses of it are removed + data = property(lambda self: self) + setattr(ops.Tensor, 'data', data) diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py index cb5326bcded..fd9dc18abfc 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops_test.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops_test.py @@ -160,7 +160,7 @@ class MathTest(test.TestCase, parameterized.TestCase): self.assertEqual( actual.dtype, expected.dtype, 'Dtype mismatch.\nActual: {}\nExpected: {}\n{}'.format( - actual.dtype, expected.dtype, msg)) + actual.dtype.as_numpy_dtype, expected.dtype, msg)) self.assertEqual( actual.shape, expected.shape, 'Shape mismatch.\nActual: {}\nExpected: {}\n{}'.format( @@ -350,4 +350,6 @@ class MathTest(test.TestCase, parameterized.TestCase): if __name__ == '__main__': ops.enable_eager_execution() + ops.enable_numpy_style_type_promotion() + np_math_ops.enable_numpy_methods_on_tensor() test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_random.py b/tensorflow/python/ops/numpy_ops/np_random.py index d0b199b2d21..f6a6462760f 100644 --- a/tensorflow/python/ops/numpy_ops/np_random.py +++ b/tensorflow/python/ops/numpy_ops/np_random.py @@ -73,7 +73,7 @@ def standard_normal(size=None): elif np_utils.isscalar(size): size = (size,) dtype = np_dtypes.default_float_type() - return np_utils.tensor_to_ndarray(random_ops.random_normal(size, dtype=dtype)) + return random_ops.random_normal(size, dtype=dtype) @np_utils.np_doc('random.uniform') @@ -83,9 +83,8 @@ def uniform(low=0.0, high=1.0, size=None): high = np_array_ops.asarray(high, dtype=dtype) if size is None: size = array_ops.broadcast_dynamic_shape(low.shape, high.shape) - return np_utils.tensor_to_ndarray( - random_ops.random_uniform( - shape=size, minval=low, maxval=high, dtype=dtype)) + return random_ops.random_uniform( + shape=size, minval=low, maxval=high, dtype=dtype) @np_utils.np_doc('random.poisson') @@ -94,8 +93,7 @@ def poisson(lam=1.0, size=None): size = () elif np_utils.isscalar(size): size = (size,) - return np_utils.tensor_to_ndarray( - random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_)) + return random_ops.random_poisson(shape=size, lam=lam, dtype=np_dtypes.int_) @np_utils.np_doc('random.random') @@ -121,6 +119,5 @@ def randint(low, high=None, size=None, dtype=onp.int): # pylint: disable=missin dtype = np_utils.result_type(dtype) if dtype not in (onp.int32, onp.int64): raise ValueError('Only np.int32 or np.int64 types are supported') - return np_utils.tensor_to_ndarray( - random_ops.random_uniform( - shape=size, minval=low, maxval=high, dtype=dtype)) + return random_ops.random_uniform( + shape=size, minval=low, maxval=high, dtype=dtype) diff --git a/tensorflow/python/ops/numpy_ops/np_random_test.py b/tensorflow/python/ops/numpy_ops/np_random_test.py index 61ddbcaf47b..de0ee32802e 100644 --- a/tensorflow/python/ops/numpy_ops/np_random_test.py +++ b/tensorflow/python/ops/numpy_ops/np_random_test.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import numpy_ops as np # Needed for ndarray.reshape. from tensorflow.python.ops.numpy_ops import np_array_ops # pylint: disable=unused-import from tensorflow.python.ops.numpy_ops import np_dtypes +from tensorflow.python.ops.numpy_ops import np_math_ops from tensorflow.python.ops.numpy_ops import np_random from tensorflow.python.platform import test @@ -192,7 +193,7 @@ class RandNDistriutionTest(test.TestCase): self.assertEqual(output.shape, tuple(args)) default_dtype = ( np.float64 if np_dtypes.is_allow_float64() else np.float32) - self.assertEqual(output.dtype.type, default_dtype) + self.assertEqual(output.dtype.as_numpy_dtype, default_dtype) if np.prod(args): # Don't bother with empty arrays. outputs = [output.tolist() for output in outputs] @@ -230,4 +231,5 @@ class RandNDistriutionTest(test.TestCase): if __name__ == '__main__': ops.enable_eager_execution() + np_math_ops.enable_numpy_methods_on_tensor() test.main() diff --git a/tensorflow/python/ops/numpy_ops/np_utils.py b/tensorflow/python/ops/numpy_ops/np_utils.py index ca09624de76..90f3c3913a6 100644 --- a/tensorflow/python/ops/numpy_ops/np_utils.py +++ b/tensorflow/python/ops/numpy_ops/np_utils.py @@ -38,9 +38,6 @@ from tensorflow.python.types import core from tensorflow.python.util import nest -tensor_to_ndarray = np_arrays.tensor_to_ndarray - - def _canonicalize_axis(axis, rank): return _canonicalize_axes([axis], rank)[0] @@ -478,8 +475,6 @@ def _maybe_get_dtype(x): """Returns a numpy type if available from x. Skips if x is numpy.ndarray.""" # Don't put np.ndarray in this list, because np.result_type looks at the # value (not just dtype) of np.ndarray to decide the result type. - if isinstance(x, np_arrays.ndarray): - return x.dtype if isinstance(x, numbers.Real): return x if isinstance(x, (core.Tensor, indexed_slices.IndexedSlices)): diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index 169eb17cda1..569ef81b410 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -34,7 +34,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.parallel_for.pfor import PFor from tensorflow.python.ops.parallel_for.pfor import PForConfig from tensorflow.python.platform import tf_logging as logging @@ -289,7 +288,6 @@ def _pfor_impl(loop_fn, loop_fn_outputs) # Convert outputs to Tensor if needed. - rewrap_as_ndarray = False tmp_loop_fn_outputs = [] for loop_fn_output in nest.flatten(loop_fn_output_tensors): if (loop_fn_output is not None and not isinstance( @@ -301,9 +299,6 @@ def _pfor_impl(loop_fn, " IndexedSlices separately, and handle the vectorized" " outputs directly." % loop_fn_output) loop_fn_output = ops.convert_to_tensor(loop_fn_output) - elif isinstance(loop_fn_output, np_arrays.ndarray): - loop_fn_output = loop_fn_output.data - rewrap_as_ndarray = True else: loop_fn_output = ops.convert_to_tensor(loop_fn_output) tmp_loop_fn_outputs.append(loop_fn_output) @@ -327,8 +322,6 @@ def _pfor_impl(loop_fn, flattened_output_tensors = [] for loop_fn_output in nest.flatten(loop_fn_output_tensors): output = converter.convert(loop_fn_output) - if rewrap_as_ndarray: - output = np_arrays.tensor_to_ndarray(output) flattened_output_tensors.append(output) else: if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access @@ -346,8 +339,6 @@ def _pfor_impl(loop_fn, flattened_output_tensors = nest.flatten(loop_fn_output_tensors) for loop_fn_output in flattened_output_tensors: output = converter.convert(loop_fn_output) - if rewrap_as_ndarray: - output = np_arrays.tensor_to_ndarray(output) remaining_output_tensors.append(output) with ops.name_scope("pfor_tiled"): @@ -398,10 +389,6 @@ def _pfor_impl(loop_fn, tensor_shape.TensorShape([iters_value]).concatenate( original_output.shape)) - if rewrap_as_ndarray: - flattened_output_tensors = [ - np_arrays.tensor_to_ndarray(x) for x in flattened_output_tensors] - return nest.map_structure_up_to( loop_fn_outputs, functools.partial(_composite_from_tensors, batch_size=iters_value), @@ -418,8 +405,6 @@ def _broadcasting_gather(x, i): elif static_first_dim is None: i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0) result = array_ops.gather(x, i) - if isinstance(x, np_arrays.ndarray): - result = np_arrays.ndarray.from_tensor(result) return result @@ -548,8 +533,6 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True): is_batched=True), elems)) def _get_shape(x): - if isinstance(x, np_arrays.ndarray): - x = x.data if x.shape.rank is None: return None return x.shape.as_list()[0] diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 69c59744efe..ecdc53d2572 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -50,12 +50,10 @@ from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import numpy_ops as tnp from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.saved_model import load @@ -1967,34 +1965,6 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(imported.lookup("foo")), 15) self.assertEqual(self.evaluate(imported.lookup("idk")), -1) - def test_saving_ndarray_specs(self, cycles): - class NdarrayModule(module.Module): - - @def_function.function - def plain(self, x): - return tnp.add(x, 1) - - @def_function.function(input_signature=[ - np_arrays.NdarraySpec(tensor_spec.TensorSpec([], dtypes.float32))]) - def with_signature(self, x): - return tnp.add(x, 1) - - m = NdarrayModule() - c = tnp.asarray(3.0, tnp.float32) - output_plain, output_with_signature = m.plain(c), m.with_signature(c) - - loaded_m = cycle(m, cycles) - - load_output_plain, load_output_with_signature = ( - loaded_m.plain(c), loaded_m.with_signature(c)) - - self.assertIsInstance(output_plain, tnp.ndarray) - self.assertIsInstance(load_output_plain, tnp.ndarray) - self.assertIsInstance(output_with_signature, tnp.ndarray) - self.assertIsInstance(load_output_with_signature, tnp.ndarray) - self.assertAllClose(output_plain, load_output_plain) - self.assertAllClose(output_with_signature, load_output_with_signature) - class SingleCycleTests(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py index a7e5548ee06..9c71b853675 100644 --- a/tensorflow/python/saved_model/nested_structure_coder.py +++ b/tensorflow/python/saved_model/nested_structure_coder.py @@ -48,7 +48,6 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import row_partition from tensorflow.python.util import compat @@ -517,8 +516,6 @@ class _TypeSpecCodec(object): resource_variable_ops.VariableSpec, struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC: row_partition.RowPartitionSpec, - struct_pb2.TypeSpecProto.NDARRAY_SPEC: - np_arrays.NdarraySpec, } # Mapping from type (TypeSpec subclass) to enum value. diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py index fb074f76eb0..9951ea64a49 100644 --- a/tensorflow/python/saved_model/nested_structure_coder_test.py +++ b/tensorflow/python/saved_model/nested_structure_coder_test.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_util -from tensorflow.python.ops.numpy_ops import np_arrays from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test from tensorflow.python.saved_model import nested_structure_coder @@ -332,14 +331,6 @@ class NestedStructureTest(test.TestCase): decoded = self._coder.decode_proto(encoded) self.assertEqual(structure, decoded) - def testEncodeDecodeNdarraySpec(self): - structure = [np_arrays.NdarraySpec( - tensor_spec.TensorSpec([4, 2], dtypes.float32))] - self.assertTrue(self._coder.can_encode(structure)) - encoded = self._coder.encode_structure(structure) - decoded = self._coder.decode_proto(encoded) - self.assertEqual(structure, decoded) - def testNotEncodable(self): class NotEncodable(object): diff --git a/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt b/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt index 4edc5f08e84..104e85835cd 100644 --- a/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt +++ b/third_party/py/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt @@ -1,14 +1,15 @@ path: "tensorflow.experimental.numpy.ndarray" tf_class { - is_instance: "<class \'tensorflow.python.ops.numpy_ops.np_arrays.ndarray\'>" - is_instance: "<class \'tensorflow.python.framework.composite_tensor.CompositeTensor\'>" + is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>" + is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>" + is_instance: "<class \'tensorflow.python.types.core.Tensor\'>" is_instance: "<type \'object\'>" member { - name: "T" - mtype: "<type \'property\'>" + name: "OVERLOADABLE_OPERATORS" + mtype: "<type \'set\'>" } member { - name: "data" + name: "device" mtype: "<type \'property\'>" } member { @@ -16,7 +17,15 @@ tf_class { mtype: "<type \'property\'>" } member { - name: "ndim" + name: "graph" + mtype: "<type \'property\'>" + } + member { + name: "name" + mtype: "<type \'property\'>" + } + member { + name: "op" mtype: "<type \'property\'>" } member { @@ -24,39 +33,35 @@ tf_class { mtype: "<type \'property\'>" } member { - name: "size" + name: "value_index" mtype: "<type \'property\'>" } member_method { name: "__init__" - argspec: "args=[\'self\', \'shape\', \'dtype\', \'buffer\'], varargs=None, keywords=None, defaults=[\"<class \'float\'>\", \'None\'], " + argspec: "args=[\'self\', \'op\', \'value_index\', \'dtype\'], varargs=None, keywords=None, defaults=None" } member_method { - name: "astype" - argspec: "args=[\'self\', \'dtype\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "clip" - argspec: "args=[\'a\', \'a_min\', \'a_max\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "from_tensor" - argspec: "args=[\'cls\', \'tensor\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ravel" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "reshape" - argspec: "args=[\'a\'], varargs=newshape, keywords=kwargs, defaults=None" - } - member_method { - name: "tolist" + name: "consumers" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } member_method { - name: "transpose" - argspec: "args=[\'a\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\'], " + name: "eval" + argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "experimental_ref" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_shape" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "ref" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_shape" + argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None" } }