Replaced deprecated tf.create_partitioned_variables with tf.get_variable
PiperOrigin-RevId: 222963167
This commit is contained in:
parent
ac167c39b4
commit
b5b6628931
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -36,6 +37,7 @@ from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
@ -48,11 +50,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
assert num_shards > 0
|
||||
assert num_shards <= vocab_size
|
||||
|
||||
embedding_weights = partitioned_variables.create_partitioned_variables(
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
|
||||
embedding_weights = list(variable_scope.get_variable(
|
||||
"embedding_weights",
|
||||
shape=[vocab_size, embed_dim],
|
||||
slicing=[num_shards, 1],
|
||||
initializer=init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
|
||||
initializer=initializer))
|
||||
for w in embedding_weights:
|
||||
w.initializer.run()
|
||||
embedding_weights = [w.eval() for w in embedding_weights]
|
||||
@ -256,6 +260,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
embedding_weights, sparse_ids, sparse_weights)
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def local_variable_scope():
|
||||
"""Create a variable scope named like the caller function."""
|
||||
return variable_scope.variable_scope(sys._getframe(1).f_code.co_name)
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
|
||||
class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -266,17 +277,18 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
assert num_shards > 0
|
||||
assert num_shards <= size
|
||||
|
||||
embedding_weights = partitioned_variables.create_partitioned_variables(
|
||||
embedding_weights = list(variable_scope.get_variable(
|
||||
"embedding_weights",
|
||||
shape=[size],
|
||||
slicing=[num_shards],
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
|
||||
initializer=init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1.0, dtype=dtypes.float32))
|
||||
mean=0.0, stddev=1.0, dtype=dtypes.float32)))
|
||||
for w in embedding_weights:
|
||||
w.initializer.run()
|
||||
return embedding_weights
|
||||
|
||||
def test_scattered_embedding_consistency(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedding_weights = self._random_weights()
|
||||
values = constant_op.constant(["foo", "foo"])
|
||||
|
||||
@ -288,7 +300,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
embedding_lookup_result[1])
|
||||
|
||||
def test_scattered_embedding_multiple_partition(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedding_weights = self._random_weights(num_shards=7)
|
||||
values = constant_op.constant([4, 4, 5])
|
||||
|
||||
@ -304,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
self.assertGreater(embedding_diff, 0)
|
||||
|
||||
def test_scattered_embedding_coverage(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
size = 8
|
||||
embedding_weights = self._random_weights(size=size, num_shards=3)
|
||||
values = constant_op.constant(["foo"])
|
||||
@ -316,7 +328,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
|
||||
|
||||
def test_scattered_embedding_multi_dimension(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedding_weights = self._random_weights()
|
||||
values = constant_op.constant([["foo", "bar", "bar"],
|
||||
["bar", "bar", "foo"]])
|
||||
@ -329,7 +341,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
embedding_lookup_result[1][2])
|
||||
|
||||
def test_scattered_embedding_lookup_sparse(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedding_weights = self._random_weights(num_shards=3)
|
||||
sparse_tensor = sparse_tensor_lib.SparseTensor(
|
||||
values=["foo", "bar", "foo", "bar"],
|
||||
@ -358,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
embeds = np.random.randn(n_embed, d_embed)
|
||||
idx = np.random.randint(0, n_embed, idx_shape)
|
||||
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedded_np = embeds[idx]
|
||||
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
|
||||
|
||||
@ -370,7 +382,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
idx = np.random.randint(0, 5, 10)
|
||||
idx2d = np.random.randint(0, 5, (10, 2))
|
||||
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedded_np = embeds[idx]
|
||||
embedded_np2d = embeds[idx2d]
|
||||
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
|
||||
@ -398,17 +410,18 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
|
||||
assert num_shards > 0
|
||||
assert num_shards <= size
|
||||
|
||||
embedding_weights = partitioned_variables.create_partitioned_variables(
|
||||
embedding_weights = list(variable_scope.get_variable(
|
||||
"embedding_weights",
|
||||
shape=[size],
|
||||
slicing=[num_shards],
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
|
||||
initializer=init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1.0, dtype=dtypes.float32))
|
||||
mean=0.0, stddev=1.0, dtype=dtypes.float32)))
|
||||
for w in embedding_weights:
|
||||
w.initializer.run()
|
||||
return embedding_weights
|
||||
|
||||
def test_hashed_embedding_consistency(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedding_weights = self._random_weights()
|
||||
values = constant_op.constant(["foo", "foo"])
|
||||
# The first three sampled_candidates are equal, so the first three
|
||||
@ -429,7 +442,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
|
||||
embedding_lookup_result[1][3])
|
||||
|
||||
def test_hashed_embedding_multi_dimension(self):
|
||||
with self.cached_session():
|
||||
with self.cached_session(), local_variable_scope():
|
||||
embedding_weights = self._random_weights()
|
||||
values = constant_op.constant([["foo", "bar", "bar"],
|
||||
["bar", "bar", "foo"]])
|
||||
|
@ -1069,17 +1069,14 @@ def _create_partitioned_variables(name,
|
||||
'As TPU embedding is not optimized for small tables, '
|
||||
'please consider other ways for this embedding lookup.')
|
||||
|
||||
slicing = [num_hosts, 1]
|
||||
|
||||
# TODO(shizhiw): deprecated, use tf.get_variable()?
|
||||
return partitioned_variables.create_partitioned_variables(
|
||||
name=name,
|
||||
slicing=slicing,
|
||||
return list(variable_scope.get_variable(
|
||||
name,
|
||||
shape=(vocabulary_size, embedding_dimension),
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(num_hosts),
|
||||
dtype=dtypes.float32,
|
||||
initializer=initializer,
|
||||
collections=collections,
|
||||
trainable=False)
|
||||
trainable=False))
|
||||
|
||||
|
||||
@ops.RegisterGradient('TPUEmbeddingActivations')
|
||||
|
@ -758,11 +758,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
assert num_shards > 0
|
||||
assert num_shards <= vocab_size
|
||||
|
||||
embedding_weights = partitioned_variables.create_partitioned_variables(
|
||||
initializer = init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
|
||||
embedding_weights = list(variable_scope.get_variable(
|
||||
name="embedding_weights",
|
||||
shape=[vocab_size, embed_dim],
|
||||
slicing=[num_shards, 1],
|
||||
initializer=init_ops.truncated_normal_initializer(
|
||||
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
|
||||
initializer=initializer))
|
||||
for w in embedding_weights:
|
||||
w.initializer.run()
|
||||
embedding_weights = [w.eval() for w in embedding_weights]
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver
|
||||
@ -44,8 +45,12 @@ class SaverLargePartitionedVariableTest(test.TestCase):
|
||||
# split into smaller sized variables.
|
||||
init = lambda shape, dtype, partition_info: constant_op.constant(
|
||||
True, dtype, shape)
|
||||
partitioned_var = partitioned_variables.create_partitioned_variables(
|
||||
[1 << 31], [4], init, dtype=dtypes.bool, name=var_name)
|
||||
partitioned_var = list(variable_scope.get_variable(
|
||||
var_name,
|
||||
shape=[1 << 31],
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(4),
|
||||
initializer=init,
|
||||
dtype=dtypes.bool))
|
||||
variables.global_variables_initializer().run()
|
||||
save = saver.Saver(partitioned_var)
|
||||
val = save.save(sess, save_path)
|
||||
|
@ -998,19 +998,12 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
|
||||
call_saver_with_dict = False # updated by test loop below
|
||||
|
||||
def _save(slices=None, partitioner=None):
|
||||
def _save(partitioner=None):
|
||||
with self.session(graph=ops_lib.Graph()) as sess:
|
||||
# Calls .eval() to return the ndarray that makes up the full variable.
|
||||
rnd = random_ops.random_uniform(var_full_shape).eval()
|
||||
|
||||
if slices:
|
||||
assert not partitioner
|
||||
# TODO(apassos): make create_partitioned_variables take use_resource
|
||||
# option to make this test passable without creating a named
|
||||
# variable_scope.
|
||||
vs = partitioned_variables.create_partitioned_variables(
|
||||
var_full_shape, slices, rnd, name=var_name)
|
||||
elif partitioner:
|
||||
if partitioner:
|
||||
vs = [
|
||||
variable_scope.get_variable(
|
||||
var_name,
|
||||
@ -1027,7 +1020,7 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
|
||||
variables.global_variables_initializer().run()
|
||||
if call_saver_with_dict:
|
||||
saver = saver_module.Saver({var_name: (vs if slices else vs[0])})
|
||||
saver = saver_module.Saver({var_name: vs[0]})
|
||||
else:
|
||||
saver = saver_module.Saver(vs)
|
||||
actual_path = saver.save(sess, saved_path)
|
||||
@ -1035,16 +1028,9 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
|
||||
return rnd
|
||||
|
||||
def _restore(slices=None, partitioner=None):
|
||||
def _restore(partitioner=None):
|
||||
with self.session(graph=ops_lib.Graph()) as sess:
|
||||
if slices:
|
||||
assert not partitioner
|
||||
new_vs = partitioned_variables.create_partitioned_variables(
|
||||
var_full_shape,
|
||||
slices,
|
||||
array_ops.zeros(var_full_shape), # != original contents.
|
||||
name=var_name)
|
||||
elif partitioner:
|
||||
if partitioner:
|
||||
new_vs = [
|
||||
variable_scope.get_variable(
|
||||
var_name,
|
||||
@ -1063,7 +1049,7 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
variables.global_variables_initializer().run()
|
||||
if call_saver_with_dict:
|
||||
saver = saver_module.Saver({
|
||||
var_name: (new_vs if slices else new_vs[0])
|
||||
var_name: new_vs[0]
|
||||
})
|
||||
else:
|
||||
saver = saver_module.Saver(new_vs)
|
||||
@ -1071,11 +1057,7 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
|
||||
if partitioner:
|
||||
return new_vs[0].as_tensor().eval()
|
||||
elif slices and slices[0] != 1:
|
||||
return array_ops.concat(new_vs, 0).eval()
|
||||
elif slices and slices[1] != 1:
|
||||
return array_ops.concat(new_vs, 1).eval()
|
||||
else: # Non-sliced.
|
||||
else:
|
||||
return new_vs[0].eval()
|
||||
|
||||
for call_saver_with_dict in {False, True}:
|
||||
@ -1086,27 +1068,23 @@ class SaveRestoreShardedTest(test.TestCase):
|
||||
restored_full = _restore()
|
||||
self.assertAllEqual(saved_full, restored_full)
|
||||
|
||||
# Saves 10 horizontal parts of a partitioned variable.
|
||||
# Restores into a full variable, non-sliced.
|
||||
saved_full = _save(slices=[10, 1])
|
||||
restored_full = _restore()
|
||||
self.assertAllEqual(saved_full, restored_full)
|
||||
|
||||
# Restores into a different number/orientation of slices.
|
||||
restored_full = _restore(slices=[2, 1]) # 2 horizon parts.
|
||||
self.assertAllEqual(saved_full, restored_full)
|
||||
restored_full = _restore(slices=[1, 3]) # 3 vertical parts.
|
||||
self.assertAllEqual(saved_full, restored_full)
|
||||
|
||||
# Restores into a PartitionedVariable
|
||||
# Restores into the same number of partitions.
|
||||
restored_full = _restore(
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(
|
||||
num_shards=2))
|
||||
self.assertAllEqual(saved_full, restored_full)
|
||||
|
||||
# Now, saves a full variable and restores in slices.
|
||||
# Restores into a different number of partitions.
|
||||
restored_full = _restore(
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(
|
||||
num_shards=3))
|
||||
self.assertAllEqual(saved_full, restored_full)
|
||||
|
||||
# Now, saves a full variable and restores PartitionedVariable.
|
||||
saved_full = _save()
|
||||
restored_full = _restore(slices=[1, 3])
|
||||
restored_full = _restore(
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(
|
||||
num_shards=3))
|
||||
self.assertAllEqual(saved_full, restored_full)
|
||||
|
||||
def testPartitionedVariable(self):
|
||||
|
Loading…
Reference in New Issue
Block a user