Implement __gt__ method on FeatureColumn base class so that they are sortable in Python 3.

PiperOrigin-RevId: 259409662
This commit is contained in:
Sundeep Gottipati 2019-07-22 14:41:43 -07:00 committed by TensorFlower Gardener
parent 1d29d5d79f
commit ed2d1fe63d
3 changed files with 55 additions and 7 deletions

View File

@ -1758,12 +1758,15 @@ class _FeatureColumn(object):
pass
def __lt__(self, other):
"""Allows feature columns to be sortable in Python 3 as they are in 2.
"""Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer.
`__lt__` is the only method needed for sorting in CPython:
In CPython, `__lt__` must be defined for all objects in the
sequence being sorted. If any objects do not have an `__lt__` compatible
with feature column objects (such as strings), then CPython will fall back
to using the `__gt__` method below.
https://docs.python.org/3/library/stdtypes.html#list.sort
Args:
@ -1772,10 +1775,30 @@ class _FeatureColumn(object):
Returns:
True if the string representation of this object is lexicographically less
than the string representation of `other`. For FeatureColumn objects,
this looks like "<__main__.FeatureColumn object at 0x7fa1fc02bba8>".
this looks like "<__main__.FeatureColumn object at 0xa>".
"""
return str(self) < str(other)
def __gt__(self, other):
"""Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer.
`__gt__` is called when the "other" object being compared during the sort
does not have `__lt__` defined.
Example: http://gpaste/4803354716798976
Args:
other: The other object to compare to.
Returns:
True if the string representation of this object is lexicographically
greater than the string representation of `other`. For FeatureColumn
objects, this looks like "<__main__.FeatureColumn object at 0xa>".
"""
return str(self) > str(other)
@property
def _var_scope_name(self):
"""Returns string. Used for variable_scope. Defaults to self.name."""

View File

@ -2198,12 +2198,17 @@ class FeatureColumn(object):
pass
def __lt__(self, other):
"""Allows feature columns to be sortable in Python 3 as they are in 2.
"""Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer.
`__lt__` is the only method needed for sorting in CPython:
In CPython, `__lt__` must be defined for all objects in the
sequence being sorted.
If any objects in teh sequence being sorted do not have an `__lt__` method
compatible with feature column objects (such as strings), then CPython will
fall back to using the `__gt__` method below.
https://docs.python.org/3/library/stdtypes.html#list.sort
Args:
@ -2212,10 +2217,30 @@ class FeatureColumn(object):
Returns:
True if the string representation of this object is lexicographically less
than the string representation of `other`. For FeatureColumn objects,
this looks like "<__main__.FeatureColumn object at 0x7fa1fc02bba8>".
this looks like "<__main__.FeatureColumn object at 0xa>".
"""
return str(self) < str(other)
def __gt__(self, other):
"""Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer.
`__gt__` is called when the "other" object being compared during the sort
does not have `__lt__` defined.
Example: http://gpaste/4803354716798976
Args:
other: The other object to compare to.
Returns:
True if the string representation of this object is lexicographically
greater than the string representation of `other`. For FeatureColumn
objects, this looks like "<__main__.FeatureColumn object at 0xa>".
"""
return str(self) > str(other)
@abc.abstractmethod
def transform_feature(self, transformation_cache, state_manager):
"""Returns intermediate representation (usually a `Tensor`).

View File

@ -99,7 +99,7 @@ class SortableFeatureColumnTest(test.TestCase):
a = fc.numeric_column('first') # '<__main__.NumericColumn object at 0xa>'
b = fc.numeric_column('second') # '<__main__.NumericColumn object at 0xb>'
c = fc_old._numeric_column('third') # '<__main__._NumericColumn ...>'
self.assertAllEqual(sorted(['d', c, b, a]), [a, b, c, 'd'])
self.assertAllEqual(sorted(['d', c, b, a, '0']), ['0', a, b, c, 'd'])
class LazyColumnTest(test.TestCase):