Adds a shape
property to LabeledTensor.
#labeledtensor PiperOrigin-RevId: 186073035
This commit is contained in:
parent
090bb9168c
commit
128572c316
@ -361,6 +361,10 @@ class LabeledTensor(object):
|
||||
def dtype(self):
|
||||
return self._tensor.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._tensor.shape
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._tensor.name
|
||||
|
@ -244,6 +244,9 @@ class LabeledTensorTest(test_util.Base):
|
||||
def test_dtype(self):
|
||||
self.assertEqual(self.lt.dtype, self.lt.tensor.dtype)
|
||||
|
||||
def test_shape(self):
|
||||
self.assertEqual(self.lt.shape, self.lt.tensor.shape)
|
||||
|
||||
def test_get_shape(self):
|
||||
self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user