Merge pull request #40921 from rahul-kamat:generic-tensor

PiperOrigin-RevId: 320640683
Change-Id: Iaf606eb19770c2495cd7da4196566f986d584b67
This commit is contained in:
TensorFlower Gardener 2020-07-10 11:38:17 -07:00
commit c14d4f0e21
6 changed files with 21 additions and 6 deletions

View File

@ -31,7 +31,11 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* TF Core:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* `tf.Tensor` is now a subclass of `typing.Generic`, allowing type annotations
to be parameterized by dtype: `tf.Tensor[tf.Int32]`. This requires Python 3,
and will become fully compatible with static type checkers in the future.
* `tf.data`:
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.

View File

@ -24,6 +24,7 @@ import sys
import threading
import types
from typing import Generic, TypeVar
import numpy as np
import six
from six.moves import map # pylint: disable=redefined-builtin
@ -254,9 +255,19 @@ def disable_tensor_equality():
Tensor._USE_EQUALITY = False # pylint: disable=protected-access
DataType = TypeVar("DataType", bound=dtypes.DType)
# TODO(rahulkamat): Remove this and make Tensor a generic class
# once compatibility with Python 2 is dropped.
if sys.version_info[0] >= 3:
TensorTypeBase = Generic[DataType]
else:
TensorTypeBase = object
# TODO(mdan): This object should subclass Symbol, not just Tensor.
@tf_export("Tensor")
class Tensor(internal.NativeObject, core_tf_types.Tensor):
class Tensor(internal.NativeObject, core_tf_types.Tensor, TensorTypeBase):
"""A tensor is a multidimensional array of elements represented by a
`tf.Tensor` object. All elements are of a single known data type.

View File

@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'typing.Generic\'>"
member {
name: "OVERLOADABLE_OPERATORS"
mtype: "<type \'set\'>"

View File

@ -258,7 +258,7 @@ tf_module {
}
member {
name: "Tensor"
mtype: "<type \'type\'>"
mtype: "<class \'typing.GenericMeta\'>"
}
member {
name: "TensorArray"

View File

@ -3,7 +3,7 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
is_instance: "<type \'object\'>"
is_instance: "<class \'typing.Generic\'>"
member {
name: "OVERLOADABLE_OPERATORS"
mtype: "<type \'set\'>"

View File

@ -66,7 +66,7 @@ tf_module {
}
member {
name: "Tensor"
mtype: "<type \'type\'>"
mtype: "<class \'typing.GenericMeta\'>"
}
member {
name: "TensorArray"