Implement __gt__ method on FeatureColumn base class so that they are sortable in Python 3.

PiperOrigin-RevId: 259409662
This commit is contained in:
Sundeep Gottipati 2019-07-22 14:41:43 -07:00 committed by TensorFlower Gardener
parent 1d29d5d79f
commit ed2d1fe63d
3 changed files with 55 additions and 7 deletions

View File

@ -1758,12 +1758,15 @@ class _FeatureColumn(object):
pass pass
def __lt__(self, other): def __lt__(self, other):
"""Allows feature columns to be sortable in Python 3 as they are in 2. """Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer. keys in a features dictionary passed to a layer.
`__lt__` is the only method needed for sorting in CPython: In CPython, `__lt__` must be defined for all objects in the
sequence being sorted. If any objects do not have an `__lt__` compatible
with feature column objects (such as strings), then CPython will fall back
to using the `__gt__` method below.
https://docs.python.org/3/library/stdtypes.html#list.sort https://docs.python.org/3/library/stdtypes.html#list.sort
Args: Args:
@ -1772,10 +1775,30 @@ class _FeatureColumn(object):
Returns: Returns:
True if the string representation of this object is lexicographically less True if the string representation of this object is lexicographically less
than the string representation of `other`. For FeatureColumn objects, than the string representation of `other`. For FeatureColumn objects,
this looks like "<__main__.FeatureColumn object at 0x7fa1fc02bba8>". this looks like "<__main__.FeatureColumn object at 0xa>".
""" """
return str(self) < str(other) return str(self) < str(other)
def __gt__(self, other):
"""Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer.
`__gt__` is called when the "other" object being compared during the sort
does not have `__lt__` defined.
Example: http://gpaste/4803354716798976
Args:
other: The other object to compare to.
Returns:
True if the string representation of this object is lexicographically
greater than the string representation of `other`. For FeatureColumn
objects, this looks like "<__main__.FeatureColumn object at 0xa>".
"""
return str(self) > str(other)
@property @property
def _var_scope_name(self): def _var_scope_name(self):
"""Returns string. Used for variable_scope. Defaults to self.name.""" """Returns string. Used for variable_scope. Defaults to self.name."""

View File

@ -2198,12 +2198,17 @@ class FeatureColumn(object):
pass pass
def __lt__(self, other): def __lt__(self, other):
"""Allows feature columns to be sortable in Python 3 as they are in 2. """Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer. keys in a features dictionary passed to a layer.
`__lt__` is the only method needed for sorting in CPython: In CPython, `__lt__` must be defined for all objects in the
sequence being sorted.
If any objects in teh sequence being sorted do not have an `__lt__` method
compatible with feature column objects (such as strings), then CPython will
fall back to using the `__gt__` method below.
https://docs.python.org/3/library/stdtypes.html#list.sort https://docs.python.org/3/library/stdtypes.html#list.sort
Args: Args:
@ -2212,10 +2217,30 @@ class FeatureColumn(object):
Returns: Returns:
True if the string representation of this object is lexicographically less True if the string representation of this object is lexicographically less
than the string representation of `other`. For FeatureColumn objects, than the string representation of `other`. For FeatureColumn objects,
this looks like "<__main__.FeatureColumn object at 0x7fa1fc02bba8>". this looks like "<__main__.FeatureColumn object at 0xa>".
""" """
return str(self) < str(other) return str(self) < str(other)
def __gt__(self, other):
"""Allows feature columns to be sorted in Python 3 as they are in Python 2.
Feature columns need to occasionally be sortable, for example when used as
keys in a features dictionary passed to a layer.
`__gt__` is called when the "other" object being compared during the sort
does not have `__lt__` defined.
Example: http://gpaste/4803354716798976
Args:
other: The other object to compare to.
Returns:
True if the string representation of this object is lexicographically
greater than the string representation of `other`. For FeatureColumn
objects, this looks like "<__main__.FeatureColumn object at 0xa>".
"""
return str(self) > str(other)
@abc.abstractmethod @abc.abstractmethod
def transform_feature(self, transformation_cache, state_manager): def transform_feature(self, transformation_cache, state_manager):
"""Returns intermediate representation (usually a `Tensor`). """Returns intermediate representation (usually a `Tensor`).

View File

@ -99,7 +99,7 @@ class SortableFeatureColumnTest(test.TestCase):
a = fc.numeric_column('first') # '<__main__.NumericColumn object at 0xa>' a = fc.numeric_column('first') # '<__main__.NumericColumn object at 0xa>'
b = fc.numeric_column('second') # '<__main__.NumericColumn object at 0xb>' b = fc.numeric_column('second') # '<__main__.NumericColumn object at 0xb>'
c = fc_old._numeric_column('third') # '<__main__._NumericColumn ...>' c = fc_old._numeric_column('third') # '<__main__._NumericColumn ...>'
self.assertAllEqual(sorted(['d', c, b, a]), [a, b, c, 'd']) self.assertAllEqual(sorted(['d', c, b, a, '0']), ['0', a, b, c, 'd'])
class LazyColumnTest(test.TestCase): class LazyColumnTest(test.TestCase):