Replaced deprecated tf.create_partitioned_variables with tf.get_variable

PiperOrigin-RevId: 222963167
This commit is contained in:
Sergei Lebedev 2018-11-27 02:54:53 -08:00 committed by TensorFlower Gardener
parent ac167c39b4
commit b5b6628931
5 changed files with 67 additions and 72 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import itertools import itertools
import math import math
import sys
import numpy as np import numpy as np
@ -36,6 +37,7 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -48,11 +50,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
assert num_shards > 0 assert num_shards > 0
assert num_shards <= vocab_size assert num_shards <= vocab_size
embedding_weights = partitioned_variables.create_partitioned_variables( initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
embedding_weights = list(variable_scope.get_variable(
"embedding_weights",
shape=[vocab_size, embed_dim], shape=[vocab_size, embed_dim],
slicing=[num_shards, 1], partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer( initializer=initializer))
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
for w in embedding_weights: for w in embedding_weights:
w.initializer.run() w.initializer.run()
embedding_weights = [w.eval() for w in embedding_weights] embedding_weights = [w.eval() for w in embedding_weights]
@ -256,6 +260,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights) embedding_weights, sparse_ids, sparse_weights)
# pylint: disable=invalid-name
def local_variable_scope():
"""Create a variable scope named like the caller function."""
return variable_scope.variable_scope(sys._getframe(1).f_code.co_name)
# pylint: enable=invalid-name
class ScatteredEmbeddingLookupTest(test.TestCase): class ScatteredEmbeddingLookupTest(test.TestCase):
def setUp(self): def setUp(self):
@ -266,17 +277,18 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
assert num_shards > 0 assert num_shards > 0
assert num_shards <= size assert num_shards <= size
embedding_weights = partitioned_variables.create_partitioned_variables( embedding_weights = list(variable_scope.get_variable(
"embedding_weights",
shape=[size], shape=[size],
slicing=[num_shards], partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer( initializer=init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0, dtype=dtypes.float32)) mean=0.0, stddev=1.0, dtype=dtypes.float32)))
for w in embedding_weights: for w in embedding_weights:
w.initializer.run() w.initializer.run()
return embedding_weights return embedding_weights
def test_scattered_embedding_consistency(self): def test_scattered_embedding_consistency(self):
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights() embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"]) values = constant_op.constant(["foo", "foo"])
@ -288,7 +300,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1]) embedding_lookup_result[1])
def test_scattered_embedding_multiple_partition(self): def test_scattered_embedding_multiple_partition(self):
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights(num_shards=7) embedding_weights = self._random_weights(num_shards=7)
values = constant_op.constant([4, 4, 5]) values = constant_op.constant([4, 4, 5])
@ -304,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertGreater(embedding_diff, 0) self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self): def test_scattered_embedding_coverage(self):
with self.cached_session(): with self.cached_session(), local_variable_scope():
size = 8 size = 8
embedding_weights = self._random_weights(size=size, num_shards=3) embedding_weights = self._random_weights(size=size, num_shards=3)
values = constant_op.constant(["foo"]) values = constant_op.constant(["foo"])
@ -316,7 +328,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size) self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_scattered_embedding_multi_dimension(self): def test_scattered_embedding_multi_dimension(self):
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights() embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"], values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]]) ["bar", "bar", "foo"]])
@ -329,7 +341,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][2]) embedding_lookup_result[1][2])
def test_scattered_embedding_lookup_sparse(self): def test_scattered_embedding_lookup_sparse(self):
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights(num_shards=3) embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = sparse_tensor_lib.SparseTensor( sparse_tensor = sparse_tensor_lib.SparseTensor(
values=["foo", "bar", "foo", "bar"], values=["foo", "bar", "foo", "bar"],
@ -358,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embeds = np.random.randn(n_embed, d_embed) embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape) idx = np.random.randint(0, n_embed, idx_shape)
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedded_np = embeds[idx] embedded_np = embeds[idx]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@ -370,7 +382,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
idx = np.random.randint(0, 5, 10) idx = np.random.randint(0, 5, 10)
idx2d = np.random.randint(0, 5, (10, 2)) idx2d = np.random.randint(0, 5, (10, 2))
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedded_np = embeds[idx] embedded_np = embeds[idx]
embedded_np2d = embeds[idx2d] embedded_np2d = embeds[idx2d]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@ -398,17 +410,18 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
assert num_shards > 0 assert num_shards > 0
assert num_shards <= size assert num_shards <= size
embedding_weights = partitioned_variables.create_partitioned_variables( embedding_weights = list(variable_scope.get_variable(
"embedding_weights",
shape=[size], shape=[size],
slicing=[num_shards], partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer( initializer=init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0, dtype=dtypes.float32)) mean=0.0, stddev=1.0, dtype=dtypes.float32)))
for w in embedding_weights: for w in embedding_weights:
w.initializer.run() w.initializer.run()
return embedding_weights return embedding_weights
def test_hashed_embedding_consistency(self): def test_hashed_embedding_consistency(self):
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights() embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"]) values = constant_op.constant(["foo", "foo"])
# The first three sampled_candidates are equal, so the first three # The first three sampled_candidates are equal, so the first three
@ -429,7 +442,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][3]) embedding_lookup_result[1][3])
def test_hashed_embedding_multi_dimension(self): def test_hashed_embedding_multi_dimension(self):
with self.cached_session(): with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights() embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"], values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]]) ["bar", "bar", "foo"]])

View File

@ -1069,17 +1069,14 @@ def _create_partitioned_variables(name,
'As TPU embedding is not optimized for small tables, ' 'As TPU embedding is not optimized for small tables, '
'please consider other ways for this embedding lookup.') 'please consider other ways for this embedding lookup.')
slicing = [num_hosts, 1] return list(variable_scope.get_variable(
name,
# TODO(shizhiw): deprecated, use tf.get_variable()?
return partitioned_variables.create_partitioned_variables(
name=name,
slicing=slicing,
shape=(vocabulary_size, embedding_dimension), shape=(vocabulary_size, embedding_dimension),
partitioner=partitioned_variables.fixed_size_partitioner(num_hosts),
dtype=dtypes.float32, dtype=dtypes.float32,
initializer=initializer, initializer=initializer,
collections=collections, collections=collections,
trainable=False) trainable=False))
@ops.RegisterGradient('TPUEmbeddingActivations') @ops.RegisterGradient('TPUEmbeddingActivations')

View File

@ -758,11 +758,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
assert num_shards > 0 assert num_shards > 0
assert num_shards <= vocab_size assert num_shards <= vocab_size
embedding_weights = partitioned_variables.create_partitioned_variables( initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
embedding_weights = list(variable_scope.get_variable(
name="embedding_weights",
shape=[vocab_size, embed_dim], shape=[vocab_size, embed_dim],
slicing=[num_shards, 1], partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer( initializer=initializer))
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
for w in embedding_weights: for w in embedding_weights:
w.initializer.run() w.initializer.run()
embedding_weights = [w.eval() for w in embedding_weights] embedding_weights = [w.eval() for w in embedding_weights]

View File

@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import saver from tensorflow.python.training import saver
@ -44,8 +45,12 @@ class SaverLargePartitionedVariableTest(test.TestCase):
# split into smaller sized variables. # split into smaller sized variables.
init = lambda shape, dtype, partition_info: constant_op.constant( init = lambda shape, dtype, partition_info: constant_op.constant(
True, dtype, shape) True, dtype, shape)
partitioned_var = partitioned_variables.create_partitioned_variables( partitioned_var = list(variable_scope.get_variable(
[1 << 31], [4], init, dtype=dtypes.bool, name=var_name) var_name,
shape=[1 << 31],
partitioner=partitioned_variables.fixed_size_partitioner(4),
initializer=init,
dtype=dtypes.bool))
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
save = saver.Saver(partitioned_var) save = saver.Saver(partitioned_var)
val = save.save(sess, save_path) val = save.save(sess, save_path)

View File

@ -998,19 +998,12 @@ class SaveRestoreShardedTest(test.TestCase):
call_saver_with_dict = False # updated by test loop below call_saver_with_dict = False # updated by test loop below
def _save(slices=None, partitioner=None): def _save(partitioner=None):
with self.session(graph=ops_lib.Graph()) as sess: with self.session(graph=ops_lib.Graph()) as sess:
# Calls .eval() to return the ndarray that makes up the full variable. # Calls .eval() to return the ndarray that makes up the full variable.
rnd = random_ops.random_uniform(var_full_shape).eval() rnd = random_ops.random_uniform(var_full_shape).eval()
if slices: if partitioner:
assert not partitioner
# TODO(apassos): make create_partitioned_variables take use_resource
# option to make this test passable without creating a named
# variable_scope.
vs = partitioned_variables.create_partitioned_variables(
var_full_shape, slices, rnd, name=var_name)
elif partitioner:
vs = [ vs = [
variable_scope.get_variable( variable_scope.get_variable(
var_name, var_name,
@ -1027,7 +1020,7 @@ class SaveRestoreShardedTest(test.TestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
if call_saver_with_dict: if call_saver_with_dict:
saver = saver_module.Saver({var_name: (vs if slices else vs[0])}) saver = saver_module.Saver({var_name: vs[0]})
else: else:
saver = saver_module.Saver(vs) saver = saver_module.Saver(vs)
actual_path = saver.save(sess, saved_path) actual_path = saver.save(sess, saved_path)
@ -1035,16 +1028,9 @@ class SaveRestoreShardedTest(test.TestCase):
return rnd return rnd
def _restore(slices=None, partitioner=None): def _restore(partitioner=None):
with self.session(graph=ops_lib.Graph()) as sess: with self.session(graph=ops_lib.Graph()) as sess:
if slices: if partitioner:
assert not partitioner
new_vs = partitioned_variables.create_partitioned_variables(
var_full_shape,
slices,
array_ops.zeros(var_full_shape), # != original contents.
name=var_name)
elif partitioner:
new_vs = [ new_vs = [
variable_scope.get_variable( variable_scope.get_variable(
var_name, var_name,
@ -1063,7 +1049,7 @@ class SaveRestoreShardedTest(test.TestCase):
variables.global_variables_initializer().run() variables.global_variables_initializer().run()
if call_saver_with_dict: if call_saver_with_dict:
saver = saver_module.Saver({ saver = saver_module.Saver({
var_name: (new_vs if slices else new_vs[0]) var_name: new_vs[0]
}) })
else: else:
saver = saver_module.Saver(new_vs) saver = saver_module.Saver(new_vs)
@ -1071,11 +1057,7 @@ class SaveRestoreShardedTest(test.TestCase):
if partitioner: if partitioner:
return new_vs[0].as_tensor().eval() return new_vs[0].as_tensor().eval()
elif slices and slices[0] != 1: else:
return array_ops.concat(new_vs, 0).eval()
elif slices and slices[1] != 1:
return array_ops.concat(new_vs, 1).eval()
else: # Non-sliced.
return new_vs[0].eval() return new_vs[0].eval()
for call_saver_with_dict in {False, True}: for call_saver_with_dict in {False, True}:
@ -1086,27 +1068,23 @@ class SaveRestoreShardedTest(test.TestCase):
restored_full = _restore() restored_full = _restore()
self.assertAllEqual(saved_full, restored_full) self.assertAllEqual(saved_full, restored_full)
# Saves 10 horizontal parts of a partitioned variable. # Restores into the same number of partitions.
# Restores into a full variable, non-sliced.
saved_full = _save(slices=[10, 1])
restored_full = _restore()
self.assertAllEqual(saved_full, restored_full)
# Restores into a different number/orientation of slices.
restored_full = _restore(slices=[2, 1]) # 2 horizon parts.
self.assertAllEqual(saved_full, restored_full)
restored_full = _restore(slices=[1, 3]) # 3 vertical parts.
self.assertAllEqual(saved_full, restored_full)
# Restores into a PartitionedVariable
restored_full = _restore( restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner( partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=2)) num_shards=2))
self.assertAllEqual(saved_full, restored_full) self.assertAllEqual(saved_full, restored_full)
# Now, saves a full variable and restores in slices. # Restores into a different number of partitions.
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=3))
self.assertAllEqual(saved_full, restored_full)
# Now, saves a full variable and restores PartitionedVariable.
saved_full = _save() saved_full = _save()
restored_full = _restore(slices=[1, 3]) restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=3))
self.assertAllEqual(saved_full, restored_full) self.assertAllEqual(saved_full, restored_full)
def testPartitionedVariable(self): def testPartitionedVariable(self):