Add properties SparseTensorSpec.dtype and SparseTensorSpec.shape
PiperOrigin-RevId: 260991316
This commit is contained in:
parent
644e458325
commit
17555d3c37
tensorflow
python/framework
tools/api/golden
@ -279,6 +279,16 @@ class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
||||
def _serialize(self):
|
||||
return (self._shape, self._dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The `tf.dtypes.DType` specified by this type for the SparseTensor."""
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""The `tf.TensorShape` specified by this type for the SparseTensor."""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
rank = self._shape.ndims
|
||||
|
@ -122,12 +122,12 @@ class SparseTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
|
||||
def testConstruction(self):
|
||||
spec1 = sparse_tensor.SparseTensorSpec()
|
||||
self.assertEqual(spec1._shape.rank, None)
|
||||
self.assertEqual(spec1._dtype, dtypes.float32)
|
||||
self.assertEqual(spec1.shape.rank, None)
|
||||
self.assertEqual(spec1.dtype, dtypes.float32)
|
||||
|
||||
spec2 = sparse_tensor.SparseTensorSpec([None, None], dtypes.string)
|
||||
self.assertEqual(spec2._shape.as_list(), [None, None])
|
||||
self.assertEqual(spec2._dtype, dtypes.string)
|
||||
self.assertEqual(spec2.shape.as_list(), [None, None])
|
||||
self.assertEqual(spec2.dtype, dtypes.string)
|
||||
|
||||
def testValueType(self):
|
||||
spec1 = sparse_tensor.SparseTensorSpec()
|
||||
|
@ -4,6 +4,14 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,6 +4,14 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user