Fixed cases where tf.TensorShape was constructed with float dimensions
This is a prerequisite for making TensorShape and Dimension more strict about the types of their arguments. PiperOrigin-RevId: 274700832
This commit is contained in:
parent
e2eb7e3641
commit
2f245bdc5b
@ -123,12 +123,12 @@ class AllReduceTest(test_util.TensorFlowTestCase):
|
|||||||
for otl in output_tensors:
|
for otl in output_tensors:
|
||||||
self.assertEqual(len(otl), num_chunks)
|
self.assertEqual(len(otl), num_chunks)
|
||||||
for ot in otl:
|
for ot in otl:
|
||||||
self.assertEqual(ot.shape, [tlen/num_chunks])
|
self.assertEqual(ot.shape, [tlen//num_chunks])
|
||||||
|
|
||||||
def _buildInitialVars(self, shape, dev_list):
|
def _buildInitialVars(self, shape, dev_list):
|
||||||
values = []
|
values = []
|
||||||
num_devices = len(dev_list)
|
num_devices = len(dev_list)
|
||||||
dim = np.prod(shape) if shape else 1
|
dim = np.prod(shape, dtype=int) if shape else 1
|
||||||
for d in range(0, num_devices):
|
for d in range(0, num_devices):
|
||||||
with ops.device(dev_list[d]):
|
with ops.device(dev_list[d]):
|
||||||
npt = np.zeros(shape).astype(np.float32)
|
npt = np.zeros(shape).astype(np.float32)
|
||||||
|
@ -184,7 +184,7 @@ class ShardingPolicy(object):
|
|||||||
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
|
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
|
||||||
(shape.as_list(), self._number_of_shards,
|
(shape.as_list(), self._number_of_shards,
|
||||||
self._shard_dimension))
|
self._shard_dimension))
|
||||||
dims[self._shard_dimension] /= self._number_of_shards
|
dims[self._shard_dimension] //= self._number_of_shards
|
||||||
return tensor_shape.as_shape(dims)
|
return tensor_shape.as_shape(dims)
|
||||||
|
|
||||||
def _unshard_shape(self, shape):
|
def _unshard_shape(self, shape):
|
||||||
|
@ -398,7 +398,7 @@ class _SparseMetaData(object):
|
|||||||
"""
|
"""
|
||||||
self._sparse = sparse
|
self._sparse = sparse
|
||||||
self._map_op = map_op
|
self._map_op = map_op
|
||||||
self._rank = tensor_shape.Dimension(rank)
|
self._rank = tensor_shape.as_dimension(rank)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if self.sparse != other.sparse:
|
if self.sparse != other.sparse:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user