Move the _BaseFeatureLayer back to Keras.

All its subclasses are in keras/feature_column

PiperOrigin-RevId: 312521092
Change-Id: Icba59f4be0299487df5e2fd86ab697ab9e7317b3
This commit is contained in:
Scott Zhu 2020-05-20 11:30:15 -07:00 committed by TensorFlower Gardener
parent 07898e752c
commit 0992a65a5d
10 changed files with 171 additions and 118 deletions

View File

@ -383,117 +383,6 @@ class _StateManagerImplV2(_StateManagerImpl):
return var
class _BaseFeaturesLayer(Layer):
"""Base class for DenseFeatures and SequenceFeatures.
Defines common methods and helpers.
Args:
feature_columns: An iterable containing the FeatureColumns to use as
inputs to your model.
expected_column_type: Expected class for provided feature columns.
trainable: Boolean, whether the layer's variables will be updated via
gradient descent during training.
name: Name to give to the DenseFeatures.
**kwargs: Keyword arguments to construct a layer.
Raises:
ValueError: if an item in `feature_columns` doesn't match
`expected_column_type`.
"""
def __init__(self,
feature_columns,
expected_column_type,
trainable,
name,
partitioner=None,
**kwargs):
super(_BaseFeaturesLayer, self).__init__(
name=name, trainable=trainable, **kwargs)
self._feature_columns = _normalize_feature_columns(feature_columns)
self._state_manager = _StateManagerImpl(self, self.trainable)
self._partitioner = partitioner
for column in self._feature_columns:
if not isinstance(column, expected_column_type):
raise ValueError(
'Items of feature_columns must be a {}. '
'You can wrap a categorical column with an '
'embedding_column or indicator_column. Given: {}'.format(
expected_column_type, column))
def build(self, _):
for column in self._feature_columns:
with variable_scope._pure_variable_scope( # pylint: disable=protected-access
self.name,
partitioner=self._partitioner):
with variable_scope._pure_variable_scope( # pylint: disable=protected-access
_sanitize_column_name_for_variable_scope(column.name)):
column.create_state(self._state_manager)
super(_BaseFeaturesLayer, self).build(None)
def _output_shape(self, input_shape, num_elements):
"""Computes expected output shape of the layer or a column's dense tensor.
Args:
input_shape: Tensor or array with batch shape.
num_elements: Size of the last dimension of the output.
Returns:
Tuple with output shape.
"""
raise NotImplementedError('Calling an abstract method.')
def compute_output_shape(self, input_shape):
total_elements = 0
for column in self._feature_columns:
total_elements += column.variable_shape.num_elements()
return self._target_shape(input_shape, total_elements)
def _process_dense_tensor(self, column, tensor):
"""Reshapes the dense tensor output of a column based on expected shape.
Args:
column: A DenseColumn or SequenceDenseColumn object.
tensor: A dense tensor obtained from the same column.
Returns:
Reshaped dense tensor."""
num_elements = column.variable_shape.num_elements()
target_shape = self._target_shape(array_ops.shape(tensor), num_elements)
return array_ops.reshape(tensor, shape=target_shape)
def _verify_and_concat_tensors(self, output_tensors):
"""Verifies and concatenates the dense output of several columns."""
_verify_static_batch_size_equality(output_tensors, self._feature_columns)
return array_ops.concat(output_tensors, -1)
def get_config(self):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
column_configs = serialization.serialize_feature_columns(
self._feature_columns)
config = {'feature_columns': column_configs}
config['partitioner'] = generic_utils.serialize_keras_object(
self._partitioner)
base_config = super( # pylint: disable=bad-super-call
_BaseFeaturesLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
config_cp = config.copy()
config_cp['feature_columns'] = serialization.deserialize_feature_columns(
config['feature_columns'], custom_objects=custom_objects)
config_cp['partitioner'] = generic_utils.deserialize_keras_object(
config['partitioner'], custom_objects)
return cls(**config_cp)
class _LinearModelLayer(Layer):
"""Layer that contains logic for `LinearModel`."""

View File

@ -14,18 +14,32 @@ py_library(
name = "feature_column",
srcs = ["__init__.py"],
deps = [
":base_feature_layer",
":dense_features",
":dense_features_v2",
":sequence_feature_column",
],
)
py_library(
name = "base_feature_layer",
srcs = ["base_feature_layer.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:variable_scope",
"//tensorflow/python/feature_column:feature_column_v2",
"//tensorflow/python/keras/engine:base_layer",
"//tensorflow/python/keras/utils:generic_utils",
],
)
py_library(
name = "dense_features",
srcs = [
"dense_features.py",
],
deps = [
":base_feature_layer",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tf_export",
"//tensorflow/python:util",
@ -40,6 +54,7 @@ py_library(
"dense_features_v2.py",
],
deps = [
":base_feature_layer",
":dense_features",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tf_export",
@ -98,6 +113,7 @@ py_library(
name = "sequence_feature_column",
srcs = ["sequence_feature_column.py"],
deps = [
":base_feature_layer",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:framework_ops",

View File

@ -0,0 +1,145 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This API defines FeatureColumn abstraction."""
# This file was originally under tf/python/feature_column, and was moved to
# Keras package in order to remove the reverse dependency from TF to Keras.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
class _BaseFeaturesLayer(Layer):
"""Base class for DenseFeatures and SequenceFeatures.
Defines common methods and helpers.
Args:
feature_columns: An iterable containing the FeatureColumns to use as
inputs to your model.
expected_column_type: Expected class for provided feature columns.
trainable: Boolean, whether the layer's variables will be updated via
gradient descent during training.
name: Name to give to the DenseFeatures.
**kwargs: Keyword arguments to construct a layer.
Raises:
ValueError: if an item in `feature_columns` doesn't match
`expected_column_type`.
"""
def __init__(self,
feature_columns,
expected_column_type,
trainable,
name,
partitioner=None,
**kwargs):
super(_BaseFeaturesLayer, self).__init__(
name=name, trainable=trainable, **kwargs)
self._feature_columns = feature_column_v2._normalize_feature_columns( # pylint: disable=protected-access
feature_columns)
self._state_manager = feature_column_v2._StateManagerImpl( # pylint: disable=protected-access
self, self.trainable)
self._partitioner = partitioner
for column in self._feature_columns:
if not isinstance(column, expected_column_type):
raise ValueError(
'Items of feature_columns must be a {}. '
'You can wrap a categorical column with an '
'embedding_column or indicator_column. Given: {}'.format(
expected_column_type, column))
def build(self, _):
for column in self._feature_columns:
with variable_scope._pure_variable_scope( # pylint: disable=protected-access
self.name,
partitioner=self._partitioner):
with variable_scope._pure_variable_scope( # pylint: disable=protected-access
feature_column_v2._sanitize_column_name_for_variable_scope( # pylint: disable=protected-access
column.name)):
column.create_state(self._state_manager)
super(_BaseFeaturesLayer, self).build(None)
def _output_shape(self, input_shape, num_elements):
"""Computes expected output shape of the layer or a column's dense tensor.
Args:
input_shape: Tensor or array with batch shape.
num_elements: Size of the last dimension of the output.
Returns:
Tuple with output shape.
"""
raise NotImplementedError('Calling an abstract method.')
def compute_output_shape(self, input_shape):
total_elements = 0
for column in self._feature_columns:
total_elements += column.variable_shape.num_elements()
return self._target_shape(input_shape, total_elements)
def _process_dense_tensor(self, column, tensor):
"""Reshapes the dense tensor output of a column based on expected shape.
Args:
column: A DenseColumn or SequenceDenseColumn object.
tensor: A dense tensor obtained from the same column.
Returns:
Reshaped dense tensor.
"""
num_elements = column.variable_shape.num_elements()
target_shape = self._target_shape(array_ops.shape(tensor), num_elements)
return array_ops.reshape(tensor, shape=target_shape)
def _verify_and_concat_tensors(self, output_tensors):
"""Verifies and concatenates the dense output of several columns."""
feature_column_v2._verify_static_batch_size_equality( # pylint: disable=protected-access
output_tensors, self._feature_columns)
return array_ops.concat(output_tensors, -1)
def get_config(self):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
column_configs = serialization.serialize_feature_columns(
self._feature_columns)
config = {'feature_columns': column_configs}
config['partitioner'] = generic_utils.serialize_keras_object(
self._partitioner)
base_config = super( # pylint: disable=bad-super-call
_BaseFeaturesLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
# Import here to avoid circular imports.
from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top
config_cp = config.copy()
config_cp['feature_columns'] = serialization.deserialize_feature_columns(
config['feature_columns'], custom_objects=custom_objects)
config_cp['partitioner'] = generic_utils.deserialize_keras_object(
config['partitioner'], custom_objects)
return cls(**config_cp)

View File

@ -23,12 +23,13 @@ import json
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
from tensorflow.python.util import serialization
from tensorflow.python.util.tf_export import keras_export
@keras_export(v1=['keras.layers.DenseFeatures'])
class DenseFeatures(fc._BaseFeaturesLayer): # pylint: disable=protected-access
class DenseFeatures(kfc._BaseFeaturesLayer): # pylint: disable=protected-access
"""A layer that produces a dense `Tensor` based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.framework import ops
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
from tensorflow.python.keras.feature_column import dense_features
from tensorflow.python.util.tf_export import keras_export
@ -92,4 +93,4 @@ class DenseFeatures(dense_features.DenseFeatures):
column.create_state(self._state_manager)
# We would like to call Layer.build and not _DenseFeaturesHelper.build.
# pylint: disable=protected-access
super(fc._BaseFeaturesLayer, self).build(None) # pylint: disable=bad-super-call
super(kfc._BaseFeaturesLayer, self).build(None) # pylint: disable=bad-super-call

View File

@ -24,6 +24,7 @@ from __future__ import print_function
from tensorflow.python.feature_column import feature_column_v2 as fc
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.util.tf_export import keras_export
@ -32,7 +33,7 @@ from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.experimental.SequenceFeatures')
class SequenceFeatures(fc._BaseFeaturesLayer):
class SequenceFeatures(kfc._BaseFeaturesLayer):
"""A layer for sequence input.
All `feature_columns` must be sequence dense columns with the same

View File

@ -1,7 +1,7 @@
path: "tensorflow.keras.experimental.SequenceFeatures"
tf_class {
is_instance: "<class \'tensorflow.python.keras.feature_column.sequence_feature_column.SequenceFeatures\'>"
is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"

View File

@ -1,7 +1,7 @@
path: "tensorflow.keras.layers.DenseFeatures"
tf_class {
is_instance: "<class \'tensorflow.python.keras.feature_column.dense_features.DenseFeatures\'>"
is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"

View File

@ -1,7 +1,7 @@
path: "tensorflow.keras.experimental.SequenceFeatures"
tf_class {
is_instance: "<class \'tensorflow.python.keras.feature_column.sequence_feature_column.SequenceFeatures\'>"
is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"

View File

@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.DenseFeatures"
tf_class {
is_instance: "<class \'tensorflow.python.keras.feature_column.dense_features_v2.DenseFeatures\'>"
is_instance: "<class \'tensorflow.python.keras.feature_column.dense_features.DenseFeatures\'>"
is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.feature_column.base_feature_layer._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"