Change not exposed in public API

PiperOrigin-RevId: 324263999
Change-Id: I7f776a5ba2e735f0bf6695dc0ad79801ee15fca7
This commit is contained in:
Amy Skerry-Ryan 2020-07-31 12:41:43 -07:00 committed by TensorFlower Gardener
parent b1da7fd091
commit 56aa1b17ed
3 changed files with 128 additions and 17 deletions

View File

@ -372,10 +372,12 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
trainable=True,
max_sequence_length=0,
learning_rate_fn=None,
use_safe_embedding_lookup=True):
use_safe_embedding_lookup=True,
bypass_scope_validation=False):
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
# are not supported on TPU. They are solely for matching the signature of
# __new__ of parent class fc._EmbeddingColumn.
del bypass_scope_validation
return fc._EmbeddingColumn.__new__(
cls,
categorical_column,
@ -399,13 +401,18 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
trainable=True,
max_sequence_length=0,
learning_rate_fn=None,
use_safe_embedding_lookup=True):
use_safe_embedding_lookup=True,
bypass_scope_validation=False):
_TPUBaseEmbeddingColumn.__init__(
self,
categorical_column,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn)
self._key = None
# If true, scope validation is skipped to allow the same column to be used
# in multiple variable scopes. By default, this is False, and we expect a
# 1:1 mapping between feature columns and scopes.
self._bypass_scope_validation = bypass_scope_validation
def get_combiner(self):
return self.combiner
@ -459,8 +466,10 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor = inputs.get(self.get_feature_key_name())
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(self.get_embedding_var_name(),
'embedding_weights')
_record_variable_scope_and_name(
self.get_embedding_var_name(),
'embedding_weights',
bypass_scope_validation=self._bypass_scope_validation)
return tensor
@ -484,8 +493,10 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(self.get_embedding_var_name(),
'embedding_weights')
_record_variable_scope_and_name(
self.get_embedding_var_name(),
'embedding_weights',
bypass_scope_validation=self._bypass_scope_validation)
return fc._SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=tensor, sequence_length=tensor_lengths)
@ -627,7 +638,8 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
def _record_variable_scope_and_name(embedding_var_name,
embedding_var_name_in_fc,
is_shared_embedding=False):
is_shared_embedding=False,
bypass_scope_validation=False):
"""Add embedding variable name and scope to collection."""
g = ops.get_default_graph()
collection = g.get_collection_ref(_TPU_FC_TO_SCOPE)
@ -640,8 +652,8 @@ def _record_variable_scope_and_name(embedding_var_name,
captured_scope_name = captured_scope.name
if embedding_var_name in var_def_dict:
if (var_def_dict[embedding_var_name][0] != captured_scope_name
and not is_shared_embedding):
if (var_def_dict[embedding_var_name][0] != captured_scope_name and
not is_shared_embedding and not bypass_scope_validation):
raise ValueError(
'For embedding var name {}, the variable scope name is different, '
'got {}; expected {}'.format(embedding_var_name,

View File

@ -427,7 +427,9 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
initializer=None,
max_sequence_length=0,
learning_rate_fn=None,
use_safe_embedding_lookup=True):
use_safe_embedding_lookup=True,
bypass_scope_validation=False):
del bypass_scope_validation
return fc_lib.EmbeddingColumn.__new__(
cls,
categorical_column,
@ -455,13 +457,18 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
initializer=None,
max_sequence_length=0,
learning_rate_fn=None,
use_safe_embedding_lookup=True):
use_safe_embedding_lookup=True,
bypass_scope_validation=False):
_TPUBaseEmbeddingColumn.__init__(
self,
categorical_column,
max_sequence_length=max_sequence_length,
learning_rate_fn=learning_rate_fn)
self._key = None
# If true, scope validation is skipped to allow the same column to be used
# in multiple variable scopes. By default, this is False, and we expect a
# 1:1 mapping between feature columns and scopes.
self._bypass_scope_validation = bypass_scope_validation
def get_combiner(self):
return self.combiner
@ -515,8 +522,10 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
tensor = inputs.get(self.get_feature_key_name())
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(self.get_embedding_var_name(),
'embedding_weights')
_record_variable_scope_and_name(
self.get_embedding_var_name(),
'embedding_weights',
bypass_scope_validation=self._bypass_scope_validation)
return tensor
@ -528,8 +537,10 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
# Create state is called for the EmbeddingColumn to create its embedding
# variables under feature column V2, if we are on TPU so record the scope
# here.
_record_variable_scope_and_name(self.get_embedding_var_name(),
'embedding_weights')
_record_variable_scope_and_name(
self.get_embedding_var_name(),
'embedding_weights',
bypass_scope_validation=self._bypass_scope_validation)
def get_dense_tensor(self, transformation_cache, state_manager):
if tpu.under_tpu_inference_context():
@ -569,8 +580,10 @@ class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn):
tensor_lengths = array_ops.squeeze(tensor_lengths, -1)
# Add to collection for _create_tpu_embedding_variables_and_ops
_record_variable_scope_and_name(self.get_embedding_var_name(),
'embedding_weights')
_record_variable_scope_and_name(
self.get_embedding_var_name(),
'embedding_weights',
bypass_scope_validation=self._bypass_scope_validation)
return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair(
dense_tensor=tensor, sequence_length=tensor_lengths)

View File

@ -28,8 +28,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.tpu import feature_column_v2 as tpu_fc
@ -44,6 +46,40 @@ def _initialized_session():
return sess
class _TestStateManager(fc_lib.StateManager):
def __init__(self, trainable=True):
self._all_variables = {}
self._trainable = trainable
def create_variable(self,
feature_column,
name,
shape,
dtype=None,
trainable=True,
use_resource=True,
initializer=None):
if feature_column not in self._all_variables:
self._all_variables[feature_column] = {}
var_dict = self._all_variables[feature_column]
if name in var_dict:
return var_dict[name]
else:
var = variable_scope.get_variable(
name=name,
shape=shape,
dtype=dtype,
trainable=self._trainable and trainable,
use_resource=use_resource,
initializer=initializer)
var_dict[name] = var
return var
def get_variable(self, feature_column, name):
return self._all_variables[feature_column][name]
class EmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):
def test_defaults(self):
@ -193,6 +229,56 @@ class EmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):
self.assertEqual(embedding_column._max_sequence_length,
embedding_column_copy._max_sequence_length)
def test_with_scope_validation(self):
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=3)
embedding_dimension = 2
initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5)
embedding_column = tpu_fc._TPUEmbeddingColumnV2(
categorical_column=categorical_column,
dimension=embedding_dimension,
combiner='mean',
initializer=initializer,
max_sequence_length=0,
learning_rate_fn=None,
use_safe_embedding_lookup=True,
bypass_scope_validation=False)
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
state_manager = _TestStateManager()
with tpu_function.tpu_shard_context(1):
with variable_scope.variable_scope('tower1/scope1'):
embedding_column.create_state(state_manager)
with variable_scope.variable_scope('tower2/scope2'):
# With default scope validation, the same column cannot be used in a new
# variable scope.
with self.assertRaisesRegex(ValueError,
'the variable scope name is different'):
embedding_column.create_state(state_manager)
def test_bypass_scope_validation(self):
categorical_column = fc_lib.categorical_column_with_identity(
key='aaa', num_buckets=3)
embedding_dimension = 2
initializer = init_ops.truncated_normal_initializer(mean=0.0, stddev=.5)
embedding_column = tpu_fc._TPUEmbeddingColumnV2(
categorical_column=categorical_column,
dimension=embedding_dimension,
combiner='mean',
initializer=initializer,
max_sequence_length=0,
learning_rate_fn=None,
use_safe_embedding_lookup=True,
bypass_scope_validation=True)
self.assertIs(categorical_column, embedding_column.categorical_column)
self.assertEqual(embedding_dimension, embedding_column.dimension)
state_manager = _TestStateManager()
with tpu_function.tpu_shard_context(1):
with variable_scope.variable_scope('tower1/scope1'):
embedding_column.create_state(state_manager)
with variable_scope.variable_scope('tower2/scope2'):
embedding_column.create_state(state_manager)
class SharedEmbeddingColumnTestV2(test.TestCase, parameterized.TestCase):