Adds a shape property to LabeledTensor.

#labeledtensor

PiperOrigin-RevId: 186073035
This commit is contained in:
A. Unique TensorFlower 2018-02-16 18:18:35 -08:00 committed by TensorFlower Gardener
parent 090bb9168c
commit 128572c316
2 changed files with 7 additions and 0 deletions

View File

@ -361,6 +361,10 @@ class LabeledTensor(object):
def dtype(self):
return self._tensor.dtype
@property
def shape(self):
return self._tensor.shape
@property
def name(self):
return self._tensor.name

View File

@ -244,6 +244,9 @@ class LabeledTensorTest(test_util.Base):
def test_dtype(self):
self.assertEqual(self.lt.dtype, self.lt.tensor.dtype)
def test_shape(self):
self.assertEqual(self.lt.shape, self.lt.tensor.shape)
def test_get_shape(self):
self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape())