Move crossed_column to core.
PiperOrigin-RevId: 155687697
This commit is contained in:
parent
770a27161b
commit
e09b0b6ebf
tensorflow/python/feature_column
@ -607,6 +607,18 @@ def bucketized_column(source_column, boundaries):
|
||||
dense_tensor = make_input_layer(features, columns)
|
||||
```
|
||||
|
||||
`bucketized_column` can also be crossed with another categorical column using
|
||||
`crossed_column`:
|
||||
```python
|
||||
price = numeric_column('price')
|
||||
# bucketized_column converts numerical feature to a categorical one.
|
||||
bucketized_price = bucketized_column(price, boundaries=[...])
|
||||
# 'keywords' is a string feature.
|
||||
price_x_keywords = crossed_column([bucketized_price, 'keywords'], 50K)
|
||||
all_feature_columns = [price_x_keywords, ...]
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
```
|
||||
|
||||
Args:
|
||||
source_column: A one-dimensional dense column which is generated with
|
||||
`numeric_column`.
|
||||
@ -1036,6 +1048,107 @@ def weighted_categorical_column(
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
def crossed_column(keys, hash_bucket_size, hash_key=None):
|
||||
"""Returns a column for performing crosses of categorical features.
|
||||
|
||||
Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
|
||||
the transformation can be thought of as:
|
||||
Hash(cartesian product of features) % `hash_bucket_size`
|
||||
|
||||
For example, if the input features are:
|
||||
* SparseTensor referred by first key: shape = [2, 2]
|
||||
[0, 0]: "a"
|
||||
[1, 0]: "b"
|
||||
[1, 1]: "c"
|
||||
|
||||
* SparseTensor referred by second key: shape = [2, 1]
|
||||
[0, 0]: "d"
|
||||
[1, 0]: "e"
|
||||
|
||||
then crossed feature will look like:
|
||||
shape = [2, 2]
|
||||
[0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
|
||||
[1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
|
||||
[1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
|
||||
|
||||
Here is an example to create a linear model with crosses of string features:
|
||||
```python
|
||||
keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
|
||||
all_feature_columns = [keywords_x_doc_terms, ...]
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
```
|
||||
|
||||
You could also use vocabulary lookup before crossing:
|
||||
```python
|
||||
keywords = categorical_column_with_vocabulary_file(
|
||||
'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
|
||||
keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
|
||||
all_feature_columns = [keywords_x_doc_terms, ...]
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
```
|
||||
|
||||
If an input feature is of numeric type, you can use
|
||||
`categorical_column_with_identity`, or `bucketized_column`, as in the example:
|
||||
```python
|
||||
# vertical_id is an integer categorical feature.
|
||||
vertical_id = categorical_column_with_identity('vertical_id', 10K)
|
||||
price = numeric_column('price')
|
||||
# bucketized_column converts numerical feature to a categorical one.
|
||||
bucketized_price = bucketized_column(price, boundaries=[...])
|
||||
vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
|
||||
all_feature_columns = [vertical_id_x_price, ...]
|
||||
linear_prediction = make_linear_model(features, all_feature_columns)
|
||||
```
|
||||
|
||||
To use crossed column in DNN model, you need to add it in an embedding column
|
||||
as in this example:
|
||||
```python
|
||||
vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
|
||||
vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
|
||||
dense_tensor = make_input_layer(features, [vertical_id_x_price_embedded, ...])
|
||||
```
|
||||
|
||||
Args:
|
||||
keys: An iterable identifying the features to be crossed. Each element can
|
||||
be either:
|
||||
* string: Will use the corresponding feature which must be of string type.
|
||||
* `_CategoricalColumn`: Will use the transformed tensor produced by this
|
||||
column. Does not support hashed categorical column.
|
||||
hash_bucket_size: An int > 1. The number of buckets.
|
||||
hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
|
||||
function to combine the crosses fingerprints on SparseCrossOp (optional).
|
||||
|
||||
Returns:
|
||||
A `_CrossedColumn`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `len(keys) < 2`.
|
||||
ValueError: If any of the keys is neither a string nor `_CategoricalColumn`.
|
||||
ValueError: If any of the keys is `_HashedCategoricalColumn`.
|
||||
ValueError: If `hash_bucket_size < 1`.
|
||||
"""
|
||||
if not hash_bucket_size or hash_bucket_size < 1:
|
||||
raise ValueError('hash_bucket_size must be > 1. '
|
||||
'hash_bucket_size: {}'.format(hash_bucket_size))
|
||||
if not keys or len(keys) < 2:
|
||||
raise ValueError(
|
||||
'keys must be a list with length > 1. Given: {}'.format(keys))
|
||||
for key in keys:
|
||||
if (not isinstance(key, six.string_types) and
|
||||
not isinstance(key, _CategoricalColumn)):
|
||||
raise ValueError(
|
||||
'Unsupported key type. All keys must be either string, or '
|
||||
'categorical column except _HashedCategoricalColumn. '
|
||||
'Given: {}'.format(key))
|
||||
if isinstance(key, _HashedCategoricalColumn):
|
||||
raise ValueError(
|
||||
'_HashedCategoricalColumn is not supported. Instead, use the feature '
|
||||
'name as a string. Given: {}'.format(key))
|
||||
return _CrossedColumn(
|
||||
keys=tuple(keys), hash_bucket_size=hash_bucket_size,
|
||||
hash_key=hash_key)
|
||||
|
||||
|
||||
class _FeatureColumn(object):
|
||||
"""Represents a feature column abstraction.
|
||||
|
||||
@ -1969,6 +2082,80 @@ class _WeightedCategoricalColumn(
|
||||
return _CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
|
||||
|
||||
|
||||
class _CrossedColumn(
|
||||
_CategoricalColumn,
|
||||
collections.namedtuple('_CrossedColumn',
|
||||
['keys', 'hash_bucket_size', 'hash_key'])):
|
||||
"""See `crossed_column`."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
feature_names = []
|
||||
for key in _collect_leaf_level_keys(self):
|
||||
if isinstance(key, _FeatureColumn):
|
||||
feature_names.append(key.name)
|
||||
else: # key must be a string
|
||||
feature_names.append(key)
|
||||
return '_X_'.join(sorted(feature_names))
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
config = {}
|
||||
for key in self.keys:
|
||||
if isinstance(key, _FeatureColumn):
|
||||
config.update(key._parse_example_config) # pylint: disable=protected-access
|
||||
else: # key must be a string
|
||||
config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
|
||||
return config
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
feature_tensors = []
|
||||
for key in _collect_leaf_level_keys(self):
|
||||
if isinstance(key, six.string_types):
|
||||
feature_tensors.append(inputs.get(key))
|
||||
elif isinstance(key, _CategoricalColumn):
|
||||
ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access
|
||||
if ids_and_weights.weight_tensor is not None:
|
||||
raise ValueError(
|
||||
'crossed_column does not support weight_tensor, but the given '
|
||||
'column populates weight_tensor. '
|
||||
'Given column: {}'.format(key.name))
|
||||
feature_tensors.append(ids_and_weights.id_tensor)
|
||||
else:
|
||||
raise ValueError('Unsupported column type. Given: {}'.format(key))
|
||||
return sparse_ops._sparse_cross_hashed( # pylint: disable=protected-access
|
||||
inputs=feature_tensors,
|
||||
num_buckets=self.hash_bucket_size,
|
||||
hash_key=self.hash_key)
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
"""Returns number of buckets in this sparse feature."""
|
||||
return self.hash_bucket_size
|
||||
|
||||
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||
trainable=None):
|
||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||
|
||||
|
||||
def _collect_leaf_level_keys(cross):
|
||||
"""Collects base keys by expanding all nested crosses.
|
||||
|
||||
Args:
|
||||
cross: A `_CrossedColumn`.
|
||||
|
||||
Returns:
|
||||
A list of strings or `_CategoricalColumn` instances.
|
||||
"""
|
||||
leaf_level_keys = []
|
||||
for k in cross.keys:
|
||||
if isinstance(k, _CrossedColumn):
|
||||
leaf_level_keys.extend(_collect_leaf_level_keys(k))
|
||||
else:
|
||||
leaf_level_keys.append(k)
|
||||
return leaf_level_keys
|
||||
|
||||
|
||||
# TODO(zakaria): Move this to embedding_ops and make it public.
|
||||
def _safe_embedding_lookup_sparse(embedding_weights,
|
||||
sparse_ids,
|
||||
|
@ -735,6 +735,262 @@ class HashedCategoricalColumnTest(test.TestCase):
|
||||
self.assertAllClose(((4.,), (6.,)), predictions.eval())
|
||||
|
||||
|
||||
class CrossedColumnTest(test.TestCase):
|
||||
|
||||
def test_keys_empty(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'keys must be a list with length > 1'):
|
||||
fc.crossed_column([], 10)
|
||||
|
||||
def test_keys_length_one(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'keys must be a list with length > 1'):
|
||||
fc.crossed_column(['a'], 10)
|
||||
|
||||
def test_key_type_unsupported(self):
|
||||
with self.assertRaisesRegexp(ValueError, 'Unsupported key type'):
|
||||
fc.crossed_column(['a', fc.numeric_column('c')], 10)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, '_HashedCategoricalColumn is not supported'):
|
||||
fc.crossed_column(
|
||||
['a', fc.categorical_column_with_hash_bucket('c', 10)], 10)
|
||||
|
||||
def test_hash_bucket_size_negative(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'hash_bucket_size must be > 1'):
|
||||
fc.crossed_column(['a', 'c'], -1)
|
||||
|
||||
def test_hash_bucket_size_zero(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'hash_bucket_size must be > 1'):
|
||||
fc.crossed_column(['a', 'c'], 0)
|
||||
|
||||
def test_hash_bucket_size_none(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'hash_bucket_size must be > 1'):
|
||||
fc.crossed_column(['a', 'c'], None)
|
||||
|
||||
def test_name(self):
|
||||
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||
|
||||
crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
|
||||
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
|
||||
|
||||
def test_name_ordered_alphabetically(self):
|
||||
"""Tests that the name does not depend on the order of given columns."""
|
||||
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||
|
||||
crossed2 = fc.crossed_column([crossed1, 'c', b], 10)
|
||||
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
|
||||
|
||||
def test_name_leaf_keys_ordered_alphabetically(self):
|
||||
"""Tests that the name does not depend on the order of given columns."""
|
||||
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
crossed1 = fc.crossed_column(['d2', 'c'], 10)
|
||||
|
||||
crossed2 = fc.crossed_column([crossed1, 'd1', b], 10)
|
||||
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
|
||||
|
||||
def test_parse_config(self):
|
||||
a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
crossed = fc.crossed_column([b, 'c'], 10)
|
||||
self.assertEqual({
|
||||
'a': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32),
|
||||
'c': parsing_ops.VarLenFeature(dtypes.string),
|
||||
}, crossed._parse_example_config)
|
||||
|
||||
def test_num_buckets(self):
|
||||
a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
crossed = fc.crossed_column([b, 'c'], 15)
|
||||
self.assertEqual(15, crossed._num_buckets)
|
||||
|
||||
def test_deep_copy(self):
|
||||
a = fc.numeric_column('a', dtype=dtypes.int32)
|
||||
b = fc.bucketized_column(a, boundaries=[0, 1])
|
||||
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||
crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
|
||||
crossed2_copy = copy.deepcopy(crossed2)
|
||||
self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2_copy.name,)
|
||||
self.assertEqual(15, crossed2_copy.hash_bucket_size)
|
||||
self.assertEqual(5, crossed2_copy.hash_key)
|
||||
|
||||
def test_parse_example(self):
|
||||
price = fc.numeric_column('price', shape=[2])
|
||||
bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
|
||||
price_cross_wire = fc.crossed_column([bucketized_price, 'wire'], 10)
|
||||
data = example_pb2.Example(features=feature_pb2.Features(
|
||||
feature={
|
||||
'price':
|
||||
feature_pb2.Feature(float_list=feature_pb2.FloatList(
|
||||
value=[20., 110.])),
|
||||
'wire':
|
||||
feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
|
||||
value=[b'omar', b'stringer'])),
|
||||
}))
|
||||
features = parsing_ops.parse_example(
|
||||
serialized=[data.SerializeToString()],
|
||||
features=price_cross_wire._parse_example_config)
|
||||
self.assertIn('price', features)
|
||||
self.assertIn('wire', features)
|
||||
with self.test_session():
|
||||
self.assertAllEqual([[20., 110.]], features['price'].eval())
|
||||
wire_sparse = features['wire']
|
||||
self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
|
||||
# Use byte constants to pass the open-source test.
|
||||
self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
|
||||
self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
|
||||
|
||||
def test_get_sparse_tensors(self):
|
||||
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||
crossed1 = fc.crossed_column(['d1', 'd2'], 10)
|
||||
crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
|
||||
with ops.Graph().as_default():
|
||||
builder = fc._LazyBuilder({
|
||||
'a': constant_op.constant(((-1., .5), (.5, 1.))),
|
||||
'c': sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=['cA', 'cB', 'cC'],
|
||||
dense_shape=(2, 2)),
|
||||
'd1': sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=['d1A', 'd1B', 'd1C'],
|
||||
dense_shape=(2, 2)),
|
||||
'd2': sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=['d2A', 'd2B', 'd2C'],
|
||||
dense_shape=(2, 2)),
|
||||
})
|
||||
id_weight_pair = crossed2._get_sparse_tensors(builder)
|
||||
with _initialized_session():
|
||||
id_tensor_eval = id_weight_pair.id_tensor.eval()
|
||||
self.assertAllEqual(
|
||||
((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
|
||||
(1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13),
|
||||
(1, 14), (1, 15)),
|
||||
id_tensor_eval.indices)
|
||||
# Check exact hashed output. If hashing changes this test will break.
|
||||
# All values are within [0, hash_bucket_size).
|
||||
expected_values = (
|
||||
6, 14, 0, 13, 8, 8, 10, 12, 2, 0, 1, 9, 8, 12, 2, 0, 10, 11)
|
||||
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
||||
self.assertAllEqual((2, 16), id_tensor_eval.dense_shape)
|
||||
|
||||
def test_get_sparse_tensors_simple(self):
|
||||
"""Same as test_get_sparse_tensors, but with simpler values."""
|
||||
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||
crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
|
||||
with ops.Graph().as_default():
|
||||
builder = fc._LazyBuilder({
|
||||
'a': constant_op.constant(((-1., .5), (.5, 1.))),
|
||||
'c': sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=['cA', 'cB', 'cC'],
|
||||
dense_shape=(2, 2)),
|
||||
})
|
||||
id_weight_pair = crossed._get_sparse_tensors(builder)
|
||||
with _initialized_session():
|
||||
id_tensor_eval = id_weight_pair.id_tensor.eval()
|
||||
self.assertAllEqual(
|
||||
((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3)),
|
||||
id_tensor_eval.indices)
|
||||
# Check exact hashed output. If hashing changes this test will break.
|
||||
# All values are within [0, hash_bucket_size).
|
||||
expected_values = (1, 0, 1, 3, 4, 2)
|
||||
self.assertAllEqual(expected_values, id_tensor_eval.values)
|
||||
self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
|
||||
|
||||
def test_make_linear_model(self):
|
||||
"""Tests make_linear_model.
|
||||
|
||||
Uses data from test_get_sparse_tesnsors_simple.
|
||||
"""
|
||||
a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
|
||||
b = fc.bucketized_column(a, boundaries=(0, 1))
|
||||
crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
|
||||
with ops.Graph().as_default():
|
||||
predictions = fc.make_linear_model({
|
||||
'a': constant_op.constant(((-1., .5), (.5, 1.))),
|
||||
'c': sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=['cA', 'cB', 'cC'],
|
||||
dense_shape=(2, 2)),
|
||||
}, (crossed,))
|
||||
bias = get_linear_model_bias()
|
||||
crossed_var = get_linear_model_column_var(crossed)
|
||||
with _initialized_session() as sess:
|
||||
self.assertAllClose((0.,), bias.eval())
|
||||
self.assertAllClose(
|
||||
((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
|
||||
self.assertAllClose(((0.,), (0.,)), predictions.eval())
|
||||
sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
|
||||
# Expected ids after cross = (1, 0, 1, 3, 4, 2)
|
||||
self.assertAllClose(((3.,), (14.,)), predictions.eval())
|
||||
sess.run(bias.assign((.1,)))
|
||||
self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
|
||||
|
||||
def test_make_linear_model_with_weights(self):
|
||||
class _TestColumnWithWeights(fc._CategoricalColumn):
|
||||
"""Produces sparse IDs and sparse weights."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'test_column'
|
||||
|
||||
@property
|
||||
def _parse_example_config(self):
|
||||
return {
|
||||
self.name: parsing_ops.VarLenFeature(dtypes.int32),
|
||||
'{}_weights'.format(self.name): parsing_ops.VarLenFeature(
|
||||
dtypes.float32),
|
||||
}
|
||||
|
||||
@property
|
||||
def _num_buckets(self):
|
||||
return 5
|
||||
|
||||
def _transform_feature(self, inputs):
|
||||
return (inputs.get(self.name),
|
||||
inputs.get('{}_weights'.format(self.name)))
|
||||
|
||||
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||
trainable=None):
|
||||
"""Populates both id_tensor and weight_tensor."""
|
||||
ids_and_weights = inputs.get(self)
|
||||
return fc._CategoricalColumn.IdWeightPair(
|
||||
id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
|
||||
|
||||
t = _TestColumnWithWeights()
|
||||
crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
|
||||
with ops.Graph().as_default():
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'crossed_column does not support weight_tensor.*{}'.format(t.name)):
|
||||
fc.make_linear_model({
|
||||
t.name: sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=[0, 1, 2],
|
||||
dense_shape=(2, 2)),
|
||||
'{}_weights'.format(t.name): sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=[1., 10., 2.],
|
||||
dense_shape=(2, 2)),
|
||||
'c': sparse_tensor.SparseTensor(
|
||||
indices=((0, 0), (1, 0), (1, 1)),
|
||||
values=['cA', 'cB', 'cC'],
|
||||
dense_shape=(2, 2)),
|
||||
}, (crossed,))
|
||||
|
||||
|
||||
def get_linear_model_bias():
|
||||
with variable_scope.variable_scope('make_linear_model', reuse=True):
|
||||
return variable_scope.get_variable('bias_weights')
|
||||
|
Loading…
Reference in New Issue
Block a user