From b5b6628931a2db98a94c55e90c0df5db5cb4fa5b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 27 Nov 2018 02:54:53 -0800 Subject: [PATCH] Replaced deprecated tf.create_partitioned_variables with tf.get_variable PiperOrigin-RevId: 222963167 --- .../python/layers/embedding_ops_test.py | 51 ++++++++++------ .../contrib/tpu/python/tpu/tpu_embedding.py | 11 ++-- .../python/kernel_tests/embedding_ops_test.py | 10 ++-- .../saver_large_partitioned_variable_test.py | 9 ++- tensorflow/python/training/saver_test.py | 58 ++++++------------- 5 files changed, 67 insertions(+), 72 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index 8015a571e14..295c721fced 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import itertools import math +import sys import numpy as np @@ -36,6 +37,7 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -48,11 +50,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): assert num_shards > 0 assert num_shards <= vocab_size - embedding_weights = partitioned_variables.create_partitioned_variables( + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32) + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[vocab_size, embed_dim], - slicing=[num_shards, 1], - initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)) + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=initializer)) for w in embedding_weights: w.initializer.run() embedding_weights = [w.eval() for w in embedding_weights] @@ -256,6 +260,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): embedding_weights, sparse_ids, sparse_weights) +# pylint: disable=invalid-name +def local_variable_scope(): + """Create a variable scope named like the caller function.""" + return variable_scope.variable_scope(sys._getframe(1).f_code.co_name) +# pylint: enable=invalid-name + + class ScatteredEmbeddingLookupTest(test.TestCase): def setUp(self): @@ -266,17 +277,18 @@ class ScatteredEmbeddingLookupTest(test.TestCase): assert num_shards > 0 assert num_shards <= size - embedding_weights = partitioned_variables.create_partitioned_variables( + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[size], - slicing=[num_shards], + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0, dtype=dtypes.float32)) + mean=0.0, stddev=1.0, dtype=dtypes.float32))) for w in embedding_weights: w.initializer.run() return embedding_weights def test_scattered_embedding_consistency(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) @@ -288,7 +300,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1]) def test_scattered_embedding_multiple_partition(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights(num_shards=7) values = constant_op.constant([4, 4, 5]) @@ -304,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertGreater(embedding_diff, 0) def test_scattered_embedding_coverage(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): size = 8 embedding_weights = self._random_weights(size=size, num_shards=3) values = constant_op.constant(["foo"]) @@ -316,7 +328,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): self.assertEqual(len(np.unique(embedding_lookup_result[0])), size) def test_scattered_embedding_multi_dimension(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) @@ -329,7 +341,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][2]) def test_scattered_embedding_lookup_sparse(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights(num_shards=3) sparse_tensor = sparse_tensor_lib.SparseTensor( values=["foo", "bar", "foo", "bar"], @@ -358,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): embeds = np.random.randn(n_embed, d_embed) idx = np.random.randint(0, n_embed, idx_shape) - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedded_np = embeds[idx] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -370,7 +382,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase): idx = np.random.randint(0, 5, 10) idx2d = np.random.randint(0, 5, (10, 2)) - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedded_np = embeds[idx] embedded_np2d = embeds[idx2d] embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval() @@ -398,17 +410,18 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): assert num_shards > 0 assert num_shards <= size - embedding_weights = partitioned_variables.create_partitioned_variables( + embedding_weights = list(variable_scope.get_variable( + "embedding_weights", shape=[size], - slicing=[num_shards], + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0, dtype=dtypes.float32)) + mean=0.0, stddev=1.0, dtype=dtypes.float32))) for w in embedding_weights: w.initializer.run() return embedding_weights def test_hashed_embedding_consistency(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant(["foo", "foo"]) # The first three sampled_candidates are equal, so the first three @@ -429,7 +442,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase): embedding_lookup_result[1][3]) def test_hashed_embedding_multi_dimension(self): - with self.cached_session(): + with self.cached_session(), local_variable_scope(): embedding_weights = self._random_weights() values = constant_op.constant([["foo", "bar", "bar"], ["bar", "bar", "foo"]]) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py index 3fe896426a7..ccba8a46c7c 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py @@ -1069,17 +1069,14 @@ def _create_partitioned_variables(name, 'As TPU embedding is not optimized for small tables, ' 'please consider other ways for this embedding lookup.') - slicing = [num_hosts, 1] - - # TODO(shizhiw): deprecated, use tf.get_variable()? - return partitioned_variables.create_partitioned_variables( - name=name, - slicing=slicing, + return list(variable_scope.get_variable( + name, shape=(vocabulary_size, embedding_dimension), + partitioner=partitioned_variables.fixed_size_partitioner(num_hosts), dtype=dtypes.float32, initializer=initializer, collections=collections, - trainable=False) + trainable=False)) @ops.RegisterGradient('TPUEmbeddingActivations') diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 443f54a9586..dba3409c9e1 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -758,11 +758,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase): assert num_shards > 0 assert num_shards <= vocab_size - embedding_weights = partitioned_variables.create_partitioned_variables( + initializer = init_ops.truncated_normal_initializer( + mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32) + embedding_weights = list(variable_scope.get_variable( + name="embedding_weights", shape=[vocab_size, embed_dim], - slicing=[num_shards, 1], - initializer=init_ops.truncated_normal_initializer( - mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)) + partitioner=partitioned_variables.fixed_size_partitioner(num_shards), + initializer=initializer)) for w in embedding_weights: w.initializer.run() embedding_weights = [w.eval() for w in embedding_weights] diff --git a/tensorflow/python/training/saver_large_partitioned_variable_test.py b/tensorflow/python/training/saver_large_partitioned_variable_test.py index 1a44511cfeb..84458836d06 100644 --- a/tensorflow/python/training/saver_large_partitioned_variable_test.py +++ b/tensorflow/python/training/saver_large_partitioned_variable_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import saver @@ -44,8 +45,12 @@ class SaverLargePartitionedVariableTest(test.TestCase): # split into smaller sized variables. init = lambda shape, dtype, partition_info: constant_op.constant( True, dtype, shape) - partitioned_var = partitioned_variables.create_partitioned_variables( - [1 << 31], [4], init, dtype=dtypes.bool, name=var_name) + partitioned_var = list(variable_scope.get_variable( + var_name, + shape=[1 << 31], + partitioner=partitioned_variables.fixed_size_partitioner(4), + initializer=init, + dtype=dtypes.bool)) variables.global_variables_initializer().run() save = saver.Saver(partitioned_var) val = save.save(sess, save_path) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index eb2690985d5..be49e6e7157 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -998,19 +998,12 @@ class SaveRestoreShardedTest(test.TestCase): call_saver_with_dict = False # updated by test loop below - def _save(slices=None, partitioner=None): + def _save(partitioner=None): with self.session(graph=ops_lib.Graph()) as sess: # Calls .eval() to return the ndarray that makes up the full variable. rnd = random_ops.random_uniform(var_full_shape).eval() - if slices: - assert not partitioner - # TODO(apassos): make create_partitioned_variables take use_resource - # option to make this test passable without creating a named - # variable_scope. - vs = partitioned_variables.create_partitioned_variables( - var_full_shape, slices, rnd, name=var_name) - elif partitioner: + if partitioner: vs = [ variable_scope.get_variable( var_name, @@ -1027,7 +1020,7 @@ class SaveRestoreShardedTest(test.TestCase): variables.global_variables_initializer().run() if call_saver_with_dict: - saver = saver_module.Saver({var_name: (vs if slices else vs[0])}) + saver = saver_module.Saver({var_name: vs[0]}) else: saver = saver_module.Saver(vs) actual_path = saver.save(sess, saved_path) @@ -1035,16 +1028,9 @@ class SaveRestoreShardedTest(test.TestCase): return rnd - def _restore(slices=None, partitioner=None): + def _restore(partitioner=None): with self.session(graph=ops_lib.Graph()) as sess: - if slices: - assert not partitioner - new_vs = partitioned_variables.create_partitioned_variables( - var_full_shape, - slices, - array_ops.zeros(var_full_shape), # != original contents. - name=var_name) - elif partitioner: + if partitioner: new_vs = [ variable_scope.get_variable( var_name, @@ -1063,7 +1049,7 @@ class SaveRestoreShardedTest(test.TestCase): variables.global_variables_initializer().run() if call_saver_with_dict: saver = saver_module.Saver({ - var_name: (new_vs if slices else new_vs[0]) + var_name: new_vs[0] }) else: saver = saver_module.Saver(new_vs) @@ -1071,11 +1057,7 @@ class SaveRestoreShardedTest(test.TestCase): if partitioner: return new_vs[0].as_tensor().eval() - elif slices and slices[0] != 1: - return array_ops.concat(new_vs, 0).eval() - elif slices and slices[1] != 1: - return array_ops.concat(new_vs, 1).eval() - else: # Non-sliced. + else: return new_vs[0].eval() for call_saver_with_dict in {False, True}: @@ -1086,27 +1068,23 @@ class SaveRestoreShardedTest(test.TestCase): restored_full = _restore() self.assertAllEqual(saved_full, restored_full) - # Saves 10 horizontal parts of a partitioned variable. - # Restores into a full variable, non-sliced. - saved_full = _save(slices=[10, 1]) - restored_full = _restore() - self.assertAllEqual(saved_full, restored_full) - - # Restores into a different number/orientation of slices. - restored_full = _restore(slices=[2, 1]) # 2 horizon parts. - self.assertAllEqual(saved_full, restored_full) - restored_full = _restore(slices=[1, 3]) # 3 vertical parts. - self.assertAllEqual(saved_full, restored_full) - - # Restores into a PartitionedVariable + # Restores into the same number of partitions. restored_full = _restore( partitioner=partitioned_variables.fixed_size_partitioner( num_shards=2)) self.assertAllEqual(saved_full, restored_full) - # Now, saves a full variable and restores in slices. + # Restores into a different number of partitions. + restored_full = _restore( + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=3)) + self.assertAllEqual(saved_full, restored_full) + + # Now, saves a full variable and restores PartitionedVariable. saved_full = _save() - restored_full = _restore(slices=[1, 3]) + restored_full = _restore( + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=3)) self.assertAllEqual(saved_full, restored_full) def testPartitionedVariable(self):