Fixed cases where tf.TensorShape was constructed with float dimensions

This is a prerequisite for making TensorShape and Dimension more strict
about the types of their arguments.

PiperOrigin-RevId: 274700832
This commit is contained in:
Sergei Lebedev 2019-10-14 17:50:22 -07:00 committed by TensorFlower Gardener
parent e2eb7e3641
commit 2f245bdc5b
3 changed files with 4 additions and 4 deletions

View File

@ -123,12 +123,12 @@ class AllReduceTest(test_util.TensorFlowTestCase):
for otl in output_tensors:
self.assertEqual(len(otl), num_chunks)
for ot in otl:
self.assertEqual(ot.shape, [tlen/num_chunks])
self.assertEqual(ot.shape, [tlen//num_chunks])
def _buildInitialVars(self, shape, dev_list):
values = []
num_devices = len(dev_list)
dim = np.prod(shape) if shape else 1
dim = np.prod(shape, dtype=int) if shape else 1
for d in range(0, num_devices):
with ops.device(dev_list[d]):
npt = np.zeros(shape).astype(np.float32)

View File

@ -184,7 +184,7 @@ class ShardingPolicy(object):
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
(shape.as_list(), self._number_of_shards,
self._shard_dimension))
dims[self._shard_dimension] /= self._number_of_shards
dims[self._shard_dimension] //= self._number_of_shards
return tensor_shape.as_shape(dims)
def _unshard_shape(self, shape):

View File

@ -398,7 +398,7 @@ class _SparseMetaData(object):
"""
self._sparse = sparse
self._map_op = map_op
self._rank = tensor_shape.Dimension(rank)
self._rank = tensor_shape.as_dimension(rank)
def __eq__(self, other):
if self.sparse != other.sparse: