tf.numpy: Numpy API on TF.
PiperOrigin-RevId: 314792337 Change-Id: I92dc879492af15c6160e3f671e185ea4f87e8ed4
This commit is contained in:
parent
bab81497b0
commit
84c796966b
@ -223,6 +223,7 @@ py_library(
|
|||||||
"//tensorflow/python/ops/linalg",
|
"//tensorflow/python/ops/linalg",
|
||||||
"//tensorflow/python/ops/linalg/sparse",
|
"//tensorflow/python/ops/linalg/sparse",
|
||||||
"//tensorflow/python/ops/losses",
|
"//tensorflow/python/ops/losses",
|
||||||
|
"//tensorflow/python/ops/numpy_ops:numpy",
|
||||||
"//tensorflow/python/ops/parallel_for",
|
"//tensorflow/python/ops/parallel_for",
|
||||||
"//tensorflow/python/ops/ragged",
|
"//tensorflow/python/ops/ragged",
|
||||||
"//tensorflow/python/ops/signal",
|
"//tensorflow/python/ops/signal",
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# TF numpy API
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -8,9 +8,109 @@ package(
|
|||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "numpy_ops",
|
name = "numpy",
|
||||||
srcs = [
|
srcs = [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
|
"np_array_ops.py",
|
||||||
|
"np_arrays.py",
|
||||||
|
"np_dtypes.py",
|
||||||
|
"np_math_ops.py",
|
||||||
|
"np_random.py",
|
||||||
|
"np_utils.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:bitwise_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:errors",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:indexed_slices",
|
||||||
|
"//tensorflow/python:linalg_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:sort_ops",
|
||||||
|
"//tensorflow/python:tensor_util",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "np_arrays_test",
|
||||||
|
srcs = ["np_arrays_test.py"],
|
||||||
|
deps = [
|
||||||
|
":numpy",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "np_array_ops_test",
|
||||||
|
srcs = ["np_array_ops_test.py"],
|
||||||
|
deps = [
|
||||||
|
":numpy",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "np_backprop_test",
|
||||||
|
srcs = ["np_backprop_test.py"],
|
||||||
|
deps = [
|
||||||
|
":numpy",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "np_logic_test",
|
||||||
|
srcs = ["np_logic_test.py"],
|
||||||
|
deps = [
|
||||||
|
":numpy",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "np_math_ops_test",
|
||||||
|
srcs = ["np_math_ops_test.py"],
|
||||||
|
deps = [
|
||||||
|
":numpy",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "np_random_test",
|
||||||
|
srcs = ["np_random_test.py"],
|
||||||
|
deps = [
|
||||||
|
":numpy",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
"@six_archive//:six",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "np_utils_test",
|
||||||
|
srcs = ["np_utils_test.py"],
|
||||||
|
deps = [
|
||||||
|
":numpy",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
@ -17,3 +17,21 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_random as random
|
||||||
|
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.python.ops.numpy_ops.np_array_ops import *
|
||||||
|
from tensorflow.python.ops.numpy_ops.np_arrays import ndarray
|
||||||
|
from tensorflow.python.ops.numpy_ops.np_dtypes import *
|
||||||
|
from tensorflow.python.ops.numpy_ops.np_math_ops import *
|
||||||
|
from tensorflow.python.ops.numpy_ops.np_utils import finfo
|
||||||
|
from tensorflow.python.ops.numpy_ops.np_utils import promote_types
|
||||||
|
from tensorflow.python.ops.numpy_ops.np_utils import result_type
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
|
# pylint: disable=redefined-builtin,undefined-variable
|
||||||
|
max = amax
|
||||||
|
min = amin
|
||||||
|
round = around
|
||||||
|
# pylint: enable=redefined-builtin,undefined-variable
|
||||||
|
1415
tensorflow/python/ops/numpy_ops/np_array_ops.py
Normal file
1415
tensorflow/python/ops/numpy_ops/np_array_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
1139
tensorflow/python/ops/numpy_ops/np_array_ops_test.py
Normal file
1139
tensorflow/python/ops/numpy_ops/np_array_ops_test.py
Normal file
File diff suppressed because it is too large
Load Diff
292
tensorflow/python/ops/numpy_ops/np_arrays.py
Normal file
292
tensorflow/python/ops/numpy_ops/np_arrays.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""ndarray class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import six
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_dtypes
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_tensor(value, dtype=None):
|
||||||
|
# A safer version of `tf.convert_to_tensor` to work around b/149876037.
|
||||||
|
# TODO(wangpeng): Remove this function once the bug is fixed.
|
||||||
|
if (dtype is None and isinstance(value, six.integer_types) and
|
||||||
|
value >= 2**63):
|
||||||
|
dtype = dtypes.uint64
|
||||||
|
elif (dtype is None and isinstance(value, float)):
|
||||||
|
dtype = np_dtypes.default_float_type()
|
||||||
|
return ops.convert_to_tensor(value, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class ndarray(object): # pylint: disable=invalid-name
|
||||||
|
"""Equivalent of numpy.ndarray backed by TensorFlow tensors.
|
||||||
|
|
||||||
|
This does not support all features of NumPy ndarrays e.g. strides and
|
||||||
|
memory order since, unlike NumPy, the backing storage is not a raw memory
|
||||||
|
buffer.
|
||||||
|
|
||||||
|
TODO(srbs): Clearly specify which attributes and methods are not supported
|
||||||
|
or if there are any differences in behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, shape, dtype=float, buffer=None): # pylint: disable=redefined-builtin
|
||||||
|
"""Initializes an ndarray.
|
||||||
|
|
||||||
|
This is a low level interface for building ndarrays and should be avoided.
|
||||||
|
Users should instead use methods in array_creation.py.
|
||||||
|
|
||||||
|
This class provides a numpy.ndarray like interface for a TF Tensor with a
|
||||||
|
fully-defined shape. Note that, unlike the backing buffer of np.ndarray,
|
||||||
|
Tensors are immutable. So, operations like `__setitem__` are performed by
|
||||||
|
replacing the Tensor. This restricts the ability to implement NumPy `view`
|
||||||
|
semantics.
|
||||||
|
|
||||||
|
Compared to numpy.ndarray, this does not support `offset`, `strides`
|
||||||
|
and `order` arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: The shape of the array. Must be a scalar, an iterable of integers
|
||||||
|
or a `TensorShape` object.
|
||||||
|
dtype: Optional. The dtype of the array. Must be a python type, a numpy
|
||||||
|
type or a tensorflow `DType` object.
|
||||||
|
buffer: Optional. The backing buffer of the array. Must have shape
|
||||||
|
`shape`. Must be a `ndarray`, `np.ndarray` or a `Tensor`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `buffer` is specified and its shape does not match
|
||||||
|
`shape`.
|
||||||
|
"""
|
||||||
|
if dtype and not isinstance(dtype, dtypes.DType):
|
||||||
|
dtype = dtypes.as_dtype(np.dtype(dtype))
|
||||||
|
if buffer is None:
|
||||||
|
buffer = array_ops.zeros(shape, dtype=dtype)
|
||||||
|
else:
|
||||||
|
if isinstance(buffer, ndarray):
|
||||||
|
buffer = buffer.data
|
||||||
|
elif isinstance(buffer, np.ndarray):
|
||||||
|
# If `buffer` is a np.ndarray, the Tensor will share the underlying
|
||||||
|
# storage of the array.
|
||||||
|
buffer = convert_to_tensor(value=buffer, dtype=dtype)
|
||||||
|
elif not isinstance(buffer, ops.Tensor):
|
||||||
|
raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
|
||||||
|
' Tensor or np.ndarray.'.format(type(buffer)))
|
||||||
|
|
||||||
|
if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access
|
||||||
|
# TODO(srbs): NumPy allows this. Investigate if/how to support this.
|
||||||
|
raise ValueError('shape arg must match buffer.shape.')
|
||||||
|
|
||||||
|
assert isinstance(buffer, ops.Tensor)
|
||||||
|
if dtype and dtype != buffer.dtype:
|
||||||
|
buffer = array_ops.bitcast(buffer, dtype)
|
||||||
|
self._data = buffer
|
||||||
|
self.base = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
"""Tensor object containing the array data.
|
||||||
|
|
||||||
|
This has a few key differences from the Python buffer object used in
|
||||||
|
NumPy arrays.
|
||||||
|
1. Tensors are immutable. So operations requiring in-place edit, e.g.
|
||||||
|
__setitem__, are performed by replacing the underlying buffer with a new
|
||||||
|
one.
|
||||||
|
2. Tensors do not provide access to their raw buffer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tensor.
|
||||||
|
"""
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
"""Returns a tuple of array dimensions."""
|
||||||
|
return self.data._shape_tuple() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return np.dtype(self.data.dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ndim(self):
|
||||||
|
return self.data.shape.ndims
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self):
|
||||||
|
"""Returns the number of elements in the array."""
|
||||||
|
return np.prod(self.shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def T(self): # pylint: disable=invalid-name
|
||||||
|
return self.transpose()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.shape:
|
||||||
|
return self.shape[0]
|
||||||
|
else:
|
||||||
|
raise TypeError('len() of unsized object.')
|
||||||
|
|
||||||
|
def astype(self, dtype):
|
||||||
|
if self.dtype == dtype:
|
||||||
|
return self
|
||||||
|
else:
|
||||||
|
return tensor_to_ndarray(math_ops.cast(self.data, dtype))
|
||||||
|
|
||||||
|
# Unary operations
|
||||||
|
def __neg__(self):
|
||||||
|
return tensor_to_ndarray(-self.data) # pylint: disable=invalid-unary-operand-type
|
||||||
|
|
||||||
|
def __pos__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
__hash__ = None
|
||||||
|
|
||||||
|
def __int__(self):
|
||||||
|
return int(self.data)
|
||||||
|
|
||||||
|
def __float__(self):
|
||||||
|
return float(self.data)
|
||||||
|
|
||||||
|
def __nonzero__(self):
|
||||||
|
return bool(self.data)
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return self.__nonzero__()
|
||||||
|
|
||||||
|
def __getitem__(self, slice_spec):
|
||||||
|
# TODO(srbs): Need to support better indexing.
|
||||||
|
result_t = self.data.__getitem__(slice_spec)
|
||||||
|
return tensor_to_ndarray(result_t)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(self.shape[0]):
|
||||||
|
result_t = self.data[i]
|
||||||
|
yield tensor_to_ndarray(result_t)
|
||||||
|
return
|
||||||
|
|
||||||
|
def __array__(self, dtype=None):
|
||||||
|
"""Returns a NumPy ndarray.
|
||||||
|
|
||||||
|
This allows instances of this class to be directly used in NumPy routines.
|
||||||
|
However, doing that may force a copy to CPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: A NumPy compatible type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A NumPy ndarray.
|
||||||
|
"""
|
||||||
|
return np.asarray(self.data, dtype)
|
||||||
|
|
||||||
|
__array_priority__ = 110
|
||||||
|
|
||||||
|
def __index__(self):
|
||||||
|
"""Returns a python scalar.
|
||||||
|
|
||||||
|
This allows using an instance of this class as an array index.
|
||||||
|
Note that only arrays of integer types with size 1 can be used as array
|
||||||
|
indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Python scalar.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the array is not of an integer type.
|
||||||
|
ValueError: If the array does not have size 1.
|
||||||
|
"""
|
||||||
|
# TODO(wangpeng): Handle graph mode
|
||||||
|
return np.asscalar(self.data.numpy())
|
||||||
|
|
||||||
|
def tolist(self):
|
||||||
|
return self.data.numpy().tolist()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'ndarray<{}>'.format(self.data.__str__())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'ndarray<{}>'.format(self.data.__repr__())
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_ndarray(tensor):
|
||||||
|
return ndarray(tensor._shape_tuple(), dtype=tensor.dtype, buffer=tensor) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
|
||||||
|
if as_ref:
|
||||||
|
raise ValueError('as_ref is not supported.')
|
||||||
|
if dtype and dtypes.as_dtype(arr.dtype) != dtype:
|
||||||
|
return math_ops.cast(arr.data, dtype)
|
||||||
|
result_t = arr.data
|
||||||
|
if name:
|
||||||
|
result_t = array_ops.identity(result_t, name=name)
|
||||||
|
return result_t
|
||||||
|
|
||||||
|
|
||||||
|
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# Don't use a namedtuple since nest considers that a tuple and unflattens and
|
||||||
|
# flattens it.
|
||||||
|
class ShardedNdArray(object):
|
||||||
|
"""Wrapper over ndarray that can contain tensors on multiple devices.
|
||||||
|
|
||||||
|
This is returned by extensions.pmap, and contains the individual tensors on
|
||||||
|
different devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tensors):
|
||||||
|
"""Initializes the ShardedNdArray.
|
||||||
|
|
||||||
|
Note that the tensors should be ordered in the way the pmap producing these
|
||||||
|
tensors is run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensors: list or tuple of eager tensors, one for each device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(tensors, (list, tuple)) or not tensors:
|
||||||
|
raise ValueError(
|
||||||
|
'Unable to create a ShardedNdArray without a list of tensors.')
|
||||||
|
self.tensors = tensors
|
||||||
|
self.n_devices = len(tensors)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.tensors[i]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return (self.n_devices,) + self.tensors[0]._shape_tuple() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.tensors[0].dtype
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs):
|
||||||
|
del args, kwargs
|
||||||
|
# TODO(nareshmodi): Consider a collective op to gather the tensors from the
|
||||||
|
# various devices for performance reasons.
|
||||||
|
return array_ops.stack(value.tensors)
|
||||||
|
|
||||||
|
|
||||||
|
ops.register_tensor_conversion_function(ShardedNdArray,
|
||||||
|
convert_sharded_tensor_to_eager_tensor)
|
189
tensorflow/python/ops/numpy_ops/np_arrays_test.py
Normal file
189
tensorflow/python/ops/numpy_ops/np_arrays_test.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for ndarray."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||||
|
# Required for operator overloads
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
t2a = np_arrays.tensor_to_ndarray
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayTest(test.TestCase):
|
||||||
|
|
||||||
|
def testDtype(self):
|
||||||
|
a = t2a(array_ops.zeros(shape=[1, 2], dtype=dtypes.int64))
|
||||||
|
self.assertIs(a.dtype.type, np.int64)
|
||||||
|
self.assertAllEqual(0, a.dtype.type(0))
|
||||||
|
|
||||||
|
def testAstype(self):
|
||||||
|
a = t2a(ops.convert_to_tensor(value=1.1,
|
||||||
|
dtype=dtypes.float32)).astype(np.int32)
|
||||||
|
self.assertIs(a.dtype.type, np.int32)
|
||||||
|
self.assertAllEqual(1, a)
|
||||||
|
a = t2a(ops.convert_to_tensor(value=[0.0, 1.1],
|
||||||
|
dtype=dtypes.float32)).astype(np.bool_)
|
||||||
|
self.assertIs(a.dtype.type, np.bool_)
|
||||||
|
self.assertAllEqual([False, True], a)
|
||||||
|
|
||||||
|
def testNeg(self):
|
||||||
|
a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
|
||||||
|
self.assertAllEqual([-1.0, -2.0], -a)
|
||||||
|
|
||||||
|
def _testBinOp(self, a, b, out, f, types=None):
|
||||||
|
a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
|
||||||
|
b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
|
||||||
|
if not isinstance(out, np_arrays.ndarray):
|
||||||
|
out = t2a(ops.convert_to_tensor(value=out, dtype=np.int32))
|
||||||
|
if types is None:
|
||||||
|
types = [[np.int32, np.int32, np.int32], [np.int64, np.int32, np.int64],
|
||||||
|
[np.int32, np.int64, np.int64],
|
||||||
|
[np.float32, np.int32, np.float64],
|
||||||
|
[np.int32, np.float32, np.float64],
|
||||||
|
[np.float32, np.float32, np.float32],
|
||||||
|
[np.float64, np.float32, np.float64],
|
||||||
|
[np.float32, np.float64, np.float64]]
|
||||||
|
for a_type, b_type, out_type in types:
|
||||||
|
o = f(a.astype(a_type), b.astype(b_type))
|
||||||
|
self.assertIs(o.dtype.type, out_type)
|
||||||
|
out = out.astype(out_type)
|
||||||
|
if np.issubdtype(out_type, np.inexact):
|
||||||
|
self.assertAllClose(out, o)
|
||||||
|
else:
|
||||||
|
self.assertAllEqual(out, o)
|
||||||
|
|
||||||
|
def testAdd(self):
|
||||||
|
self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: a.__add__(b))
|
||||||
|
|
||||||
|
def testRadd(self):
|
||||||
|
self._testBinOp([1, 2], [3, 4], [4, 6], lambda a, b: b.__radd__(a))
|
||||||
|
|
||||||
|
def testSub(self):
|
||||||
|
self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: a.__sub__(b))
|
||||||
|
|
||||||
|
def testRsub(self):
|
||||||
|
self._testBinOp([1, 2], [3, 5], [-2, -3], lambda a, b: b.__rsub__(a))
|
||||||
|
|
||||||
|
def testMul(self):
|
||||||
|
self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: a.__mul__(b))
|
||||||
|
|
||||||
|
def testRmul(self):
|
||||||
|
self._testBinOp([1, 2], [3, 4], [3, 8], lambda a, b: b.__rmul__(a))
|
||||||
|
|
||||||
|
def testPow(self):
|
||||||
|
self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: a.__pow__(b))
|
||||||
|
|
||||||
|
def testRpow(self):
|
||||||
|
self._testBinOp([4, 5], [3, 2], [64, 25], lambda a, b: b.__rpow__(a))
|
||||||
|
|
||||||
|
_truediv_types = [[np.int32, np.int32, np.float64],
|
||||||
|
[np.int64, np.int32, np.float64],
|
||||||
|
[np.int32, np.int64, np.float64],
|
||||||
|
[np.float32, np.int32, np.float64],
|
||||||
|
[np.int32, np.float32, np.float64],
|
||||||
|
[np.float32, np.float32, np.float32],
|
||||||
|
[np.float64, np.float32, np.float64],
|
||||||
|
[np.float32, np.float64, np.float64]]
|
||||||
|
|
||||||
|
def testTruediv(self):
|
||||||
|
self._testBinOp([3, 5], [2, 4],
|
||||||
|
t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
|
||||||
|
lambda a, b: a.__truediv__(b),
|
||||||
|
types=self._truediv_types)
|
||||||
|
|
||||||
|
def testRtruediv(self):
|
||||||
|
self._testBinOp([3, 5], [2, 4],
|
||||||
|
t2a(ops.convert_to_tensor(value=[1.5, 1.25])),
|
||||||
|
lambda a, b: b.__rtruediv__(a),
|
||||||
|
types=self._truediv_types)
|
||||||
|
|
||||||
|
def _testCmp(self, a, b, out, f):
|
||||||
|
a = t2a(ops.convert_to_tensor(value=a, dtype=np.int32))
|
||||||
|
b = t2a(ops.convert_to_tensor(value=b, dtype=np.int32))
|
||||||
|
types = [[np.int32, np.int32], [np.int64, np.int32], [np.int32, np.int64],
|
||||||
|
[np.float32, np.int32], [np.int32, np.float32],
|
||||||
|
[np.float32, np.float32], [np.float64, np.float32],
|
||||||
|
[np.float32, np.float64]]
|
||||||
|
for a_type, b_type in types:
|
||||||
|
o = f(a.astype(a_type), b.astype(b_type))
|
||||||
|
self.assertAllEqual(out, o)
|
||||||
|
|
||||||
|
def testLt(self):
|
||||||
|
self._testCmp([1, 2, 3], [3, 2, 1], [True, False, False],
|
||||||
|
lambda a, b: a.__lt__(b))
|
||||||
|
|
||||||
|
def testLe(self):
|
||||||
|
self._testCmp([1, 2, 3], [3, 2, 1], [True, True, False],
|
||||||
|
lambda a, b: a.__le__(b))
|
||||||
|
|
||||||
|
def testGt(self):
|
||||||
|
self._testCmp([1, 2, 3], [3, 2, 1], [False, False, True],
|
||||||
|
lambda a, b: a.__gt__(b))
|
||||||
|
|
||||||
|
def testGe(self):
|
||||||
|
self._testCmp([1, 2, 3], [3, 2, 1], [False, True, True],
|
||||||
|
lambda a, b: a.__ge__(b))
|
||||||
|
|
||||||
|
def testEq(self):
|
||||||
|
self._testCmp([1, 2, 3], [3, 2, 1], [False, True, False],
|
||||||
|
lambda a, b: a.__eq__(b))
|
||||||
|
|
||||||
|
def testNe(self):
|
||||||
|
self._testCmp([1, 2, 3], [3, 2, 1], [True, False, True],
|
||||||
|
lambda a, b: a.__ne__(b))
|
||||||
|
|
||||||
|
def testInt(self):
|
||||||
|
v = 10
|
||||||
|
u = int(t2a(ops.convert_to_tensor(value=v)))
|
||||||
|
self.assertIsInstance(u, int)
|
||||||
|
self.assertAllEqual(v, u)
|
||||||
|
|
||||||
|
def testFloat(self):
|
||||||
|
v = 21.32
|
||||||
|
u = float(t2a(ops.convert_to_tensor(value=v)))
|
||||||
|
self.assertIsInstance(u, float)
|
||||||
|
self.assertAllClose(v, u)
|
||||||
|
|
||||||
|
def testBool(self):
|
||||||
|
b = bool(t2a(ops.convert_to_tensor(value=10)))
|
||||||
|
self.assertIsInstance(b, bool)
|
||||||
|
self.assertTrue(b)
|
||||||
|
self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0))))
|
||||||
|
self.assertTrue(bool(t2a(ops.convert_to_tensor(value=0.1))))
|
||||||
|
self.assertFalse(bool(t2a(ops.convert_to_tensor(value=0.0))))
|
||||||
|
|
||||||
|
def testHash(self):
|
||||||
|
a = t2a(ops.convert_to_tensor(value=10))
|
||||||
|
self.assertNotIsInstance(a, collections.Hashable)
|
||||||
|
with self.assertRaisesWithPredicateMatch(TypeError, r'unhashable type'):
|
||||||
|
hash(a)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# TODO(wangpeng): Test in graph mode as well.
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
69
tensorflow/python/ops/numpy_ops/np_backprop_test.py
Normal file
69
tensorflow/python/ops/numpy_ops/np_backprop_test.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for backpropgration on tf-numpy functions."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||||
|
# Required for operator overloads
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_math_ops # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class BackpropTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_setitem(self):
|
||||||
|
# Single integer index.
|
||||||
|
a = np_array_ops.array([1., 2., 3.])
|
||||||
|
b = np_array_ops.array(5.)
|
||||||
|
c = np_array_ops.array(10.)
|
||||||
|
|
||||||
|
tensors = [arr.data for arr in [a, b, c]]
|
||||||
|
with backprop.GradientTape() as g:
|
||||||
|
g.watch(tensors)
|
||||||
|
a[1] = b + c
|
||||||
|
loss = np_array_ops.sum(a)
|
||||||
|
|
||||||
|
gradients = g.gradient(loss.data, tensors)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
np_array_ops.array(gradients[0]).tolist(), [1., 0., 1.])
|
||||||
|
self.assertEqual(np_array_ops.array(gradients[1]).tolist(), 1.)
|
||||||
|
self.assertEqual(np_array_ops.array(gradients[2]).tolist(), 1.)
|
||||||
|
|
||||||
|
# Tuple index.
|
||||||
|
a = np_array_ops.array([[[1., 2.], [3., 4.]], [[5., 6.],
|
||||||
|
[7., 8.]]]) # 2x2x2 array.
|
||||||
|
b = np_array_ops.array([10., 11.])
|
||||||
|
|
||||||
|
tensors = [arr.data for arr in [a, b]]
|
||||||
|
with backprop.GradientTape() as g:
|
||||||
|
g.watch(tensors)
|
||||||
|
a[(1, 0)] = b
|
||||||
|
loss = np_array_ops.sum(a)
|
||||||
|
|
||||||
|
gradients = g.gradient(loss.data, tensors)
|
||||||
|
self.assertSequenceEqual(
|
||||||
|
np_array_ops.array(gradients[0]).tolist(),
|
||||||
|
[[[1., 1.], [1., 1.]], [[0., 0.], [1., 1.]]])
|
||||||
|
self.assertEqual(np_array_ops.array(gradients[1]).tolist(), [1., 1.])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
96
tensorflow/python/ops/numpy_ops/np_dtypes.py
Normal file
96
tensorflow/python/ops/numpy_ops/np_dtypes.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Dtypes and dtype utilities."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# We use numpy's dtypes instead of TF's, because the user expects to use them
|
||||||
|
# with numpy facilities such as `np.dtype(np.int64)` and
|
||||||
|
# `if x.dtype.type is np.int64`.
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# pylint: disable=g-bad-import-order
|
||||||
|
from numpy import bool_
|
||||||
|
from numpy import int_
|
||||||
|
from numpy import int16
|
||||||
|
from numpy import int32
|
||||||
|
from numpy import int64
|
||||||
|
from numpy import int8
|
||||||
|
from numpy import uint16
|
||||||
|
from numpy import uint32
|
||||||
|
from numpy import uint64
|
||||||
|
from numpy import uint8
|
||||||
|
from numpy import float_
|
||||||
|
from numpy import float16
|
||||||
|
from numpy import float32
|
||||||
|
from numpy import float64
|
||||||
|
from numpy import complex_
|
||||||
|
from numpy import complex64
|
||||||
|
from numpy import complex128
|
||||||
|
|
||||||
|
from numpy import inexact
|
||||||
|
|
||||||
|
from numpy import iinfo
|
||||||
|
from numpy import issubdtype
|
||||||
|
|
||||||
|
from numpy import inf
|
||||||
|
|
||||||
|
# TODO(wangpeng): Make bfloat16 a numpy dtype instead of using TF's
|
||||||
|
from tensorflow.python.framework.dtypes import bfloat16
|
||||||
|
# pylint: enable=g-bad-import-order
|
||||||
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
|
_to_float32 = {
|
||||||
|
np.dtype('float64'): np.dtype('float32'),
|
||||||
|
np.dtype('complex128'): np.dtype('complex64'),
|
||||||
|
}
|
||||||
|
|
||||||
|
_allow_float64 = True
|
||||||
|
|
||||||
|
|
||||||
|
def is_allow_float64():
|
||||||
|
return _allow_float64
|
||||||
|
|
||||||
|
|
||||||
|
def set_allow_float64(b):
|
||||||
|
global _allow_float64
|
||||||
|
_allow_float64 = b
|
||||||
|
|
||||||
|
|
||||||
|
def canonicalize_dtype(dtype):
|
||||||
|
if not is_allow_float64():
|
||||||
|
return _to_float32.get(dtype, dtype)
|
||||||
|
else:
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _result_type(*arrays_and_dtypes):
|
||||||
|
dtype = np.result_type(*arrays_and_dtypes)
|
||||||
|
return canonicalize_dtype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def default_float_type():
|
||||||
|
"""Gets the default float type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If `is_allow_float64()` is true, returns float64; otherwise returns float32.
|
||||||
|
"""
|
||||||
|
if is_allow_float64():
|
||||||
|
return float64
|
||||||
|
else:
|
||||||
|
return float32
|
110
tensorflow/python/ops/numpy_ops/np_logic_test.py
Normal file
110
tensorflow/python/ops/numpy_ops/np_logic_test.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf numpy random number methods."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_math_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class LogicTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(LogicTest, self).setUp()
|
||||||
|
self.array_transforms = [
|
||||||
|
lambda x: x, # Identity,
|
||||||
|
ops.convert_to_tensor,
|
||||||
|
np.array,
|
||||||
|
lambda x: np.array(x, dtype=np.int32),
|
||||||
|
lambda x: np.array(x, dtype=np.int64),
|
||||||
|
lambda x: np.array(x, dtype=np.float32),
|
||||||
|
lambda x: np.array(x, dtype=np.float64),
|
||||||
|
np_array_ops.array,
|
||||||
|
lambda x: np_array_ops.array(x, dtype=dtypes.int32),
|
||||||
|
lambda x: np_array_ops.array(x, dtype=dtypes.int64),
|
||||||
|
lambda x: np_array_ops.array(x, dtype=dtypes.float32),
|
||||||
|
lambda x: np_array_ops.array(x, dtype=dtypes.float64),
|
||||||
|
]
|
||||||
|
|
||||||
|
def testEqual(self):
|
||||||
|
|
||||||
|
def run_test(x1, x2=None):
|
||||||
|
if x2 is None:
|
||||||
|
x2 = x1
|
||||||
|
for fn1 in self.array_transforms:
|
||||||
|
for fn2 in self.array_transforms:
|
||||||
|
arg1 = fn1(x1)
|
||||||
|
arg2 = fn2(x2)
|
||||||
|
self.match(
|
||||||
|
np_math_ops.equal(arg1, arg2),
|
||||||
|
np.equal(
|
||||||
|
make_numpy_compatible(arg1), make_numpy_compatible(arg2)))
|
||||||
|
|
||||||
|
run_test(1)
|
||||||
|
run_test(1, 2)
|
||||||
|
run_test([1, 2])
|
||||||
|
run_test([1, 2, 3], [2])
|
||||||
|
run_test([[1, 2], [3, 4]], [1, 2])
|
||||||
|
run_test([[1, 2], [1, 4]], [1, 2])
|
||||||
|
run_test([1, 2], [[1, 2], [1, 4]])
|
||||||
|
run_test([[1, 2], [3, 4]], [[1, 2], [3, 4]])
|
||||||
|
run_test([[1, 2], [3, 4]], [[1, 3], [3, 4]])
|
||||||
|
|
||||||
|
def match_shape(self, actual, expected, msg=None):
|
||||||
|
if msg:
|
||||||
|
msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
|
||||||
|
msg, expected.shape, actual.shape)
|
||||||
|
self.assertEqual(actual.shape, expected.shape, msg=msg)
|
||||||
|
if msg:
|
||||||
|
msg = 'Shape: {} is not a tuple for {}'.format(actual.shape, msg)
|
||||||
|
self.assertIsInstance(actual.shape, tuple, msg=msg)
|
||||||
|
|
||||||
|
def match_dtype(self, actual, expected, msg=None):
|
||||||
|
if msg:
|
||||||
|
msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format(
|
||||||
|
msg, expected.dtype, actual.dtype)
|
||||||
|
self.assertEqual(actual.dtype, expected.dtype, msg=msg)
|
||||||
|
|
||||||
|
def match(self, actual, expected, msg=None):
|
||||||
|
msg_ = 'Expected: {} Actual: {}'.format(expected, actual)
|
||||||
|
if msg:
|
||||||
|
msg = '{} {}'.format(msg_, msg)
|
||||||
|
else:
|
||||||
|
msg = msg_
|
||||||
|
self.assertIsInstance(actual, np_arrays.ndarray)
|
||||||
|
self.match_dtype(actual, expected, msg)
|
||||||
|
self.match_shape(actual, expected, msg)
|
||||||
|
if not actual.shape:
|
||||||
|
self.assertEqual(actual.tolist(), expected.tolist())
|
||||||
|
else:
|
||||||
|
self.assertSequenceEqual(actual.tolist(), expected.tolist())
|
||||||
|
|
||||||
|
|
||||||
|
def make_numpy_compatible(s):
|
||||||
|
return s if not isinstance(s, np_arrays.ndarray) else s.data.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
1251
tensorflow/python/ops/numpy_ops/np_math_ops.py
Normal file
1251
tensorflow/python/ops/numpy_ops/np_math_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
332
tensorflow/python/ops/numpy_ops/np_math_ops_test.py
Normal file
332
tensorflow/python/ops/numpy_ops/np_math_ops_test.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf numpy mathematical methods."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_array_ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_math_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class MathTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(MathTest, self).setUp()
|
||||||
|
self.array_transforms = [
|
||||||
|
lambda x: x, # Identity,
|
||||||
|
ops.convert_to_tensor,
|
||||||
|
np.array,
|
||||||
|
lambda x: np.array(x, dtype=np.float32),
|
||||||
|
lambda x: np.array(x, dtype=np.float64),
|
||||||
|
np_array_ops.array,
|
||||||
|
lambda x: np_array_ops.array(x, dtype=np.float32),
|
||||||
|
lambda x: np_array_ops.array(x, dtype=np.float64),
|
||||||
|
]
|
||||||
|
self.types = [np.int32, np.int64, np.float32, np.float64]
|
||||||
|
|
||||||
|
def _testBinaryOp(self,
|
||||||
|
math_fun,
|
||||||
|
np_fun,
|
||||||
|
name,
|
||||||
|
operands=None,
|
||||||
|
extra_operands=None,
|
||||||
|
check_promotion=True,
|
||||||
|
check_promotion_result_type=True):
|
||||||
|
|
||||||
|
def run_test(a, b):
|
||||||
|
for fn in self.array_transforms:
|
||||||
|
arg1 = fn(a)
|
||||||
|
arg2 = fn(b)
|
||||||
|
self.match(
|
||||||
|
math_fun(arg1, arg2),
|
||||||
|
np_fun(arg1, arg2),
|
||||||
|
msg='{}({}, {})'.format(name, arg1, arg2))
|
||||||
|
# Tests type promotion
|
||||||
|
for type_a in self.types:
|
||||||
|
for type_b in self.types:
|
||||||
|
if not check_promotion and type_a != type_b:
|
||||||
|
continue
|
||||||
|
arg1 = np_array_ops.array(a, dtype=type_a)
|
||||||
|
arg2 = np_array_ops.array(b, dtype=type_b)
|
||||||
|
self.match(
|
||||||
|
math_fun(arg1, arg2),
|
||||||
|
np_fun(arg1, arg2),
|
||||||
|
msg='{}({}, {})'.format(name, arg1, arg2),
|
||||||
|
check_dtype=check_promotion_result_type)
|
||||||
|
|
||||||
|
if operands is None:
|
||||||
|
operands = [(5, 2), (5, [2, 3]), (5, [[2, 3], [6, 7]]), ([1, 2, 3], 7),
|
||||||
|
([1, 2, 3], [5, 6, 7])]
|
||||||
|
for operand1, operand2 in operands:
|
||||||
|
run_test(operand1, operand2)
|
||||||
|
if extra_operands is not None:
|
||||||
|
for operand1, operand2 in extra_operands:
|
||||||
|
run_test(operand1, operand2)
|
||||||
|
|
||||||
|
def testDot(self):
|
||||||
|
extra_operands = [([1, 2], [[5, 6, 7], [8, 9, 10]]),
|
||||||
|
(np.arange(2 * 3 * 5).reshape([2, 3, 5]).tolist(),
|
||||||
|
np.arange(5 * 7 * 11).reshape([7, 5, 11]).tolist())]
|
||||||
|
return self._testBinaryOp(
|
||||||
|
np_math_ops.dot, np.dot, 'dot', extra_operands=extra_operands)
|
||||||
|
|
||||||
|
def testMinimum(self):
|
||||||
|
# The numpy version has strange result type when promotion happens,
|
||||||
|
# so set check_promotion_result_type to False.
|
||||||
|
return self._testBinaryOp(
|
||||||
|
np_math_ops.minimum,
|
||||||
|
np.minimum,
|
||||||
|
'minimum',
|
||||||
|
check_promotion_result_type=False)
|
||||||
|
|
||||||
|
def testMaximum(self):
|
||||||
|
# The numpy version has strange result type when promotion happens,
|
||||||
|
# so set check_promotion_result_type to False.
|
||||||
|
return self._testBinaryOp(
|
||||||
|
np_math_ops.maximum,
|
||||||
|
np.maximum,
|
||||||
|
'maximum',
|
||||||
|
check_promotion_result_type=False)
|
||||||
|
|
||||||
|
def testMatmul(self):
|
||||||
|
operands = [([[1, 2]], [[3, 4, 5], [6, 7, 8]])]
|
||||||
|
return self._testBinaryOp(
|
||||||
|
np_math_ops.matmul, np.matmul, 'matmul', operands=operands)
|
||||||
|
|
||||||
|
def testMatmulError(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, r''):
|
||||||
|
np_math_ops.matmul(
|
||||||
|
np_array_ops.ones([], np.int32), np_array_ops.ones([2, 3], np.int32))
|
||||||
|
with self.assertRaisesRegex(ValueError, r''):
|
||||||
|
np_math_ops.matmul(
|
||||||
|
np_array_ops.ones([2, 3], np.int32), np_array_ops.ones([], np.int32))
|
||||||
|
|
||||||
|
def _testUnaryOp(self, math_fun, np_fun, name):
|
||||||
|
|
||||||
|
def run_test(a):
|
||||||
|
for fn in self.array_transforms:
|
||||||
|
arg1 = fn(a)
|
||||||
|
self.match(
|
||||||
|
math_fun(arg1), np_fun(arg1), msg='{}({})'.format(name, arg1))
|
||||||
|
|
||||||
|
run_test(5)
|
||||||
|
run_test([2, 3])
|
||||||
|
run_test([[2, -3], [-6, 7]])
|
||||||
|
|
||||||
|
def testLog(self):
|
||||||
|
self._testUnaryOp(np_math_ops.log, np.log, 'log')
|
||||||
|
|
||||||
|
def testExp(self):
|
||||||
|
self._testUnaryOp(np_math_ops.exp, np.exp, 'exp')
|
||||||
|
|
||||||
|
def testTanh(self):
|
||||||
|
self._testUnaryOp(np_math_ops.tanh, np.tanh, 'tanh')
|
||||||
|
|
||||||
|
def testSqrt(self):
|
||||||
|
self._testUnaryOp(np_math_ops.sqrt, np.sqrt, 'sqrt')
|
||||||
|
|
||||||
|
def match(self, actual, expected, msg='', check_dtype=True):
|
||||||
|
self.assertIsInstance(actual, np_arrays.ndarray)
|
||||||
|
if check_dtype:
|
||||||
|
self.assertEqual(
|
||||||
|
actual.dtype, expected.dtype,
|
||||||
|
'Dtype mismatch.\nActual: {}\nExpected: {}\n{}'.format(
|
||||||
|
actual.dtype, expected.dtype, msg))
|
||||||
|
self.assertEqual(
|
||||||
|
actual.shape, expected.shape,
|
||||||
|
'Shape mismatch.\nActual: {}\nExpected: {}\n{}'.format(
|
||||||
|
actual.shape, expected.shape, msg))
|
||||||
|
np.testing.assert_almost_equal(actual.tolist(), expected.tolist())
|
||||||
|
|
||||||
|
def testArgsort(self):
|
||||||
|
self._testUnaryOp(np_math_ops.argsort, np.argsort, 'argsort')
|
||||||
|
|
||||||
|
# Test stability
|
||||||
|
r = np.arange(100)
|
||||||
|
a = np.zeros(100)
|
||||||
|
np.testing.assert_equal(np_math_ops.argsort(a, kind='stable'), r)
|
||||||
|
|
||||||
|
def testArgMaxArgMin(self):
|
||||||
|
data = [
|
||||||
|
0,
|
||||||
|
5,
|
||||||
|
[1],
|
||||||
|
[1, 2, 3],
|
||||||
|
[[1, 2, 3]],
|
||||||
|
[[4, 6], [7, 8]],
|
||||||
|
[[[4, 6], [9, 10]], [[7, 8], [12, 34]]],
|
||||||
|
]
|
||||||
|
for fn, d in itertools.product(self.array_transforms, data):
|
||||||
|
arr = fn(d)
|
||||||
|
self.match(np_math_ops.argmax(arr), np.argmax(arr))
|
||||||
|
self.match(np_math_ops.argmin(arr), np.argmin(arr))
|
||||||
|
if hasattr(arr, 'shape'):
|
||||||
|
ndims = len(arr.shape)
|
||||||
|
else:
|
||||||
|
ndims = np_array_ops.array(arr, copy=False).ndim
|
||||||
|
if ndims == 0:
|
||||||
|
# Numpy flattens the scalar ndarray and treats it as a 1-d array of
|
||||||
|
# size 1.
|
||||||
|
ndims = 1
|
||||||
|
for axis in range(-ndims, ndims):
|
||||||
|
self.match(
|
||||||
|
np_math_ops.argmax(arr, axis=axis), np.argmax(arr, axis=axis))
|
||||||
|
self.match(
|
||||||
|
np_math_ops.argmin(arr, axis=axis), np.argmin(arr, axis=axis))
|
||||||
|
|
||||||
|
@parameterized.parameters([False, True])
|
||||||
|
def testIsCloseEqualNan(self, equal_nan):
|
||||||
|
a = np.asarray([1, 1, np.nan, 1, np.nan], np.float32)
|
||||||
|
b = np.asarray([1, 2, 1, np.nan, np.nan], np.float32)
|
||||||
|
self.match(
|
||||||
|
np_math_ops.isclose(a, b, equal_nan=equal_nan),
|
||||||
|
np.isclose(a, b, equal_nan=equal_nan))
|
||||||
|
|
||||||
|
def testAverageWrongShape(self):
|
||||||
|
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
|
||||||
|
np_math_ops.average(np.ones([2, 3]), weights=np.ones([2, 4]))
|
||||||
|
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
|
||||||
|
np_math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([2, 4]))
|
||||||
|
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
|
||||||
|
np_math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([]))
|
||||||
|
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, r''):
|
||||||
|
np_math_ops.average(np.ones([2, 3]), axis=0, weights=np.ones([5]))
|
||||||
|
|
||||||
|
def testClip(self):
|
||||||
|
|
||||||
|
def run_test(arr, *args, **kwargs):
|
||||||
|
check_dtype = kwargs.pop('check_dtype', True)
|
||||||
|
for fn in self.array_transforms:
|
||||||
|
arr = fn(arr)
|
||||||
|
self.match(
|
||||||
|
np_math_ops.clip(arr, *args, **kwargs),
|
||||||
|
np.clip(arr, *args, **kwargs),
|
||||||
|
check_dtype=check_dtype)
|
||||||
|
|
||||||
|
# NumPy exhibits weird typing behavior when a/a_min/a_max are scalars v/s
|
||||||
|
# lists, e.g.,
|
||||||
|
#
|
||||||
|
# np.clip(np.array(0, dtype=np.int32), -5, 5).dtype == np.int64
|
||||||
|
# np.clip(np.array([0], dtype=np.int32), -5, 5).dtype == np.int32
|
||||||
|
# np.clip(np.array([0], dtype=np.int32), [-5], [5]).dtype == np.int64
|
||||||
|
#
|
||||||
|
# So we skip matching type. In tf-numpy the type of the output array is
|
||||||
|
# always the same as the input array.
|
||||||
|
run_test(0, -1, 5, check_dtype=False)
|
||||||
|
run_test(-1, -1, 5, check_dtype=False)
|
||||||
|
run_test(5, -1, 5, check_dtype=False)
|
||||||
|
run_test(-10, -1, 5, check_dtype=False)
|
||||||
|
run_test(10, -1, 5, check_dtype=False)
|
||||||
|
run_test(10, None, 5, check_dtype=False)
|
||||||
|
run_test(10, -1, None, check_dtype=False)
|
||||||
|
run_test([0, 20, -5, 4], -1, 5, check_dtype=False)
|
||||||
|
run_test([0, 20, -5, 4], None, 5, check_dtype=False)
|
||||||
|
run_test([0, 20, -5, 4], -1, None, check_dtype=False)
|
||||||
|
run_test([0.5, 20.2, -5.7, 4.4], -1.5, 5.1, check_dtype=False)
|
||||||
|
|
||||||
|
run_test([0, 20, -5, 4], [-5, 0, -5, 0], [0, 5, 0, 5], check_dtype=False)
|
||||||
|
run_test([[1, 2, 3], [4, 5, 6]], [2, 0, 2], 5, check_dtype=False)
|
||||||
|
run_test([[1, 2, 3], [4, 5, 6]], 0, [5, 3, 1], check_dtype=False)
|
||||||
|
|
||||||
|
def testPtp(self):
|
||||||
|
|
||||||
|
def run_test(arr, *args, **kwargs):
|
||||||
|
for fn in self.array_transforms:
|
||||||
|
arg = fn(arr)
|
||||||
|
self.match(
|
||||||
|
np_math_ops.ptp(arg, *args, **kwargs), np.ptp(arg, *args, **kwargs))
|
||||||
|
|
||||||
|
run_test([1, 2, 3])
|
||||||
|
run_test([1., 2., 3.])
|
||||||
|
run_test([[1, 2], [3, 4]], axis=1)
|
||||||
|
run_test([[1, 2], [3, 4]], axis=0)
|
||||||
|
run_test([[1, 2], [3, 4]], axis=-1)
|
||||||
|
run_test([[1, 2], [3, 4]], axis=-2)
|
||||||
|
|
||||||
|
def testLinSpace(self):
|
||||||
|
array_transforms = [
|
||||||
|
lambda x: x, # Identity,
|
||||||
|
ops.convert_to_tensor,
|
||||||
|
np.array,
|
||||||
|
lambda x: np.array(x, dtype=np.float32),
|
||||||
|
lambda x: np.array(x, dtype=np.float64),
|
||||||
|
np_array_ops.array,
|
||||||
|
lambda x: np_array_ops.array(x, dtype=np.float32),
|
||||||
|
lambda x: np_array_ops.array(x, dtype=np.float64)
|
||||||
|
]
|
||||||
|
|
||||||
|
def run_test(start, stop, **kwargs):
|
||||||
|
for fn1 in array_transforms:
|
||||||
|
for fn2 in array_transforms:
|
||||||
|
arg1 = fn1(start)
|
||||||
|
arg2 = fn2(stop)
|
||||||
|
self.match(
|
||||||
|
np_math_ops.linspace(arg1, arg2, **kwargs),
|
||||||
|
np.linspace(arg1, arg2, **kwargs),
|
||||||
|
msg='linspace({}, {})'.format(arg1, arg2))
|
||||||
|
|
||||||
|
run_test(0, 1)
|
||||||
|
run_test(0, 1, num=10)
|
||||||
|
run_test(0, 1, endpoint=False)
|
||||||
|
run_test(0, -1)
|
||||||
|
run_test(0, -1, num=10)
|
||||||
|
run_test(0, -1, endpoint=False)
|
||||||
|
|
||||||
|
def testLogSpace(self):
|
||||||
|
array_transforms = [
|
||||||
|
lambda x: x, # Identity,
|
||||||
|
ops.convert_to_tensor,
|
||||||
|
np.array,
|
||||||
|
lambda x: np.array(x, dtype=np.float32),
|
||||||
|
lambda x: np.array(x, dtype=np.float64),
|
||||||
|
np_array_ops.array,
|
||||||
|
lambda x: np_array_ops.array(x, dtype=np.float32),
|
||||||
|
lambda x: np_array_ops.array(x, dtype=np.float64)
|
||||||
|
]
|
||||||
|
|
||||||
|
def run_test(start, stop, **kwargs):
|
||||||
|
for fn1 in array_transforms:
|
||||||
|
for fn2 in array_transforms:
|
||||||
|
arg1 = fn1(start)
|
||||||
|
arg2 = fn2(stop)
|
||||||
|
self.match(
|
||||||
|
np_math_ops.logspace(arg1, arg2, **kwargs),
|
||||||
|
np.logspace(arg1, arg2, **kwargs),
|
||||||
|
msg='logspace({}, {})'.format(arg1, arg2))
|
||||||
|
|
||||||
|
run_test(0, 5)
|
||||||
|
run_test(0, 5, num=10)
|
||||||
|
run_test(0, 5, endpoint=False)
|
||||||
|
run_test(0, 5, base=2.0)
|
||||||
|
run_test(0, -5)
|
||||||
|
run_test(0, -5, num=10)
|
||||||
|
run_test(0, -5, endpoint=False)
|
||||||
|
run_test(0, -5, base=2.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
56
tensorflow/python/ops/numpy_ops/np_random.py
Normal file
56
tensorflow/python/ops/numpy_ops/np_random.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Random functions."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import random_seed
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_utils
|
||||||
|
|
||||||
|
DEFAULT_RANDN_DTYPE = np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def randn(*args):
|
||||||
|
"""Returns samples from a normal distribution.
|
||||||
|
|
||||||
|
Uses `tf.random_normal`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: The shape of the output array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An ndarray with shape `args` and dtype `float64`.
|
||||||
|
"""
|
||||||
|
# TODO(wangpeng): Use new stateful RNG
|
||||||
|
if np_utils.isscalar(args):
|
||||||
|
args = (args,)
|
||||||
|
return np_utils.tensor_to_ndarray(
|
||||||
|
random_ops.random_normal(args, dtype=DEFAULT_RANDN_DTYPE))
|
||||||
|
|
||||||
|
|
||||||
|
def seed(s):
|
||||||
|
"""Sets the seed for the random number generator.
|
||||||
|
|
||||||
|
Uses `tf.set_random_seed`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: an integer.
|
||||||
|
"""
|
||||||
|
# TODO(wangpeng): make the signature the same as numpy
|
||||||
|
random_seed.set_seed(s)
|
91
tensorflow/python/ops/numpy_ops/np_random_test.py
Normal file
91
tensorflow/python/ops/numpy_ops/np_random_test.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tf numpy random number methods."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
# Needed for ndarray.reshape.
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_array_ops # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_random
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class RandomTest(test.TestCase):
|
||||||
|
|
||||||
|
def assertNotAllClose(self, a, b, **kwargs):
|
||||||
|
try:
|
||||||
|
self.assertAllClose(a, b, **kwargs)
|
||||||
|
except AssertionError:
|
||||||
|
return
|
||||||
|
raise AssertionError('The two values are close at all %d elements' %
|
||||||
|
np.size(a))
|
||||||
|
|
||||||
|
def testRandN(self):
|
||||||
|
|
||||||
|
def run_test(*args):
|
||||||
|
num_samples = 1000
|
||||||
|
tol = 0.1 # High tolerance to keep the # of samples low else the test
|
||||||
|
# takes a long time to run.
|
||||||
|
np_random.seed(10)
|
||||||
|
outputs = [np_random.randn(*args) for _ in range(num_samples)]
|
||||||
|
|
||||||
|
# Test output shape.
|
||||||
|
for output in outputs:
|
||||||
|
self.assertEqual(output.shape, tuple(args))
|
||||||
|
self.assertEqual(output.dtype.type, np_random.DEFAULT_RANDN_DTYPE)
|
||||||
|
|
||||||
|
if np.prod(args): # Don't bother with empty arrays.
|
||||||
|
outputs = [output.tolist() for output in outputs]
|
||||||
|
|
||||||
|
# Test that the properties of normal distribution are satisfied.
|
||||||
|
mean = np.mean(outputs, axis=0)
|
||||||
|
stddev = np.std(outputs, axis=0)
|
||||||
|
self.assertAllClose(mean, np.zeros(args), atol=tol)
|
||||||
|
self.assertAllClose(stddev, np.ones(args), atol=tol)
|
||||||
|
|
||||||
|
# Test that outputs are different with different seeds.
|
||||||
|
np_random.seed(20)
|
||||||
|
diff_seed_outputs = [
|
||||||
|
np_random.randn(*args).tolist() for _ in range(num_samples)
|
||||||
|
]
|
||||||
|
self.assertNotAllClose(outputs, diff_seed_outputs)
|
||||||
|
|
||||||
|
# Test that outputs are the same with the same seed.
|
||||||
|
np_random.seed(10)
|
||||||
|
same_seed_outputs = [
|
||||||
|
np_random.randn(*args).tolist() for _ in range(num_samples)
|
||||||
|
]
|
||||||
|
self.assertAllClose(outputs, same_seed_outputs)
|
||||||
|
|
||||||
|
run_test()
|
||||||
|
run_test(0)
|
||||||
|
run_test(1)
|
||||||
|
run_test(5)
|
||||||
|
run_test(2, 3)
|
||||||
|
run_test(0, 2, 3)
|
||||||
|
run_test(2, 0, 3)
|
||||||
|
run_test(2, 3, 0)
|
||||||
|
run_test(2, 3, 5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
448
tensorflow/python/ops/numpy_ops/np_utils.py
Normal file
448
tensorflow/python/ops/numpy_ops/np_utils.py
Normal file
@ -0,0 +1,448 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utility functions for internal use."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import indexed_slices
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_arrays
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_dtypes
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
tensor_to_ndarray = np_arrays.tensor_to_ndarray
|
||||||
|
|
||||||
|
|
||||||
|
def _supports_signature():
|
||||||
|
return hasattr(inspect, 'signature')
|
||||||
|
|
||||||
|
|
||||||
|
def _to_tf_type(dtype):
|
||||||
|
"""Converts a native python or numpy type to TF DType.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: Could be a python type, a numpy type or a TF DType.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensorflow `DType`.
|
||||||
|
"""
|
||||||
|
return dtypes.as_dtype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_numpy_type(dtype):
|
||||||
|
"""Converts a native python or TF DType to numpy type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: Could be a python type, a numpy type or a TF DType.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A NumPy `dtype`.
|
||||||
|
"""
|
||||||
|
if isinstance(dtype, dtypes.DType):
|
||||||
|
return dtype.as_numpy_dtype
|
||||||
|
return np.dtype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def finfo(dtype):
|
||||||
|
"""Returns properties of floating point types.
|
||||||
|
|
||||||
|
Note that currently it just forwards to the numpy namesake, while tensorflow
|
||||||
|
and numpy dtypes may have different properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: Could be a python type, a numpy type or a TF DType.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A class describing properties of `dtype`, as described by
|
||||||
|
https://docs.scipy.org/doc/numpy/reference/generated/numpy.finfo.html
|
||||||
|
"""
|
||||||
|
return np.finfo(_to_numpy_type(dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def isscalar(val):
|
||||||
|
"""Returns whether `val` is a scalar value or scalar Tensor."""
|
||||||
|
if isinstance(val, (np.ndarray, np_arrays.ndarray, ops.Tensor)):
|
||||||
|
return len(val.shape) == 0 # pylint: disable=g-explicit-length-test
|
||||||
|
return np.isscalar(val)
|
||||||
|
|
||||||
|
|
||||||
|
# Can't use np_doc because np.result_type is a builtin function.
|
||||||
|
def result_type(*arrays_and_dtypes):
|
||||||
|
"""Returns the type resulting from applying NumPy type promotion to arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*arrays_and_dtypes: A list of array_like objects or dtypes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy dtype.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def maybe_get_dtype(x):
|
||||||
|
# Don't put np.ndarray in this list, because np.result_type looks at the
|
||||||
|
# value (not just dtype) of np.ndarray to decide the result type.
|
||||||
|
if isinstance(x, (np_arrays.ndarray, np_arrays.ShardedNdArray, ops.Tensor,
|
||||||
|
indexed_slices.IndexedSlices)):
|
||||||
|
return _to_numpy_type(x.dtype)
|
||||||
|
elif isinstance(x, dtypes.DType):
|
||||||
|
return _to_numpy_type(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
arrays_and_dtypes = [
|
||||||
|
maybe_get_dtype(x) for x in nest.flatten(arrays_and_dtypes)
|
||||||
|
]
|
||||||
|
if not arrays_and_dtypes:
|
||||||
|
# If arrays_and_dtypes is an empty list, let numpy decide what the dtype is.
|
||||||
|
arrays_and_dtypes = [np.asarray([])]
|
||||||
|
return np_dtypes._result_type(*arrays_and_dtypes) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def promote_types(type1, type2):
|
||||||
|
"""Returns the type resulting from applying NumPy type promotion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type1: A numpy type.
|
||||||
|
type2: A numpy type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy type.
|
||||||
|
"""
|
||||||
|
type1 = _to_numpy_type(type1)
|
||||||
|
type2 = _to_numpy_type(type2)
|
||||||
|
return np_dtypes.canonicalize_dtype(np.promote_types(type1, type2))
|
||||||
|
|
||||||
|
|
||||||
|
def _has_docstring(f):
|
||||||
|
return (f and hasattr(f, '__doc__') and isinstance(f.__doc__, str) and
|
||||||
|
f.__doc__)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_blank_line(s):
|
||||||
|
if s.endswith('\n'):
|
||||||
|
return s + '\n'
|
||||||
|
else:
|
||||||
|
return s + '\n\n'
|
||||||
|
|
||||||
|
|
||||||
|
def _np_signature(f):
|
||||||
|
"""An enhanced inspect.signature that can handle numpy.ufunc."""
|
||||||
|
# TODO(wangpeng): consider migrating away from inspect.signature.
|
||||||
|
# inspect.signature is supported in Python 3.3.
|
||||||
|
if not hasattr(inspect, 'signature'):
|
||||||
|
return None
|
||||||
|
if f is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(f, np.ufunc):
|
||||||
|
try:
|
||||||
|
return inspect.signature(f)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def names_from_num(prefix, n):
|
||||||
|
if n <= 0:
|
||||||
|
return []
|
||||||
|
elif n == 1:
|
||||||
|
return [prefix]
|
||||||
|
else:
|
||||||
|
return [prefix + str(i + 1) for i in range(n)]
|
||||||
|
|
||||||
|
input_names = names_from_num('x', f.nin)
|
||||||
|
output_names = names_from_num('out', f.nout)
|
||||||
|
keyword_only_params = [('where', True), ('casting', 'same_kind'),
|
||||||
|
('order', 'K'), ('dtype', None), ('subok', True),
|
||||||
|
('signature', None), ('extobj', None)]
|
||||||
|
params = []
|
||||||
|
params += [
|
||||||
|
inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
|
||||||
|
for name in input_names
|
||||||
|
]
|
||||||
|
if f.nout > 1:
|
||||||
|
params += [
|
||||||
|
inspect.Parameter(
|
||||||
|
name, inspect.Parameter.POSITIONAL_ONLY, default=None)
|
||||||
|
for name in output_names
|
||||||
|
]
|
||||||
|
params += [
|
||||||
|
inspect.Parameter(
|
||||||
|
'out',
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
default=None if f.nout == 1 else (None,) * f.nout)
|
||||||
|
]
|
||||||
|
params += [
|
||||||
|
inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=default)
|
||||||
|
for name, default in keyword_only_params
|
||||||
|
]
|
||||||
|
return inspect.Signature(params)
|
||||||
|
|
||||||
|
|
||||||
|
# Python 2 doesn't allow keyword-only argument. Python prior to 3.8 doesn't
|
||||||
|
# allow positional-only argument. So we conflate positional-only, keyword-only
|
||||||
|
# and positional-or-keyword arguments here.
|
||||||
|
def _is_compatible_param_kind(a, b):
|
||||||
|
|
||||||
|
def relax(k):
|
||||||
|
if k in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.KEYWORD_ONLY):
|
||||||
|
return inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||||
|
return k
|
||||||
|
|
||||||
|
return relax(a) == relax(b)
|
||||||
|
|
||||||
|
|
||||||
|
def np_doc(np_fun, np_fun_name=None):
|
||||||
|
"""Attachs numpy docstring to a function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
np_fun: the numpy function whose docstring will be used.
|
||||||
|
np_fun_name: optional name for the np_fun symbol. At least one of np_fun or
|
||||||
|
np_fun_name shoud be set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function decorator that attaches the docstring from `np_fun` to the
|
||||||
|
decorated function.
|
||||||
|
"""
|
||||||
|
if np_fun is None:
|
||||||
|
assert np_fun_name is not None
|
||||||
|
try:
|
||||||
|
np_fun = getattr(np, str(np_fun_name))
|
||||||
|
except AttributeError:
|
||||||
|
np_fun = None
|
||||||
|
np_sig = _np_signature(np_fun)
|
||||||
|
if np_fun_name is None:
|
||||||
|
assert np_fun is not None
|
||||||
|
np_fun_name = np_fun.__name__
|
||||||
|
|
||||||
|
def decorator(f):
|
||||||
|
"""The decorator."""
|
||||||
|
unsupported_params = []
|
||||||
|
if hasattr(inspect, 'signature') and np_sig is not None:
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(f)
|
||||||
|
except ValueError:
|
||||||
|
sig = None
|
||||||
|
# TODO(wangpeng): Enable this.
|
||||||
|
# Looks like this may not work with different versions of numpy.
|
||||||
|
# if sig is not None:
|
||||||
|
# for name, param in sig.parameters.items():
|
||||||
|
# np_param = np_sig.parameters.get(name)
|
||||||
|
# if np_param is None:
|
||||||
|
# raise TypeError('Cannot find parameter "%s" in the numpy
|
||||||
|
# function\'s ' 'signature' % name)
|
||||||
|
# if not _is_compatible_param_kind(param.kind, np_param.kind):
|
||||||
|
# raise TypeError(
|
||||||
|
# 'Parameter "%s" is of kind %s while in numpy it is of '
|
||||||
|
# 'kind %s' % (name, param.kind, np_param.kind))
|
||||||
|
# has_default = (param.default != inspect.Parameter.empty)
|
||||||
|
# np_has_default = (np_param.default != inspect.Parameter.empty)
|
||||||
|
# if has_default != np_has_default:
|
||||||
|
# raise TypeError('Parameter "%s" should%s have a default value' %
|
||||||
|
# (name, '' if np_has_default else ' not'))
|
||||||
|
# for name in np_sig.parameters:
|
||||||
|
# if name not in sig.parameters:
|
||||||
|
# unsupported_params.append(name)
|
||||||
|
f.__doc__ = _np_doc_helper(f, np_fun, np_fun_name=np_fun_name,
|
||||||
|
unsupported_params=unsupported_params)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None):
|
||||||
|
"""Helper to get docs."""
|
||||||
|
if not unsupported_params and not _has_docstring(f) and _has_docstring(np_f):
|
||||||
|
# TODO(wangpeng): It looks like code snippets in numpy doc don't work
|
||||||
|
# correctly with doctest. Fix that and remove the reformatting of the np_f
|
||||||
|
# comment, here and below.
|
||||||
|
return np_f.__doc__.replace('>>>', '>')
|
||||||
|
assert np_f or np_fun_name
|
||||||
|
if not np_fun_name:
|
||||||
|
np_fun_name = np_f.__name__
|
||||||
|
doc = 'TensorFlow variant of `numpy.%s`.\n\n' % np_fun_name
|
||||||
|
if unsupported_params:
|
||||||
|
doc += 'Unsupported arguments: ' + ', '.join(
|
||||||
|
'`' + name + '`' for name in unsupported_params) + '.\n\n'
|
||||||
|
if _has_docstring(f):
|
||||||
|
doc += f.__doc__
|
||||||
|
doc = _add_blank_line(doc)
|
||||||
|
if _has_docstring(np_f):
|
||||||
|
doc += 'Documentation for `numpy.%s`:\n\n' % np_f.__name__
|
||||||
|
doc += np_f.__doc__.replace('>>>', '>')
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
def np_doc_only(np_f):
|
||||||
|
"""Attachs numpy docstring to a function.
|
||||||
|
|
||||||
|
This differs from np_doc in that it doesn't check for a match in signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
np_f: the numpy function whose docstring will be used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function decorator that attaches the docstring from `np_f` to the
|
||||||
|
decorated function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(f):
|
||||||
|
f.__doc__ = _np_doc_helper(f, np_f)
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def tf_broadcast(*args):
|
||||||
|
"""Broadcast tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: a list of tensors whose shapes are broadcastable against each other.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensors broadcasted to the common shape.
|
||||||
|
"""
|
||||||
|
if len(args) <= 1:
|
||||||
|
return args
|
||||||
|
sh = array_ops.shape(args[0])
|
||||||
|
for arg in args[1:]:
|
||||||
|
sh = array_ops.broadcast_dynamic_shape(sh, array_ops.shape(arg))
|
||||||
|
return [array_ops.broadcast_to(arg, sh) for arg in args]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(wangpeng): Move the following functions to a separate file and check for
|
||||||
|
# float dtypes in each of them.
|
||||||
|
|
||||||
|
|
||||||
|
def get_static_value(x):
|
||||||
|
"""A version of tf.get_static_value that returns None on float dtypes.
|
||||||
|
|
||||||
|
It returns None on float dtypes in order to avoid breaking gradients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: a tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Same as `tf.get_static_value`, except that it returns None when `x` has a
|
||||||
|
float dtype.
|
||||||
|
"""
|
||||||
|
if isinstance(x, ops.Tensor) and (x.dtype.is_floating or x.dtype.is_complex):
|
||||||
|
return None
|
||||||
|
return tensor_util.constant_value(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_static(x):
|
||||||
|
value = get_static_value(x)
|
||||||
|
if value is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# All the following functions exist becaues get_static_value can't handle
|
||||||
|
# their TF counterparts.
|
||||||
|
|
||||||
|
|
||||||
|
def cond(pred, true_fn, false_fn):
|
||||||
|
"""A version of tf.cond that tries to evaluate the condition."""
|
||||||
|
v = get_static_value(pred)
|
||||||
|
if v is None:
|
||||||
|
return control_flow_ops.cond(pred, true_fn, false_fn)
|
||||||
|
if v:
|
||||||
|
return true_fn()
|
||||||
|
else:
|
||||||
|
return false_fn()
|
||||||
|
|
||||||
|
|
||||||
|
def add(a, b):
|
||||||
|
"""A version of tf.add that eagerly evaluates if possible."""
|
||||||
|
return _maybe_static(a) + _maybe_static(b)
|
||||||
|
|
||||||
|
|
||||||
|
def subtract(a, b):
|
||||||
|
"""A version of tf.subtract that eagerly evaluates if possible."""
|
||||||
|
return _maybe_static(a) - _maybe_static(b)
|
||||||
|
|
||||||
|
|
||||||
|
def greater(a, b):
|
||||||
|
"""A version of tf.greater that eagerly evaluates if possible."""
|
||||||
|
return _maybe_static(a) > _maybe_static(b)
|
||||||
|
|
||||||
|
|
||||||
|
def greater_equal(a, b):
|
||||||
|
"""A version of tf.greater_equal that eagerly evaluates if possible."""
|
||||||
|
return _maybe_static(a) >= _maybe_static(b)
|
||||||
|
|
||||||
|
|
||||||
|
def less_equal(a, b):
|
||||||
|
"""A version of tf.less_equal that eagerly evaluates if possible."""
|
||||||
|
return _maybe_static(a) <= _maybe_static(b)
|
||||||
|
|
||||||
|
|
||||||
|
def logical_and(a, b):
|
||||||
|
"""A version of tf.logical_and that eagerly evaluates if possible."""
|
||||||
|
a_value = get_static_value(a)
|
||||||
|
if a_value is not None:
|
||||||
|
if np.isscalar(a_value):
|
||||||
|
if a_value:
|
||||||
|
return _maybe_static(b)
|
||||||
|
else:
|
||||||
|
return a_value
|
||||||
|
else:
|
||||||
|
return a_value & _maybe_static(b)
|
||||||
|
else:
|
||||||
|
return a & _maybe_static(b)
|
||||||
|
|
||||||
|
|
||||||
|
def logical_or(a, b):
|
||||||
|
"""A version of tf.logical_or that eagerly evaluates if possible."""
|
||||||
|
a_value = get_static_value(a)
|
||||||
|
if a_value is not None:
|
||||||
|
if np.isscalar(a_value):
|
||||||
|
if a_value:
|
||||||
|
return a_value
|
||||||
|
else:
|
||||||
|
return _maybe_static(b)
|
||||||
|
else:
|
||||||
|
return a_value | _maybe_static(b)
|
||||||
|
else:
|
||||||
|
return a | _maybe_static(b)
|
||||||
|
|
||||||
|
|
||||||
|
def getitem(a, slice_spec):
|
||||||
|
"""A version of __getitem__ that eagerly evaluates if possible."""
|
||||||
|
return _maybe_static(a)[slice_spec]
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_all(input_tensor, axis=None, keepdims=False):
|
||||||
|
"""A version of tf.reduce_all that eagerly evaluates if possible."""
|
||||||
|
v = get_static_value(input_tensor)
|
||||||
|
if v is None:
|
||||||
|
return math_ops.reduce_all(input_tensor, axis=axis, keepdims=keepdims)
|
||||||
|
else:
|
||||||
|
return v.all(axis=axis, keepdims=keepdims)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_any(input_tensor, axis=None, keepdims=False):
|
||||||
|
"""A version of tf.reduce_any that eagerly evaluates if possible."""
|
||||||
|
v = get_static_value(input_tensor)
|
||||||
|
if v is None:
|
||||||
|
return math_ops.reduce_any(input_tensor, axis=axis, keepdims=keepdims)
|
||||||
|
else:
|
||||||
|
return v.any(axis=axis, keepdims=keepdims)
|
92
tensorflow/python/ops/numpy_ops/np_utils_test.py
Normal file
92
tensorflow/python/ops/numpy_ops/np_utils_test.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for utils.py."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.ops.numpy_ops import np_utils
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class UtilsTest(test.TestCase):
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def testNpDoc(self):
|
||||||
|
|
||||||
|
def np_fun(x):
|
||||||
|
"""np_fun docstring."""
|
||||||
|
return
|
||||||
|
|
||||||
|
@np_utils.np_doc(np_fun)
|
||||||
|
def f():
|
||||||
|
"""f docstring."""
|
||||||
|
return
|
||||||
|
|
||||||
|
expected = """TensorFlow variant of `numpy.np_fun`.
|
||||||
|
|
||||||
|
f docstring.
|
||||||
|
|
||||||
|
Documentation for `numpy.np_fun`:
|
||||||
|
|
||||||
|
np_fun docstring."""
|
||||||
|
self.assertEqual(expected, f.__doc__)
|
||||||
|
|
||||||
|
def testNpDocName(self):
|
||||||
|
|
||||||
|
@np_utils.np_doc(None, np_fun_name='foo')
|
||||||
|
def f():
|
||||||
|
"""f docstring."""
|
||||||
|
return
|
||||||
|
expected = """TensorFlow variant of `numpy.foo`.
|
||||||
|
|
||||||
|
f docstring.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.assertEqual(expected, f.__doc__)
|
||||||
|
|
||||||
|
def testNpDocErrors(self):
|
||||||
|
|
||||||
|
self.skipTest('Enable once np signature checking is done.')
|
||||||
|
# if not np_utils._supports_signature():
|
||||||
|
# self.skipTest("inspect.signature not supported")
|
||||||
|
|
||||||
|
def np_fun(x, y=1, **kwargs):
|
||||||
|
return
|
||||||
|
|
||||||
|
# pylint: disable=unused-variable
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'Cannot find parameter'):
|
||||||
|
|
||||||
|
@np_utils.np_doc(np_fun)
|
||||||
|
def f1(a):
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError, 'is of kind'):
|
||||||
|
|
||||||
|
@np_utils.np_doc(np_fun)
|
||||||
|
def f2(x, kwargs):
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(TypeError,
|
||||||
|
'Parameter "y" should have a default value'):
|
||||||
|
|
||||||
|
@np_utils.np_doc(np_fun)
|
||||||
|
def f3(x, y):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
x
Reference in New Issue
Block a user