Update the base class for _LinearModel in Feature column v1.
PiperOrigin-RevId: 315282469 Change-Id: I86a793c89b8098723750c631d858bc360746942d
This commit is contained in:
parent
319df5224c
commit
3b2109f7de
@ -48,7 +48,6 @@ py_library(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/keras/engine",
|
||||
# TODO(scottzhu): Remove metrics after we cleanup the keras internal cyclar dependency.
|
||||
# //third_party/tensorflow/python/feature_column:feature_column
|
||||
# //third_party/tensorflow/python/keras/engine:engine
|
||||
|
@ -144,7 +144,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.layers import base
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -165,8 +164,8 @@ from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _internal_input_layer(features,
|
||||
@ -616,7 +615,7 @@ def _strip_leading_slashes(name):
|
||||
return name.rsplit('/', 1)[-1]
|
||||
|
||||
|
||||
class _LinearModel(training.Model):
|
||||
class _LinearModel(base.Layer):
|
||||
"""Creates a linear model using feature columns.
|
||||
|
||||
See `linear_model` for details.
|
||||
@ -631,6 +630,12 @@ class _LinearModel(training.Model):
|
||||
name=None,
|
||||
**kwargs):
|
||||
super(_LinearModel, self).__init__(name=name, **kwargs)
|
||||
# We force the keras_style to be True here, as a workaround to not being
|
||||
# able to inherit keras.layers.Layer as base class. Setting this will let
|
||||
# us skip all the legacy behavior for base.Layer.
|
||||
# Also note that we use Layer as base class, instead of Model, since there
|
||||
# isn't any Model specific behavior gets used, eg compile/fit.
|
||||
self._keras_style = True
|
||||
self._feature_columns = _normalize_feature_columns(
|
||||
feature_columns)
|
||||
self._weight_collections = list(weight_collections or [])
|
||||
|
Loading…
Reference in New Issue
Block a user