Change not exposed in public API
PiperOrigin-RevId: 324263999 Change-Id: I7f776a5ba2e735f0bf6695dc0ad79801ee15fca7
This commit is contained in:
parent
b1da7fd091
commit
56aa1b17ed
tensorflow/python/tpu
@ -372,10 +372,12 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
trainable=True,
|
||||
max_sequence_length=0,
|
||||
learning_rate_fn=None,
|
||||
use_safe_embedding_lookup=True):
|
||||
use_safe_embedding_lookup=True,
|
||||
bypass_scope_validation=False):
|
||||
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
|
||||
# are not supported on TPU. They are solely for matching the signature of
|
||||
# __new__ of parent class fc._EmbeddingColumn.
|
||||
del bypass_scope_validation
|
||||
return fc._EmbeddingColumn.__new__(
|
||||
cls,
|
||||
categorical_column,
|
||||
@ -399,13 +401,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
trainable=True,
|
||||
max_sequence_length=0,
|
||||
learning_rate_fn=None,
|
||||
use_safe_embedding_lookup=True):
|
||||
use_safe_embedding_lookup=True,
|
||||
bypass_scope_validation=False):
|
||||
_TPUBaseEmbeddingColumn.__init__(
|
||||
self,
|
||||
categorical_column,
|
||||
max_sequence_length=max_sequence_length,
|
||||
learning_rate_fn=learning_rate_fn)
|
||||
self._key = None
|
||||
# If true, scope validation is skipped to allow the same column to be used
|
||||
# in multiple variable scopes. By default, this is False, and we expect a
|
||||
# 1:1 mapping between feature columns and scopes.
|
||||
self._bypass_scope_validation = bypass_scope_validation
|
||||
|
||||
def get_combiner(self):
|
||||
return self.combiner
|
||||
@ -459,8 +466,10 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
tensor = inputs.get(self.get_feature_key_name())
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(self.get_embedding_var_name(),
|
||||
'embedding_weights')
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
bypass_scope_validation=self._bypass_scope_validation)
|
||||
|
||||
return tensor
|
||||
|
||||
@ -484,8 +493,10 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
|
||||
tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(self.get_embedding_var_name(),
|
||||
'embedding_weights')
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
bypass_scope_validation=self._bypass_scope_validation)
|
||||
|
||||
return fc._SequenceDenseColumn.TensorSequenceLengthPair(
|
||||
dense_tensor=tensor, sequence_length=tensor_lengths)
|
||||
@ -627,7 +638,8 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
|
||||
|
||||
def _record_variable_scope_and_name(embedding_var_name,
|
||||
embedding_var_name_in_fc,
|
||||
is_shared_embedding=False):
|
||||
is_shared_embedding=False,
|
||||
bypass_scope_validation=False):
|
||||
"""Add embedding variable name and scope to collection."""
|
||||
g = ops.get_default_graph()
|
||||
collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
|
||||
@ -640,8 +652,8 @@ def _record_variable_scope_and_name(embedding_var_name,
|
||||
captured_scope_name = captured_scope.name
|
||||
|
||||
if embedding_var_name in var_def_dict:
|
||||
if (var_def_dict[embedding_var_name][0] != captured_scope_name
|
||||
and not is_shared_embedding):
|
||||
if (var_def_dict[embedding_var_name][0] != captured_scope_name and
|
||||
not is_shared_embedding and not bypass_scope_validation):
|
||||
raise ValueError(
|
||||
'For embedding var name {}, the variable scope name is different, '
|
||||
'got {}; expected {}'.format(embedding_var_name,
|
||||
|
@ -427,7 +427,9 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
|
||||
initializer=None,
|
||||
max_sequence_length=0,
|
||||
learning_rate_fn=None,
|
||||
use_safe_embedding_lookup=True):
|
||||
use_safe_embedding_lookup=True,
|
||||
bypass_scope_validation=False):
|
||||
del bypass_scope_validation
|
||||
return fc_lib.EmbeddingColumn.__new__(
|
||||
cls,
|
||||
categorical_column,
|
||||
@ -455,13 +457,18 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
|
||||
initializer=None,
|
||||
max_sequence_length=0,
|
||||
learning_rate_fn=None,
|
||||
use_safe_embedding_lookup=True):
|
||||
use_safe_embedding_lookup=True,
|
||||
bypass_scope_validation=False):
|
||||
_TPUBaseEmbeddingColumn.__init__(
|
||||
self,
|
||||
categorical_column,
|
||||
max_sequence_length=max_sequence_length,
|
||||
learning_rate_fn=learning_rate_fn)
|
||||
self._key = None
|
||||
# If true, scope validation is skipped to allow the same column to be used
|
||||
# in multiple variable scopes. By default, this is False, and we expect a
|
||||
# 1:1 mapping between feature columns and scopes.
|
||||
self._bypass_scope_validation = bypass_scope_validation
|
||||
|
||||
def get_combiner(self):
|
||||
return self.combiner
|
||||
@ -515,8 +522,10 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
|
||||
tensor = inputs.get(self.get_feature_key_name())
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(self.get_embedding_var_name(),
|
||||
'embedding_weights')
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
bypass_scope_validation=self._bypass_scope_validation)
|
||||
|
||||
return tensor
|
||||
|
||||
@ -528,8 +537,10 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
|
||||
# Create state is called for the EmbeddingColumn to create its embedding
|
||||
# variables under feature column V2, if we are on TPU so record the scope
|
||||
# here.
|
||||
_record_variable_scope_and_name(self.get_embedding_var_name(),
|
||||
'embedding_weights')
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
bypass_scope_validation=self._bypass_scope_validation)
|
||||
|
||||
def get_dense_tensor(self, transformation_cache, state_manager):
|
||||
if tpu.under_tpu_inference_context():
|
||||
@ -569,8 +580,10 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
|
||||
tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
|
||||
|
||||
# Add to collection for _create_tpu_embedding_variables_and_ops
|
||||
_record_variable_scope_and_name(self.get_embedding_var_name(),
|
||||
'embedding_weights')
|
||||
_record_variable_scope_and_name(
|
||||
self.get_embedding_var_name(),
|
||||
'embedding_weights',
|
||||
bypass_scope_validation=self._bypass_scope_validation)
|
||||
|
||||
return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
|
||||
dense_tensor=tensor, sequence_length=tensor_lengths)
|
||||
|
@ -28,8 +28,10 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.tpu import feature_column_v2 as tpu_fc
|
||||
@ -44,6 +46,40 @@ def _initialized_session():
|
||||
return sess
|
||||
|
||||
|
||||
class _TestStateManager(fc_lib.StateManager):
|
||||
|
||||
def __init__(self, trainable=True):
|
||||
self._all_variables = {}
|
||||
self._trainable = trainable
|
||||
|
||||
def create_variable(self,
|
||||
feature_column,
|
||||
name,
|
||||
shape,
|
||||
dtype=None,
|
||||
trainable=True,
|
||||
use_resource=True,
|
||||
initializer=None):
|
||||
if feature_column not in self._all_variables:
|
||||
self._all_variables[feature_column] = {}
|
||||
var_dict = self._all_variables[feature_column]
|
||||
if name in var_dict:
|
||||
return var_dict[name]
|
||||
else:
|
||||
var = variable_scope.get_variable(
|
||||
name=name,
|
||||
shape=shape,
|
||||
dtype=dtype,
|
||||
trainable=self._trainable and trainable,
|
||||
use_resource=use_resource,
|
||||
initializer=initializer)
|
||||
var_dict[name] = var
|
||||
return var
|
||||
|
||||
def get_variable(self, feature_column, name):
|
||||
return self._all_variables[feature_column][name]
|
||||
|
||||
|
||||
class EmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_defaults(self):
|
||||
@ -193,6 +229,56 @@ class EmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(embedding_column._max_sequence_length,
|
||||
embedding_column_copy._max_sequence_length)
|
||||
|
||||
def test_with_scope_validation(self):
|
||||
categorical_column = fc_lib.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3)
|
||||
embedding_dimension = 2
|
||||
initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5)
|
||||
embedding_column = tpu_fc._TPUEmbeddingColumnV2(
|
||||
categorical_column=categorical_column,
|
||||
dimension=embedding_dimension,
|
||||
combiner='mean',
|
||||
initializer=initializer,
|
||||
max_sequence_length=0,
|
||||
learning_rate_fn=None,
|
||||
use_safe_embedding_lookup=True,
|
||||
bypass_scope_validation=False)
|
||||
self.assertIs(categorical_column, embedding_column.categorical_column)
|
||||
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
||||
state_manager = _TestStateManager()
|
||||
with tpu_function.tpu_shard_context(1):
|
||||
with variable_scope.variable_scope('tower1/scope1'):
|
||||
embedding_column.create_state(state_manager)
|
||||
with variable_scope.variable_scope('tower2/scope2'):
|
||||
# With default scope validation, the same column cannot be used in a new
|
||||
# variable scope.
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'the variable scope name is different'):
|
||||
embedding_column.create_state(state_manager)
|
||||
|
||||
def test_bypass_scope_validation(self):
|
||||
categorical_column = fc_lib.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=3)
|
||||
embedding_dimension = 2
|
||||
initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5)
|
||||
embedding_column = tpu_fc._TPUEmbeddingColumnV2(
|
||||
categorical_column=categorical_column,
|
||||
dimension=embedding_dimension,
|
||||
combiner='mean',
|
||||
initializer=initializer,
|
||||
max_sequence_length=0,
|
||||
learning_rate_fn=None,
|
||||
use_safe_embedding_lookup=True,
|
||||
bypass_scope_validation=True)
|
||||
self.assertIs(categorical_column, embedding_column.categorical_column)
|
||||
self.assertEqual(embedding_dimension, embedding_column.dimension)
|
||||
state_manager = _TestStateManager()
|
||||
with tpu_function.tpu_shard_context(1):
|
||||
with variable_scope.variable_scope('tower1/scope1'):
|
||||
embedding_column.create_state(state_manager)
|
||||
with variable_scope.variable_scope('tower2/scope2'):
|
||||
embedding_column.create_state(state_manager)
|
||||
|
||||
|
||||
class SharedEmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user