numpy_ops: Remove convert_to_tensor, ShardedNdArray and tensor_to_ndarray from
the module. PiperOrigin-RevId: 315993646 Change-Id: If6277bfc27638b3874407a120622afaf04b24744
This commit is contained in:
parent
89910c62d6
commit
07c8612582
@ -24,12 +24,7 @@ from tensorflow.python.ops.numpy_ops import np_random as random
|
||||
# pylint: disable=wildcard-import
|
||||
|
||||
from tensorflow.python.ops.numpy_ops.np_array_ops import *
|
||||
# TODO(wangpeng): Move ShardedNdArray, convert_to_tensor, tensor_to_ndarray out
|
||||
# of here.
|
||||
from tensorflow.python.ops.numpy_ops.np_arrays import convert_to_tensor
|
||||
from tensorflow.python.ops.numpy_ops.np_arrays import ndarray
|
||||
from tensorflow.python.ops.numpy_ops.np_arrays import ShardedNdArray
|
||||
from tensorflow.python.ops.numpy_ops.np_arrays import tensor_to_ndarray
|
||||
from tensorflow.python.ops.numpy_ops.np_dtypes import *
|
||||
from tensorflow.python.ops.numpy_ops.np_math_ops import *
|
||||
from tensorflow.python.ops.numpy_ops.np_utils import finfo
|
||||
|
@ -254,49 +254,3 @@ def ndarray_to_tensor(arr, dtype=None, name=None, as_ref=False):
|
||||
ops.register_tensor_conversion_function(ndarray, ndarray_to_tensor)
|
||||
|
||||
|
||||
# Don't use a namedtuple since nest considers that a tuple and unflattens and
|
||||
# flattens it.
|
||||
class ShardedNdArray(object):
|
||||
"""Wrapper over ndarray that can contain tensors on multiple devices.
|
||||
|
||||
This is returned by extensions.pmap, and contains the individual tensors on
|
||||
different devices.
|
||||
"""
|
||||
|
||||
def __init__(self, tensors):
|
||||
"""Initializes the ShardedNdArray.
|
||||
|
||||
Note that the tensors should be ordered in the way the pmap producing these
|
||||
tensors is run.
|
||||
|
||||
Args:
|
||||
tensors: list or tuple of eager tensors, one for each device.
|
||||
"""
|
||||
|
||||
if not isinstance(tensors, (list, tuple)) or not tensors:
|
||||
raise ValueError(
|
||||
'Unable to create a ShardedNdArray without a list of tensors.')
|
||||
self.tensors = tensors
|
||||
self.n_devices = len(tensors)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.tensors[i]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return (self.n_devices,) + self.tensors[0]._shape_tuple() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
|
||||
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs):
|
||||
del args, kwargs
|
||||
# TODO(nareshmodi): Consider a collective op to gather the tensors from the
|
||||
# various devices for performance reasons.
|
||||
return array_ops.stack(value.tensors)
|
||||
|
||||
|
||||
ops.register_tensor_conversion_function(ShardedNdArray,
|
||||
convert_sharded_tensor_to_eager_tensor)
|
||||
|
@ -119,7 +119,7 @@ def result_type(*arrays_and_dtypes):
|
||||
def maybe_get_dtype(x):
|
||||
# Don't put np.ndarray in this list, because np.result_type looks at the
|
||||
# value (not just dtype) of np.ndarray to decide the result type.
|
||||
if isinstance(x, (np_arrays.ndarray, np_arrays.ShardedNdArray, ops.Tensor,
|
||||
if isinstance(x, (np_arrays.ndarray, ops.Tensor,
|
||||
indexed_slices.IndexedSlices)):
|
||||
return _to_numpy_type(x.dtype)
|
||||
elif isinstance(x, dtypes.DType):
|
||||
|
Loading…
Reference in New Issue
Block a user