Update the base class for _LinearModel in Feature column v1.
PiperOrigin-RevId: 315282469 Change-Id: I86a793c89b8098723750c631d858bc360746942d
This commit is contained in:
parent
319df5224c
commit
3b2109f7de
@ -48,7 +48,6 @@ py_library(
|
|||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/keras/engine",
|
|
||||||
# TODO(scottzhu): Remove metrics after we cleanup the keras internal cyclar dependency.
|
# TODO(scottzhu): Remove metrics after we cleanup the keras internal cyclar dependency.
|
||||||
# //third_party/tensorflow/python/feature_column:feature_column
|
# //third_party/tensorflow/python/feature_column:feature_column
|
||||||
# //third_party/tensorflow/python/keras/engine:engine
|
# //third_party/tensorflow/python/keras/engine:engine
|
||||||
|
@ -144,7 +144,6 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.keras.engine import training
|
|
||||||
from tensorflow.python.layers import base
|
from tensorflow.python.layers import base
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
@ -165,8 +164,8 @@ from tensorflow.python.platform import gfile
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import checkpoint_utils
|
from tensorflow.python.training import checkpoint_utils
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
from tensorflow.python.util.compat import collections_abc
|
from tensorflow.python.util.compat import collections_abc
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
def _internal_input_layer(features,
|
def _internal_input_layer(features,
|
||||||
@ -616,7 +615,7 @@ def _strip_leading_slashes(name):
|
|||||||
return name.rsplit('/', 1)[-1]
|
return name.rsplit('/', 1)[-1]
|
||||||
|
|
||||||
|
|
||||||
class _LinearModel(training.Model):
|
class _LinearModel(base.Layer):
|
||||||
"""Creates a linear model using feature columns.
|
"""Creates a linear model using feature columns.
|
||||||
|
|
||||||
See `linear_model` for details.
|
See `linear_model` for details.
|
||||||
@ -631,6 +630,12 @@ class _LinearModel(training.Model):
|
|||||||
name=None,
|
name=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(_LinearModel, self).__init__(name=name, **kwargs)
|
super(_LinearModel, self).__init__(name=name, **kwargs)
|
||||||
|
# We force the keras_style to be True here, as a workaround to not being
|
||||||
|
# able to inherit keras.layers.Layer as base class. Setting this will let
|
||||||
|
# us skip all the legacy behavior for base.Layer.
|
||||||
|
# Also note that we use Layer as base class, instead of Model, since there
|
||||||
|
# isn't any Model specific behavior gets used, eg compile/fit.
|
||||||
|
self._keras_style = True
|
||||||
self._feature_columns = _normalize_feature_columns(
|
self._feature_columns = _normalize_feature_columns(
|
||||||
feature_columns)
|
feature_columns)
|
||||||
self._weight_collections = list(weight_collections or [])
|
self._weight_collections = list(weight_collections or [])
|
||||||
|
Loading…
Reference in New Issue
Block a user