Merge pull request #40921 from rahul-kamat:generic-tensor
PiperOrigin-RevId: 320640683 Change-Id: Iaf606eb19770c2495cd7da4196566f986d584b67
This commit is contained in:
commit
c14d4f0e21
@ -31,7 +31,11 @@
|
||||
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
|
||||
* <NOTES SHOULD BE GROUPED PER AREA>
|
||||
* TF Core:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* `tf.Tensor` is now a subclass of `typing.Generic`, allowing type annotations
|
||||
to be parameterized by dtype: `tf.Tensor[tf.Int32]`. This requires Python 3,
|
||||
and will become fully compatible with static type checkers in the future.
|
||||
|
||||
* `tf.data`:
|
||||
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
|
||||
the complement of `select_cols`; at most one of these should be specified.
|
||||
|
@ -24,6 +24,7 @@ import sys
|
||||
import threading
|
||||
import types
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import map # pylint: disable=redefined-builtin
|
||||
@ -254,9 +255,19 @@ def disable_tensor_equality():
|
||||
Tensor._USE_EQUALITY = False # pylint: disable=protected-access
|
||||
|
||||
|
||||
DataType = TypeVar("DataType", bound=dtypes.DType)
|
||||
|
||||
# TODO(rahulkamat): Remove this and make Tensor a generic class
|
||||
# once compatibility with Python 2 is dropped.
|
||||
if sys.version_info[0] >= 3:
|
||||
TensorTypeBase = Generic[DataType]
|
||||
else:
|
||||
TensorTypeBase = object
|
||||
|
||||
|
||||
# TODO(mdan): This object should subclass Symbol, not just Tensor.
|
||||
@tf_export("Tensor")
|
||||
class Tensor(internal.NativeObject, core_tf_types.Tensor):
|
||||
class Tensor(internal.NativeObject, core_tf_types.Tensor, TensorTypeBase):
|
||||
"""A tensor is a multidimensional array of elements represented by a
|
||||
|
||||
`tf.Tensor` object. All elements are of a single known data type.
|
||||
|
@ -3,7 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
|
||||
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'typing.Generic\'>"
|
||||
member {
|
||||
name: "OVERLOADABLE_OPERATORS"
|
||||
mtype: "<type \'set\'>"
|
||||
|
@ -258,7 +258,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "Tensor"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'typing.GenericMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorArray"
|
||||
|
@ -3,7 +3,7 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.ops.Tensor\'>"
|
||||
is_instance: "<class \'tensorflow.python.types.internal.NativeObject\'>"
|
||||
is_instance: "<class \'tensorflow.python.types.core.Tensor\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'typing.Generic\'>"
|
||||
member {
|
||||
name: "OVERLOADABLE_OPERATORS"
|
||||
mtype: "<type \'set\'>"
|
||||
|
@ -66,7 +66,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "Tensor"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'typing.GenericMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorArray"
|
||||
|
Loading…
Reference in New Issue
Block a user