Update the base class for _LinearModel in Feature column v1.

PiperOrigin-RevId: 315282469
Change-Id: I86a793c89b8098723750c631d858bc360746942d
This commit is contained in:
Scott Zhu 2020-06-08 08:46:51 -07:00 committed by TensorFlower Gardener
parent 319df5224c
commit 3b2109f7de
2 changed files with 8 additions and 4 deletions

View File

@ -48,7 +48,6 @@ py_library(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/keras/engine",
# TODO(scottzhu): Remove metrics after we cleanup the keras internal cyclar dependency.
# //third_party/tensorflow/python/feature_column:feature_column
# //third_party/tensorflow/python/keras/engine:engine

View File

@ -144,7 +144,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine import training
from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@ -165,8 +164,8 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import tf_export
def _internal_input_layer(features,
@ -616,7 +615,7 @@ def _strip_leading_slashes(name):
return name.rsplit('/', 1)[-1]
class _LinearModel(training.Model):
class _LinearModel(base.Layer):
"""Creates a linear model using feature columns.
See `linear_model` for details.
@ -631,6 +630,12 @@ class _LinearModel(training.Model):
name=None,
**kwargs):
super(_LinearModel, self).__init__(name=name, **kwargs)
# We force the keras_style to be True here, as a workaround to not being
# able to inherit keras.layers.Layer as base class. Setting this will let
# us skip all the legacy behavior for base.Layer.
# Also note that we use Layer as base class, instead of Model, since there
# isn't any Model specific behavior gets used, eg compile/fit.
self._keras_style = True
self._feature_columns = _normalize_feature_columns(
feature_columns)
self._weight_collections = list(weight_collections or [])