Replaced deprecated tf.create_partitioned_variables with tf.get_variable

PiperOrigin-RevId: 222963167
This commit is contained in:
Sergei Lebedev 2018-11-27 02:54:53 -08:00 committed by TensorFlower Gardener
parent ac167c39b4
commit b5b6628931
5 changed files with 67 additions and 72 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import itertools
import math
import sys
import numpy as np
@ -36,6 +37,7 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@ -48,11 +50,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
assert num_shards > 0
assert num_shards <= vocab_size
embedding_weights = partitioned_variables.create_partitioned_variables(
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
embedding_weights = list(variable_scope.get_variable(
"embedding_weights",
shape=[vocab_size, embed_dim],
slicing=[num_shards, 1],
initializer=init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=initializer))
for w in embedding_weights:
w.initializer.run()
embedding_weights = [w.eval() for w in embedding_weights]
@ -256,6 +260,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights)
# pylint: disable=invalid-name
def local_variable_scope():
"""Create a variable scope named like the caller function."""
return variable_scope.variable_scope(sys._getframe(1).f_code.co_name)
# pylint: enable=invalid-name
class ScatteredEmbeddingLookupTest(test.TestCase):
def setUp(self):
@ -266,17 +277,18 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
assert num_shards > 0
assert num_shards <= size
embedding_weights = partitioned_variables.create_partitioned_variables(
embedding_weights = list(variable_scope.get_variable(
"embedding_weights",
shape=[size],
slicing=[num_shards],
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0, dtype=dtypes.float32))
mean=0.0, stddev=1.0, dtype=dtypes.float32)))
for w in embedding_weights:
w.initializer.run()
return embedding_weights
def test_scattered_embedding_consistency(self):
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
@ -288,7 +300,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1])
def test_scattered_embedding_multiple_partition(self):
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights(num_shards=7)
values = constant_op.constant([4, 4, 5])
@ -304,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self):
with self.cached_session():
with self.cached_session(), local_variable_scope():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = constant_op.constant(["foo"])
@ -316,7 +328,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_scattered_embedding_multi_dimension(self):
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@ -329,7 +341,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][2])
def test_scattered_embedding_lookup_sparse(self):
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=["foo", "bar", "foo", "bar"],
@ -358,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape)
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedded_np = embeds[idx]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@ -370,7 +382,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
idx = np.random.randint(0, 5, 10)
idx2d = np.random.randint(0, 5, (10, 2))
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedded_np = embeds[idx]
embedded_np2d = embeds[idx2d]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@ -398,17 +410,18 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
assert num_shards > 0
assert num_shards <= size
embedding_weights = partitioned_variables.create_partitioned_variables(
embedding_weights = list(variable_scope.get_variable(
"embedding_weights",
shape=[size],
slicing=[num_shards],
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0, dtype=dtypes.float32))
mean=0.0, stddev=1.0, dtype=dtypes.float32)))
for w in embedding_weights:
w.initializer.run()
return embedding_weights
def test_hashed_embedding_consistency(self):
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
# The first three sampled_candidates are equal, so the first three
@ -429,7 +442,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][3])
def test_hashed_embedding_multi_dimension(self):
with self.cached_session():
with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])

View File

@ -1069,17 +1069,14 @@ def _create_partitioned_variables(name,
'As TPU embedding is not optimized for small tables, '
'please consider other ways for this embedding lookup.')
slicing = [num_hosts, 1]
# TODO(shizhiw): deprecated, use tf.get_variable()?
return partitioned_variables.create_partitioned_variables(
name=name,
slicing=slicing,
return list(variable_scope.get_variable(
name,
shape=(vocabulary_size, embedding_dimension),
partitioner=partitioned_variables.fixed_size_partitioner(num_hosts),
dtype=dtypes.float32,
initializer=initializer,
collections=collections,
trainable=False)
trainable=False))
@ops.RegisterGradient('TPUEmbeddingActivations')

View File

@ -758,11 +758,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
assert num_shards > 0
assert num_shards <= vocab_size
embedding_weights = partitioned_variables.create_partitioned_variables(
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
embedding_weights = list(variable_scope.get_variable(
name="embedding_weights",
shape=[vocab_size, embed_dim],
slicing=[num_shards, 1],
initializer=init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=initializer))
for w in embedding_weights:
w.initializer.run()
embedding_weights = [w.eval() for w in embedding_weights]

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
@ -44,8 +45,12 @@ class SaverLargePartitionedVariableTest(test.TestCase):
# split into smaller sized variables.
init = lambda shape, dtype, partition_info: constant_op.constant(
True, dtype, shape)
partitioned_var = partitioned_variables.create_partitioned_variables(
[1 << 31], [4], init, dtype=dtypes.bool, name=var_name)
partitioned_var = list(variable_scope.get_variable(
var_name,
shape=[1 << 31],
partitioner=partitioned_variables.fixed_size_partitioner(4),
initializer=init,
dtype=dtypes.bool))
variables.global_variables_initializer().run()
save = saver.Saver(partitioned_var)
val = save.save(sess, save_path)

View File

@ -998,19 +998,12 @@ class SaveRestoreShardedTest(test.TestCase):
call_saver_with_dict = False # updated by test loop below
def _save(slices=None, partitioner=None):
def _save(partitioner=None):
with self.session(graph=ops_lib.Graph()) as sess:
# Calls .eval() to return the ndarray that makes up the full variable.
rnd = random_ops.random_uniform(var_full_shape).eval()
if slices:
assert not partitioner
# TODO(apassos): make create_partitioned_variables take use_resource
# option to make this test passable without creating a named
# variable_scope.
vs = partitioned_variables.create_partitioned_variables(
var_full_shape, slices, rnd, name=var_name)
elif partitioner:
if partitioner:
vs = [
variable_scope.get_variable(
var_name,
@ -1027,7 +1020,7 @@ class SaveRestoreShardedTest(test.TestCase):
variables.global_variables_initializer().run()
if call_saver_with_dict:
saver = saver_module.Saver({var_name: (vs if slices else vs[0])})
saver = saver_module.Saver({var_name: vs[0]})
else:
saver = saver_module.Saver(vs)
actual_path = saver.save(sess, saved_path)
@ -1035,16 +1028,9 @@ class SaveRestoreShardedTest(test.TestCase):
return rnd
def _restore(slices=None, partitioner=None):
def _restore(partitioner=None):
with self.session(graph=ops_lib.Graph()) as sess:
if slices:
assert not partitioner
new_vs = partitioned_variables.create_partitioned_variables(
var_full_shape,
slices,
array_ops.zeros(var_full_shape), # != original contents.
name=var_name)
elif partitioner:
if partitioner:
new_vs = [
variable_scope.get_variable(
var_name,
@ -1063,7 +1049,7 @@ class SaveRestoreShardedTest(test.TestCase):
variables.global_variables_initializer().run()
if call_saver_with_dict:
saver = saver_module.Saver({
var_name: (new_vs if slices else new_vs[0])
var_name: new_vs[0]
})
else:
saver = saver_module.Saver(new_vs)
@ -1071,11 +1057,7 @@ class SaveRestoreShardedTest(test.TestCase):
if partitioner:
return new_vs[0].as_tensor().eval()
elif slices and slices[0] != 1:
return array_ops.concat(new_vs, 0).eval()
elif slices and slices[1] != 1:
return array_ops.concat(new_vs, 1).eval()
else: # Non-sliced.
else:
return new_vs[0].eval()
for call_saver_with_dict in {False, True}:
@ -1086,27 +1068,23 @@ class SaveRestoreShardedTest(test.TestCase):
restored_full = _restore()
self.assertAllEqual(saved_full, restored_full)
# Saves 10 horizontal parts of a partitioned variable.
# Restores into a full variable, non-sliced.
saved_full = _save(slices=[10, 1])
restored_full = _restore()
self.assertAllEqual(saved_full, restored_full)
# Restores into a different number/orientation of slices.
restored_full = _restore(slices=[2, 1]) # 2 horizon parts.
self.assertAllEqual(saved_full, restored_full)
restored_full = _restore(slices=[1, 3]) # 3 vertical parts.
self.assertAllEqual(saved_full, restored_full)
# Restores into a PartitionedVariable
# Restores into the same number of partitions.
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=2))
self.assertAllEqual(saved_full, restored_full)
# Now, saves a full variable and restores in slices.
# Restores into a different number of partitions.
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=3))
self.assertAllEqual(saved_full, restored_full)
# Now, saves a full variable and restores PartitionedVariable.
saved_full = _save()
restored_full = _restore(slices=[1, 3])
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=3))
self.assertAllEqual(saved_full, restored_full)
def testPartitionedVariable(self):