Fixed cases where tf.TensorShape was constructed with float dimensions
This is a prerequisite for making TensorShape and Dimension more strict about the types of their arguments. PiperOrigin-RevId: 274700832
This commit is contained in:
parent
e2eb7e3641
commit
2f245bdc5b
@ -123,12 +123,12 @@ class AllReduceTest(test_util.TensorFlowTestCase):
|
||||
for otl in output_tensors:
|
||||
self.assertEqual(len(otl), num_chunks)
|
||||
for ot in otl:
|
||||
self.assertEqual(ot.shape, [tlen/num_chunks])
|
||||
self.assertEqual(ot.shape, [tlen//num_chunks])
|
||||
|
||||
def _buildInitialVars(self, shape, dev_list):
|
||||
values = []
|
||||
num_devices = len(dev_list)
|
||||
dim = np.prod(shape) if shape else 1
|
||||
dim = np.prod(shape, dtype=int) if shape else 1
|
||||
for d in range(0, num_devices):
|
||||
with ops.device(dev_list[d]):
|
||||
npt = np.zeros(shape).astype(np.float32)
|
||||
|
@ -184,7 +184,7 @@ class ShardingPolicy(object):
|
||||
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
|
||||
(shape.as_list(), self._number_of_shards,
|
||||
self._shard_dimension))
|
||||
dims[self._shard_dimension] /= self._number_of_shards
|
||||
dims[self._shard_dimension] //= self._number_of_shards
|
||||
return tensor_shape.as_shape(dims)
|
||||
|
||||
def _unshard_shape(self, shape):
|
||||
|
@ -398,7 +398,7 @@ class _SparseMetaData(object):
|
||||
"""
|
||||
self._sparse = sparse
|
||||
self._map_op = map_op
|
||||
self._rank = tensor_shape.Dimension(rank)
|
||||
self._rank = tensor_shape.as_dimension(rank)
|
||||
|
||||
def __eq__(self, other):
|
||||
if self.sparse != other.sparse:
|
||||
|
Loading…
Reference in New Issue
Block a user