1166 lines
46 KiB
Python
1166 lines
46 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Functional tests for ops used with embeddings."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import itertools
|
|
import math
|
|
|
|
import numpy as np
|
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
|
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import data_flow_ops
|
|
from tensorflow.python.ops import embedding_ops
|
|
from tensorflow.python.ops import gradient_checker
|
|
from tensorflow.python.ops import init_ops
|
|
from tensorflow.python.ops import linalg_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import partitioned_variables
|
|
from tensorflow.python.ops import state_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.platform import tf_logging
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
def _AsLong(array):
|
|
"""Casts arrays elements to long type. Used to convert from numpy tf."""
|
|
return [int(x) for x in array]
|
|
|
|
|
|
class ScatterAddSubTest(test.TestCase):
|
|
|
|
def _TestCase(self, shape, indices, scatter_op=state_ops.scatter_add):
|
|
"""Run a random test case with the given shape and indices.
|
|
|
|
Args:
|
|
shape: Shape of the parameters array.
|
|
indices: One-dimensional array of ints, the indices of the last dimension
|
|
of the parameters to update.
|
|
scatter_op: ScatterAdd or ScatterSub.
|
|
"""
|
|
super(ScatterAddSubTest, self).setUp()
|
|
with self.cached_session(use_gpu=False):
|
|
# Create a random parameter array of given shape
|
|
p_init = np.random.rand(*shape).astype("f")
|
|
# Create the shape of the update array. All dimensions except the last
|
|
# match the parameter array, the last dimension equals the # of indices.
|
|
vals_shape = [len(indices)] + shape[1:]
|
|
vals_init = np.random.rand(*vals_shape).astype("f")
|
|
v_i = [float(x) for x in vals_init.ravel()]
|
|
p = variables.Variable(p_init)
|
|
vals = constant_op.constant(v_i, shape=vals_shape, name="vals")
|
|
ind = constant_op.constant(indices, dtype=dtypes.int32)
|
|
p2 = scatter_op(p, ind, vals, name="updated_p")
|
|
# p = init
|
|
variables.global_variables_initializer().run()
|
|
# p += vals
|
|
result = self.evaluate(p2)
|
|
# Compute the expected 'p' using numpy operations.
|
|
for i, ind in enumerate(indices):
|
|
if scatter_op == state_ops.scatter_add:
|
|
p_init.reshape(shape[0], -1)[ind, :] += (vals_init.reshape(
|
|
vals_shape[0], -1)[i, :])
|
|
else:
|
|
p_init.reshape(shape[0], -1)[ind, :] -= (vals_init.reshape(
|
|
vals_shape[0], -1)[i, :])
|
|
self.assertTrue(all((p_init == result).ravel()))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testNoRepetitions(self):
|
|
self._TestCase([2, 2], [1])
|
|
self._TestCase([4, 4, 4], [2, 0])
|
|
self._TestCase([43, 20, 10, 10], [42, 5, 6, 1, 3, 5, 7, 9])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testWithRepetitions(self):
|
|
self._TestCase([2, 2], [1, 1])
|
|
self._TestCase([5, 3, 9, 5], [2, 0, 4, 1, 3, 1, 4, 0, 4, 3])
|
|
self._TestCase([32, 4, 4], [31] * 8)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testRandom(self):
|
|
# Random shapes of rank 4, random indices
|
|
for _ in range(5):
|
|
shape = np.random.randint(1, 20, size=4)
|
|
indices = np.random.randint(shape[0], size=2 * shape[0])
|
|
self._TestCase(_AsLong(list(shape)), list(indices))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSubRandom(self):
|
|
# Random shapes of rank 4, random indices
|
|
for _ in range(5):
|
|
shape = np.random.randint(1, 20, size=4)
|
|
indices = np.random.randint(shape[0], size=2 * shape[0])
|
|
self._TestCase(_AsLong(list(shape)), list(indices), state_ops.scatter_sub)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testWrongShape(self):
|
|
# Indices and values mismatch.
|
|
var = variables.Variable(
|
|
array_ops.zeros(shape=[1024, 64, 64], dtype=dtypes.float32))
|
|
indices = array_ops.placeholder(dtypes.int32, shape=[32])
|
|
values = array_ops.placeholder(dtypes.float32, shape=[33, 64, 64])
|
|
with self.assertRaises(ValueError):
|
|
state_ops.scatter_add(var, indices, values)
|
|
|
|
# Var and values mismatch.
|
|
values = array_ops.placeholder(dtypes.float32, shape=[32, 64, 63])
|
|
with self.assertRaises(ValueError):
|
|
state_ops.scatter_add(var, indices, values)
|
|
|
|
|
|
def _PName(param_id):
|
|
return "p" + str(param_id)
|
|
|
|
|
|
def _EmbeddingParams(num_shards,
|
|
vocab_size,
|
|
dtype=dtypes.float32,
|
|
shape=None,
|
|
use_shapeless_placeholder=False):
|
|
p = []
|
|
params = {}
|
|
feed_dict = {}
|
|
if not shape:
|
|
shape = [10]
|
|
for i in range(num_shards):
|
|
shard_shape = [vocab_size // num_shards] + shape
|
|
if i < vocab_size % num_shards: # Excess goes evenly on the first shards
|
|
shard_shape[0] += 1
|
|
|
|
param_name = _PName(i)
|
|
|
|
if use_shapeless_placeholder:
|
|
param = array_ops.placeholder(dtype, shape=None, name=param_name)
|
|
else:
|
|
param = constant_op.constant(
|
|
1.0, shape=shard_shape, dtype=dtype, name=param_name)
|
|
p.append(param)
|
|
np_type = "f" if dtype == dtypes.float32 else "d"
|
|
val = (np.random.rand(*shard_shape).astype(np_type)) + 1
|
|
params[param_name + ":0"] = val
|
|
feed_dict[param.name] = val
|
|
return p, params, feed_dict
|
|
|
|
|
|
def _EmbeddingParamsAsPartitionedVariable(num_shards,
|
|
vocab_size,
|
|
dtype=dtypes.float32,
|
|
shape=None,
|
|
use_resource=False):
|
|
p, params, feed_dict = _EmbeddingParams(
|
|
num_shards, vocab_size, dtype=dtype, shape=shape)
|
|
shape = shape or [10]
|
|
partitioned_variable = variable_scope.get_variable(
|
|
"p",
|
|
shape=[vocab_size] + shape,
|
|
initializer=array_ops.concat([params[p_i.name] for p_i in p], 0),
|
|
partitioner=partitioned_variables.min_max_variable_partitioner(
|
|
max_partitions=num_shards, min_slice_size=1),
|
|
use_resource=use_resource)
|
|
return p, partitioned_variable, params, feed_dict
|
|
|
|
|
|
def _EmbeddingResult(params,
|
|
id_vals,
|
|
num_shards,
|
|
vocab_size,
|
|
partition_strategy="mod",
|
|
weight_vals=None):
|
|
if weight_vals is None:
|
|
weight_vals = np.copy(id_vals)
|
|
weight_vals.fill(1)
|
|
values = []
|
|
weights = []
|
|
weights_squared = []
|
|
for ids, wts in zip(id_vals, weight_vals):
|
|
value_aggregation = None
|
|
weight_aggregation = None
|
|
squared_weight_aggregation = None
|
|
if isinstance(ids, compat.integral_types):
|
|
ids = [ids]
|
|
wts = [wts]
|
|
for i, weight_value in zip(ids, wts):
|
|
if partition_strategy == "mod":
|
|
val = np.copy(params[_PName(i % num_shards) + ":0"][
|
|
i // num_shards, :]) * weight_value
|
|
elif partition_strategy == "div":
|
|
ids_per_partition, extras = divmod(vocab_size, num_shards)
|
|
threshold = extras * (ids_per_partition + 1)
|
|
if i < threshold:
|
|
partition = i // (ids_per_partition + 1)
|
|
offset = i % (ids_per_partition + 1)
|
|
else:
|
|
partition = extras + (i - threshold) // ids_per_partition
|
|
offset = (i - threshold) % ids_per_partition
|
|
val = np.copy(
|
|
params[_PName(partition) + ":0"][offset, :]) * weight_value
|
|
else:
|
|
assert False
|
|
if value_aggregation is None:
|
|
assert weight_aggregation is None
|
|
assert squared_weight_aggregation is None
|
|
value_aggregation = val
|
|
weight_aggregation = weight_value
|
|
squared_weight_aggregation = weight_value * weight_value
|
|
else:
|
|
assert weight_aggregation is not None
|
|
assert squared_weight_aggregation is not None
|
|
value_aggregation += val
|
|
weight_aggregation += weight_value
|
|
squared_weight_aggregation += weight_value * weight_value
|
|
values.append(value_aggregation)
|
|
weights.append(weight_aggregation)
|
|
weights_squared.append(squared_weight_aggregation)
|
|
values = np.array(values).astype(np.float32)
|
|
weights = np.array(weights).astype(np.float32)
|
|
weights_squared = np.array(weights_squared).astype(np.float32)
|
|
return values, weights, weights_squared
|
|
|
|
|
|
class EmbeddingLookupTest(test.TestCase):
|
|
|
|
# This test looks up [0, 0] in a parameter matrix sharded 2 ways. Since
|
|
# both the ids are in the first shard, one of the resulting lookup
|
|
# vector is going to be empty. The subsequent DivOp fails because of that.
|
|
# TODO(keveman): Disabling the test until the underlying problem is fixed.
|
|
@test_util.run_deprecated_v1
|
|
def testSimpleSharded(self):
|
|
with self.cached_session():
|
|
num_shards = 2
|
|
vocab_size = 4
|
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
|
|
|
id_vals = np.array([0, 0])
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
|
|
print("Construct ids", ids.get_shape())
|
|
embedding = embedding_ops.embedding_lookup(p, ids)
|
|
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMaxNorm(self):
|
|
with self.cached_session():
|
|
embeddings = constant_op.constant([[2.0]])
|
|
|
|
ids = constant_op.constant([0], dtype=dtypes.int32)
|
|
embedding = embedding_ops.embedding_lookup(
|
|
[embeddings], ids, max_norm=1.0)
|
|
|
|
self.assertAllEqual(embedding.eval(), [[1.0]])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testMaxNormNontrivial(self):
|
|
with self.cached_session():
|
|
embeddings = constant_op.constant([[2.0, 4.0], [3.0, 1.0]])
|
|
|
|
ids = constant_op.constant([0, 1], dtype=dtypes.int32)
|
|
embedding = embedding_ops.embedding_lookup(
|
|
[embeddings], ids, max_norm=2.0)
|
|
|
|
norms = math_ops.sqrt(
|
|
math_ops.reduce_sum(embeddings * embeddings, axis=1))
|
|
normalized = embeddings / array_ops.stack([norms, norms], axis=1)
|
|
self.assertAllEqual(embedding.eval(), 2 * self.evaluate(normalized))
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSimpleShardedPartitionedVariable(self):
|
|
with self.cached_session() as sess:
|
|
num_shards = 2
|
|
vocab_size = 4
|
|
p, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable(
|
|
num_shards, vocab_size)
|
|
|
|
id_vals = np.array([0, 0])
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
|
|
print("Construct ids", ids.get_shape())
|
|
embedding = embedding_ops.embedding_lookup(p_variable, ids)
|
|
variables.global_variables_initializer().run()
|
|
params_values = [params[p_i.name] for p_i in p]
|
|
# Test that the PartitionedVariable components equal the list in p
|
|
p_var_val = self.evaluate(list(p_variable))
|
|
# Actual test
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
|
self.assertAllEqual(params_values, p_var_val)
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSimpleShardedPartitionedResourceVariable(self):
|
|
with self.cached_session() as sess:
|
|
num_shards = 2
|
|
vocab_size = 4
|
|
p, p_variable, params, _ = _EmbeddingParamsAsPartitionedVariable(
|
|
num_shards, vocab_size, use_resource=True)
|
|
|
|
id_vals = np.array([0, 0])
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
|
|
print("Construct ids", ids.get_shape())
|
|
embedding = embedding_ops.embedding_lookup(p_variable, ids)
|
|
variables.global_variables_initializer().run()
|
|
params_values = [params[p_i.name] for p_i in p]
|
|
# Test that the PartitionedVariable components equal the list in p
|
|
p_var_val = self.evaluate(list(p_variable))
|
|
# Actual test
|
|
print(ops.get_default_graph().as_graph_def())
|
|
tf_result = self.evaluate(embedding)
|
|
np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
|
self.assertAllEqual(params_values, p_var_val)
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testShardedModPartitioningInt32Ids(self):
|
|
with self.cached_session():
|
|
num_shards = 5
|
|
vocab_size = 13
|
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
|
# parameters are spread in num_shards matrices, so the first
|
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
|
|
|
num_vals = 30
|
|
# Fetch num_vals embeddings for random word ids. Since
|
|
# num_vals > vocab_size, this ought to have repetitions, so
|
|
# will test that aspect.
|
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
|
|
|
|
embedding = embedding_ops.embedding_lookup(p, ids)
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testShardedModPartitioningInt64Ids(self):
|
|
with self.cached_session():
|
|
num_shards = 5
|
|
vocab_size = 13
|
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
|
# parameters are spread in num_shards matrices, so the first
|
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
|
|
|
num_vals = 30
|
|
# Fetch num_vals embeddings for random word ids. Since
|
|
# num_vals > vocab_size, this ought to have repetitions, so
|
|
# will test that aspect.
|
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int64)
|
|
|
|
embedding = embedding_ops.embedding_lookup(p, ids)
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testShardedDivPartitioningInt32Ids(self):
|
|
with self.cached_session():
|
|
num_shards = 5
|
|
vocab_size = 13
|
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
|
# parameters are spread in num_shards matrices, so the first
|
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
|
|
|
num_vals = 30
|
|
# Fetch num_vals embeddings for random word ids. Since
|
|
# num_vals > vocab_size, this ought to have repetitions, so
|
|
# will test that aspect.
|
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
|
|
|
|
embedding = embedding_ops.embedding_lookup(
|
|
p, ids, partition_strategy="div")
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(
|
|
params, id_vals, num_shards, vocab_size, partition_strategy="div")
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testShardedDivPartitioningInt32IdsPartitionedVariable(self):
|
|
with self.cached_session():
|
|
num_shards = 5
|
|
vocab_size = 13
|
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
|
# parameters are spread in num_shards matrices, so the first
|
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
|
_, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable(
|
|
num_shards, vocab_size)
|
|
|
|
num_vals = 30
|
|
# Fetch num_vals embeddings for random word ids. Since
|
|
# num_vals > vocab_size, this ought to have repetitions, so
|
|
# will test that aspect.
|
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int32)
|
|
variables.global_variables_initializer().run()
|
|
embedding = embedding_ops.embedding_lookup(
|
|
p_variable, ids, partition_strategy="div")
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(
|
|
params, id_vals, num_shards, vocab_size, partition_strategy="div")
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testShardedDivPartitioningInt64Ids(self):
|
|
with self.cached_session():
|
|
num_shards = 5
|
|
vocab_size = 13
|
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
|
# parameters are spread in num_shards matrices, so the first
|
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
|
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
|
|
|
|
num_vals = 30
|
|
# Fetch num_vals embeddings for random word ids. Since
|
|
# num_vals > vocab_size, this ought to have repetitions, so
|
|
# will test that aspect.
|
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int64)
|
|
|
|
embedding = embedding_ops.embedding_lookup(
|
|
p, ids, partition_strategy="div")
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(
|
|
params, id_vals, num_shards, vocab_size, partition_strategy="div")
|
|
self.assertAllEqual(np_result, tf_result)
|
|
self.assertShapeEqual(np_result, embedding)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testShardedDivPartitioningUnknownParamShape(self):
|
|
with self.cached_session():
|
|
num_shards = 5
|
|
vocab_size = 13
|
|
# Embedding dimensions is 10. The vocab_size x 10 embedding
|
|
# parameters are spread in num_shards matrices, so the first
|
|
# 3 shards are 3 x 10 and the last 2 shards are 2 x 10.
|
|
|
|
# We clear parameter shapes, to test when shape is not statically known.
|
|
p, params, feed_dict = _EmbeddingParams(
|
|
num_shards, vocab_size, use_shapeless_placeholder=True)
|
|
|
|
num_vals = 30
|
|
# Fetch num_vals embeddings for random word ids. Since
|
|
# num_vals > vocab_size, this ought to have repetitions, so
|
|
# will test that aspect.
|
|
id_vals = np.random.randint(vocab_size, size=num_vals)
|
|
ids = constant_op.constant(list(id_vals), dtype=dtypes.int64)
|
|
|
|
embedding = embedding_ops.embedding_lookup(
|
|
p, ids, partition_strategy="div")
|
|
tf_result = embedding.eval(feed_dict=feed_dict)
|
|
np_result, _, _ = _EmbeddingResult(
|
|
params, id_vals, num_shards, vocab_size, partition_strategy="div")
|
|
self.assertAllEqual(np_result, tf_result)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testGradientsEmbeddingLookup(self):
|
|
vocab_size = 9
|
|
num_ids = 10
|
|
id_vals = list(np.random.randint(vocab_size, size=num_ids))
|
|
tf_logging.vlog(1, id_vals)
|
|
for ids_shape in [(10,), (2, 5)]:
|
|
for num_shards in [1, 3]:
|
|
with self.cached_session():
|
|
ids = constant_op.constant(
|
|
id_vals, shape=ids_shape, dtype=dtypes.int32)
|
|
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
|
|
y = embedding_ops.embedding_lookup(x, ids)
|
|
y_shape = ids_shape + tuple(params[_PName(0) + ":0"].shape[1:])
|
|
x_name = [_PName(i) for i in range(num_shards)]
|
|
x_init_value = [params[x_n + ":0"] for x_n in x_name]
|
|
x_shape = [i.shape for i in x_init_value]
|
|
err = gradient_checker.compute_gradient_error(
|
|
x, x_shape, y, y_shape, x_init_value=x_init_value)
|
|
self.assertLess(err, 1e-4)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testGradientsEmbeddingLookupWithComputedParams(self):
|
|
vocab_size = 9
|
|
num_ids = 5
|
|
id_vals = list(np.random.randint(vocab_size, size=num_ids))
|
|
tf_logging.vlog(1, id_vals)
|
|
for num_shards in [1, 3]:
|
|
with self.cached_session():
|
|
ids = constant_op.constant(id_vals, dtype=dtypes.int32)
|
|
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
|
|
# This will force a conversion from IndexedSlices to Tensor.
|
|
x_squared = [math_ops.square(elem) for elem in x]
|
|
y = embedding_ops.embedding_lookup(x_squared, ids)
|
|
y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
|
|
x_name = [_PName(i) for i in range(num_shards)]
|
|
x_init_value = [params[x_n + ":0"] for x_n in x_name]
|
|
x_shape = [i.shape for i in x_init_value]
|
|
err = gradient_checker.compute_gradient_error(
|
|
x, x_shape, y, y_shape, x_init_value=x_init_value)
|
|
self.assertLess(err, 1e-3)
|
|
|
|
def testConstructionNonSharded(self):
|
|
with ops.Graph().as_default():
|
|
p = variables.Variable(
|
|
array_ops.zeros(shape=[100, 100], dtype=dtypes.float32))
|
|
ids = constant_op.constant([0, 1, 1, 7], dtype=dtypes.int32)
|
|
embedding_ops.embedding_lookup([p], ids)
|
|
|
|
def testConstructionSharded(self):
|
|
with ops.Graph().as_default():
|
|
p = []
|
|
for _ in range(2):
|
|
p += [
|
|
variables.Variable(
|
|
array_ops.zeros(shape=[100, 100], dtype=dtypes.float32))
|
|
]
|
|
ids = constant_op.constant([0, 1, 1, 17], dtype=dtypes.int32)
|
|
embedding_ops.embedding_lookup(p, ids)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testHigherRank(self):
|
|
np.random.seed(8)
|
|
with self.cached_session():
|
|
for params_shape in (12,), (6, 3):
|
|
params = np.random.randn(*params_shape)
|
|
for ids_shape in (3, 2), (4, 3):
|
|
ids = np.random.randint(
|
|
params.shape[0], size=np.prod(ids_shape)).reshape(ids_shape)
|
|
# Compare nonsharded to gather
|
|
simple = embedding_ops.embedding_lookup(params, ids).eval()
|
|
self.assertAllEqual(simple, array_ops.gather(params, ids).eval())
|
|
# Run a few random sharded versions
|
|
for procs in 1, 2, 3:
|
|
stride = procs * math_ops.range(params.shape[0] // procs)
|
|
split_params = [
|
|
array_ops.gather(params, stride + p) for p in xrange(procs)
|
|
]
|
|
sharded = embedding_ops.embedding_lookup(split_params, ids).eval()
|
|
self.assertAllEqual(simple, sharded)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testHigherRankMaxNorm(self):
|
|
np.random.seed(8)
|
|
with self.cached_session():
|
|
for params_shape in (12,), (6, 3), (6, 2, 3):
|
|
# Test embedding rank 0, 1, 2.
|
|
# Note: the first dimension must be a common multiple of procs below.
|
|
params = 2 * np.ones(params_shape)
|
|
params_norm = params / np.sqrt(
|
|
np.sum(
|
|
params * params, tuple(range(params.ndim)[1:]), keepdims=True))
|
|
for ids_shape in (), (3), (4, 3), (2, 3, 4):
|
|
ids = np.random.randint(
|
|
params.shape[0], size=np.prod(ids_shape,
|
|
dtype=np.int64)).reshape(ids_shape)
|
|
# Compare nonsharded to gather
|
|
simple = embedding_ops.embedding_lookup(
|
|
params, ids, max_norm=1.0).eval()
|
|
# assertAllClose is used here as different implementations of sqrt may
|
|
# be used to compute each of the values being compared. For example,
|
|
# on AVX512 builds the embedding operation makes use of Eigen's fast
|
|
# vectorized square root algorithm for doubles. These different
|
|
# implementations of sqrt are not guaranteed to produce exactly the
|
|
# same results. Therefore, an exact comparison cannot be made.
|
|
self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
|
|
# Run a few different sharded versions.
|
|
for procs in 1, 2, 3:
|
|
stride = procs * math_ops.range(params.shape[0] // procs)
|
|
split_params = [
|
|
array_ops.gather(params, stride + p) for p in xrange(procs)
|
|
]
|
|
sharded = embedding_ops.embedding_lookup(
|
|
split_params, ids, max_norm=1.0).eval()
|
|
self.assertAllEqual(simple, sharded)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testTransform(self):
|
|
# This tests all combinations of:
|
|
# - ids rank 0, 1, >1
|
|
# - params sharded/unsharded
|
|
# It always applies max_norm.
|
|
np.random.seed(8)
|
|
l2_norm = 2.
|
|
with self.cached_session():
|
|
# Param values are in [l2_norm, l2_norm+1) so it will always clip.
|
|
params = np.random.rand(6, 3) + l2_norm
|
|
params_norm = l2_norm * params / np.sqrt(
|
|
np.sum(params * params, axis=1, keepdims=True))
|
|
# Compute the norm of each embedding. This will change the embedding
|
|
# rank to 0.
|
|
params_norm = np.linalg.norm(params_norm, axis=1)
|
|
transform = lambda x: linalg_ops.norm(x, axis=1)
|
|
for ids_shape in (), (3), (4, 3), (2, 3, 4):
|
|
# Test ids rank 0, 1, 2, 3.
|
|
ids = np.random.randint(
|
|
params.shape[0], size=np.prod(ids_shape,
|
|
dtype=np.int64)).reshape(ids_shape)
|
|
# Compare nonsharded to gather.
|
|
simple = embedding_ops._embedding_lookup_and_transform(
|
|
params, ids, max_norm=l2_norm, transform_fn=transform).eval()
|
|
self.assertAllClose(simple, array_ops.gather(params_norm, ids).eval())
|
|
# Run a few different sharded versions.
|
|
for procs in 1, 2, 3:
|
|
stride = procs * math_ops.range(params.shape[0] // procs)
|
|
split_params = [
|
|
array_ops.gather(params, stride + p) for p in xrange(procs)
|
|
]
|
|
sharded = embedding_ops._embedding_lookup_and_transform(
|
|
split_params, ids, max_norm=l2_norm,
|
|
transform_fn=transform).eval()
|
|
# assertAllClose is used here as different implementations of sqrt may
|
|
# be used to compute each of the values being compared. For example,
|
|
# on AVX512 builds the embedding operation makes use of Eigen's fast
|
|
# vectorized square root algorithm for doubles. These different
|
|
# implementations of sqrt are not guaranteed to produce exactly the
|
|
# same results. Therefore, an exact comparison cannot be made.
|
|
self.assertAllClose(simple, sharded)
|
|
|
|
def testRaggedMaxNorm(self):
|
|
embeddings = constant_op.constant([[2.0]])
|
|
ids = ragged_factory_ops.constant([[0, 0], [0]], dtype=dtypes.int32)
|
|
embedding = embedding_ops.embedding_lookup([embeddings], ids, max_norm=1.0)
|
|
self.assertAllEqual(embedding, [[[1.0], [1.0]], [[1.0]]])
|
|
|
|
|
|
class EmbeddingLookupSparseTest(test.TestCase):
|
|
|
|
def _RandomIdsAndWeights(self, batch_size, vocab_size):
|
|
max_val_per_entry = 6
|
|
vals_per_batch_entry = np.random.randint(
|
|
1, max_val_per_entry, size=batch_size)
|
|
num_vals = np.sum(vals_per_batch_entry)
|
|
|
|
ids = np.random.randint(vocab_size, size=num_vals)
|
|
weights = 1 + np.random.rand(num_vals)
|
|
|
|
indices = []
|
|
for batch_entry, num_val in enumerate(vals_per_batch_entry):
|
|
for val_index in range(num_val):
|
|
indices.append([batch_entry, val_index])
|
|
|
|
shape = [batch_size, max_val_per_entry]
|
|
|
|
sp_ids = sparse_tensor.SparseTensor(
|
|
constant_op.constant(indices, dtypes.int64),
|
|
constant_op.constant(ids, dtypes.int32),
|
|
constant_op.constant(shape, dtypes.int64))
|
|
sp_weights = sparse_tensor.SparseTensor(
|
|
constant_op.constant(indices, dtypes.int64),
|
|
constant_op.constant(weights, dtypes.float32),
|
|
constant_op.constant(shape, dtypes.int64))
|
|
|
|
return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
|
|
|
|
def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
|
|
grouped_vals = []
|
|
index = 0
|
|
for num_val in vals_per_batch_entry:
|
|
grouped_vals.append(list(vals[index:(index + num_val)]))
|
|
index += num_val
|
|
return grouped_vals
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testEmbeddingLookupSparse(self):
|
|
vocab_size = 13
|
|
batch_size = 10
|
|
param_shape = [2, 5]
|
|
expected_lookup_result_shape = [None] + param_shape
|
|
|
|
sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
|
|
self._RandomIdsAndWeights(batch_size, vocab_size))
|
|
|
|
grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
|
|
grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
|
|
grouped_ignored_weights = self._GroupByBatchEntry(
|
|
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
|
|
|
|
for num_shards, combiner, dtype, ignore_weights in itertools.product(
|
|
[1, 5], ["sum", "mean", "sqrtn"],
|
|
[dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64],
|
|
[True, False]):
|
|
|
|
with self.cached_session():
|
|
p, params, feed_dict = _EmbeddingParams(
|
|
num_shards, vocab_size, shape=param_shape, dtype=dtype)
|
|
embedding_sum = embedding_ops.embedding_lookup_sparse(
|
|
p,
|
|
sp_ids,
|
|
None if ignore_weights else sp_weights,
|
|
combiner=combiner)
|
|
|
|
self.assertEqual(embedding_sum.get_shape().as_list(),
|
|
expected_lookup_result_shape)
|
|
if dtype in (dtypes.float16, dtypes.bfloat16):
|
|
self.assertEqual(embedding_sum.dtype, dtypes.float32)
|
|
else:
|
|
self.assertEqual(embedding_sum.dtype, dtype)
|
|
|
|
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
|
|
|
|
np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
|
|
params,
|
|
grouped_ids,
|
|
num_shards,
|
|
vocab_size,
|
|
weight_vals=grouped_ignored_weights
|
|
if ignore_weights else grouped_weights)
|
|
if combiner == "mean":
|
|
np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
|
|
if combiner == "sqrtn":
|
|
np_embedding_sum /= np.reshape(
|
|
np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
|
|
|
|
rtol = 1e-6
|
|
if dtype == dtypes.bfloat16:
|
|
rtol = 1e-2
|
|
elif dtype == dtypes.float16:
|
|
rtol = 1e-3
|
|
atol = rtol
|
|
self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testGradientsEmbeddingLookupSparse(self):
|
|
vocab_size = 12
|
|
batch_size = 4
|
|
param_shape = [2, 3]
|
|
sp_ids, sp_weights, _, _, _ = (self._RandomIdsAndWeights(
|
|
batch_size, vocab_size))
|
|
|
|
for num_shards, combiner, dtype, ignore_weights in itertools.product(
|
|
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
|
|
dtypes.float64], [True, False]):
|
|
with self.cached_session():
|
|
x, params, _ = _EmbeddingParams(
|
|
num_shards, vocab_size, shape=param_shape, dtype=dtype)
|
|
|
|
y = embedding_ops.embedding_lookup_sparse(
|
|
x,
|
|
sp_ids,
|
|
None if ignore_weights else sp_weights,
|
|
combiner=combiner)
|
|
x_name = [_PName(i) for i in range(num_shards)]
|
|
x_init_value = [params[x_n + ":0"] for x_n in x_name]
|
|
x_shape = [i.shape for i in x_init_value]
|
|
y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
|
|
err = gradient_checker.compute_gradient_error(
|
|
x, x_shape, y, y_shape, x_init_value=x_init_value)
|
|
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testIncompatibleShapes(self):
|
|
with self.cached_session():
|
|
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
|
|
sp_ids = sparse_tensor.SparseTensor(
|
|
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
|
|
constant_op.constant([0, 1, 2], dtypes.int32),
|
|
constant_op.constant([2, 2], dtypes.int64))
|
|
sp_weights = sparse_tensor.SparseTensor(
|
|
constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
|
|
constant_op.constant([12.0, 5.0], dtypes.float32),
|
|
constant_op.constant([1, 2], dtypes.int64))
|
|
|
|
with self.assertRaises(ValueError):
|
|
embedding_ops.embedding_lookup_sparse(
|
|
x, sp_ids, sp_weights, combiner="mean")
|
|
|
|
|
|
class SafeEmbeddingLookupSparseTest(test.TestCase):
|
|
|
|
def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
|
|
assert vocab_size > 0
|
|
assert embed_dim > 0
|
|
assert num_shards > 0
|
|
assert num_shards <= vocab_size
|
|
|
|
initializer = init_ops.truncated_normal_initializer(
|
|
mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
|
|
embedding_weights = list(variable_scope.get_variable(
|
|
name="embedding_weights",
|
|
shape=[vocab_size, embed_dim],
|
|
partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
|
|
initializer=initializer))
|
|
for w in embedding_weights:
|
|
w.initializer.run()
|
|
embedding_weights = [w.eval() for w in embedding_weights]
|
|
return embedding_weights
|
|
|
|
def _ids_and_weights_2d(self):
|
|
# Each row demonstrates a test case:
|
|
# Row 0: multiple valid ids, 1 invalid id, weighted mean
|
|
# Row 1: all ids are invalid (leaving no valid ids after pruning)
|
|
# Row 2: no ids to begin with
|
|
# Row 3: single id
|
|
# Row 4: all ids have <=0 weight
|
|
indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]]
|
|
ids = [0, 1, -1, -1, 2, 0, 1]
|
|
weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
|
|
shape = [5, 4]
|
|
|
|
sparse_ids = sparse_tensor.SparseTensor(
|
|
constant_op.constant(indices, dtypes.int64),
|
|
constant_op.constant(ids, dtypes.int64),
|
|
constant_op.constant(shape, dtypes.int64))
|
|
|
|
sparse_weights = sparse_tensor.SparseTensor(
|
|
constant_op.constant(indices, dtypes.int64),
|
|
constant_op.constant(weights, dtypes.float32),
|
|
constant_op.constant(shape, dtypes.int64))
|
|
|
|
return sparse_ids, sparse_weights
|
|
|
|
def _ids_and_weights_3d(self):
|
|
# Each (2-D) index demonstrates a test case:
|
|
# Index 0, 0: multiple valid ids, 1 invalid id, weighted mean
|
|
# Index 0, 1: all ids are invalid (leaving no valid ids after pruning)
|
|
# Index 0, 2: no ids to begin with
|
|
# Index 1, 0: single id
|
|
# Index 1, 1: all ids have <=0 weight
|
|
# Index 1, 2: no ids to begin with
|
|
indices = [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [1, 0, 0], [1, 1, 0],
|
|
[1, 1, 1]]
|
|
ids = [0, 1, -1, -1, 2, 0, 1]
|
|
weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
|
|
shape = [2, 3, 4]
|
|
|
|
sparse_ids = sparse_tensor.SparseTensor(
|
|
constant_op.constant(indices, dtypes.int64),
|
|
constant_op.constant(ids, dtypes.int64),
|
|
constant_op.constant(shape, dtypes.int64))
|
|
|
|
sparse_weights = sparse_tensor.SparseTensor(
|
|
constant_op.constant(indices, dtypes.int64),
|
|
constant_op.constant(weights, dtypes.float32),
|
|
constant_op.constant(shape, dtypes.int64))
|
|
|
|
return sparse_ids, sparse_weights
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights()
|
|
sparse_ids, sparse_weights = self._ids_and_weights_2d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, sparse_weights).eval())
|
|
|
|
self.assertAllClose(
|
|
embedding_lookup_result,
|
|
[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
|
|
3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_return_special_vector(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights()
|
|
sparse_ids, sparse_weights = self._ids_and_weights_2d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, sparse_weights,
|
|
default_id=3).eval())
|
|
|
|
self.assertAllClose(
|
|
embedding_lookup_result,
|
|
[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
|
|
3.0, embedding_weights[0][3], embedding_weights[0][3],
|
|
embedding_weights[0][2], embedding_weights[0][3]])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_no_weights(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights()
|
|
sparse_ids, _ = self._ids_and_weights_2d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, None).eval())
|
|
|
|
self.assertAllClose(
|
|
embedding_lookup_result,
|
|
[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
|
|
[0] * 4, embedding_weights[0][2], (
|
|
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_partitioned(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights(num_shards=3)
|
|
sparse_ids, _ = self._ids_and_weights_2d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, None).eval())
|
|
|
|
embedding_weights = list(itertools.chain(*embedding_weights))
|
|
self.assertAllClose(embedding_lookup_result,
|
|
[(embedding_weights[0] + embedding_weights[1]) / 2.0,
|
|
[0] * 4, [0] * 4, embedding_weights[2],
|
|
(embedding_weights[0] + embedding_weights[1]) / 2.0])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights(num_shards=3)
|
|
sparse_ids, sparse_weights = self._ids_and_weights_2d()
|
|
|
|
embedding_weights[1] = embedding_weights[1].astype(np.float64)
|
|
self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse,
|
|
embedding_weights, sparse_ids)
|
|
embedding_weights = [
|
|
constant_op.constant(w, dtype=dtypes.float64)
|
|
for w in embedding_weights
|
|
]
|
|
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
|
|
embedding_weights, sparse_ids, sparse_weights)
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights()
|
|
sparse_ids, sparse_weights = self._ids_and_weights_3d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, sparse_weights).eval())
|
|
|
|
self.assertAllClose(embedding_lookup_result, [[
|
|
(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
|
|
[0] * 4, [0] * 4
|
|
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights()
|
|
sparse_ids, sparse_weights = self._ids_and_weights_3d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, sparse_weights,
|
|
default_id=3).eval())
|
|
|
|
self.assertAllClose(
|
|
embedding_lookup_result,
|
|
[[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
|
|
3.0, embedding_weights[0][3], embedding_weights[0][3]], [
|
|
embedding_weights[0][2], embedding_weights[0][3],
|
|
embedding_weights[0][3]
|
|
]])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights()
|
|
sparse_ids, _ = self._ids_and_weights_3d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, None).eval())
|
|
|
|
self.assertAllClose(embedding_lookup_result, [[(
|
|
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [
|
|
0
|
|
] * 4], [
|
|
embedding_weights[0][2],
|
|
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4
|
|
]])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights(num_shards=3)
|
|
sparse_ids, _ = self._ids_and_weights_3d()
|
|
|
|
embedding_lookup_result = (
|
|
embedding_ops.safe_embedding_lookup_sparse_v2(
|
|
embedding_weights, sparse_ids, None).eval())
|
|
|
|
embedding_weights = list(itertools.chain(*embedding_weights))
|
|
self.assertAllClose(embedding_lookup_result, [[
|
|
(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4
|
|
], [
|
|
embedding_weights[2],
|
|
(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4
|
|
]])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
|
|
self):
|
|
with self.cached_session():
|
|
embedding_weights = self._random_weights(num_shards=3)
|
|
sparse_ids, sparse_weights = self._ids_and_weights_3d()
|
|
|
|
embedding_weights[1] = embedding_weights[1].astype(np.float64)
|
|
self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse,
|
|
embedding_weights, sparse_ids)
|
|
embedding_weights = [
|
|
constant_op.constant(w, dtype=dtypes.float64)
|
|
for w in embedding_weights
|
|
]
|
|
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
|
|
embedding_weights, sparse_ids, sparse_weights)
|
|
|
|
|
|
class DynamicStitchOpTest(test.TestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCint32Cpu(self):
|
|
with self.session(use_gpu=False):
|
|
indices = [
|
|
ops.convert_to_tensor([0, 1, 2]),
|
|
ops.convert_to_tensor([2, 3])
|
|
]
|
|
values = [
|
|
ops.convert_to_tensor([12, 23, 34]),
|
|
ops.convert_to_tensor([1, 2])
|
|
]
|
|
self.assertAllEqual(
|
|
data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCint32Gpu(self):
|
|
with self.session(use_gpu=True):
|
|
indices = [
|
|
ops.convert_to_tensor([0, 1, 2]),
|
|
ops.convert_to_tensor([2, 3])
|
|
]
|
|
values = [
|
|
ops.convert_to_tensor([12, 23, 34]),
|
|
ops.convert_to_tensor([1, 2])
|
|
]
|
|
self.assertAllEqual(
|
|
data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt32Cpu(self):
|
|
with self.session(use_gpu=False):
|
|
indices = [
|
|
ops.convert_to_tensor([0, 1, 2]),
|
|
ops.convert_to_tensor([2, 3])
|
|
]
|
|
values = [
|
|
ops.convert_to_tensor([12, 23, 34]),
|
|
ops.convert_to_tensor([1, 2])
|
|
]
|
|
self.assertAllEqual(
|
|
data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt32Gpu(self):
|
|
with self.session(use_gpu=True):
|
|
indices = [
|
|
ops.convert_to_tensor([0, 1, 2]),
|
|
ops.convert_to_tensor([2, 3])
|
|
]
|
|
values = [
|
|
ops.convert_to_tensor([12, 23, 34]),
|
|
ops.convert_to_tensor([1, 2])
|
|
]
|
|
self.assertAllEqual(
|
|
data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSumGradArgs(self):
|
|
with self.session(use_gpu=False):
|
|
indices = [
|
|
ops.convert_to_tensor([0, 1, 2, 3]),
|
|
ops.convert_to_tensor([2, 3])
|
|
]
|
|
values = [
|
|
ops.convert_to_tensor([2, 3, 5, 7]),
|
|
ops.convert_to_tensor([1, 1])
|
|
]
|
|
self.assertAllEqual(
|
|
data_flow_ops.dynamic_stitch(indices, values).eval(), [2, 3, 1, 1])
|
|
|
|
# We expect that the values are merged in order.
|
|
@test_util.run_deprecated_v1
|
|
def testStitchOrder(self):
|
|
with self.cached_session():
|
|
indices = []
|
|
np_values = []
|
|
values = []
|
|
for _ in range(10):
|
|
indices.extend([ops.convert_to_tensor(np.arange(100).astype(np.int32))])
|
|
np_values.extend([np.random.uniform(size=100)])
|
|
values.extend([ops.convert_to_tensor(np_values[-1])])
|
|
stitched = data_flow_ops.dynamic_stitch(indices, values).eval()
|
|
self.assertAllEqual(np_values[-1], stitched)
|
|
|
|
|
|
class ParallelDynamicStitchOpTest(test.TestCase):
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testCint32Cpu(self):
|
|
with self.session(use_gpu=False):
|
|
indices = [
|
|
ops.convert_to_tensor([0, 1, 4, 6]),
|
|
ops.convert_to_tensor([2, 3, 5])
|
|
]
|
|
values = [
|
|
ops.convert_to_tensor([12, 23, 34, 45]),
|
|
ops.convert_to_tensor([1, 2, 3])
|
|
]
|
|
self.assertAllEqual(
|
|
data_flow_ops.parallel_dynamic_stitch(indices, values).eval(),
|
|
[12, 23, 1, 2, 34, 3, 45])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testInt32Cpu(self):
|
|
with self.session(use_gpu=False):
|
|
indices = [
|
|
ops.convert_to_tensor([0, 1, 5, 6, 7]),
|
|
ops.convert_to_tensor([2, 4, 3])
|
|
]
|
|
values = [
|
|
ops.convert_to_tensor([12, 23, 34, 45, 56]),
|
|
ops.convert_to_tensor([1, 3, 2])
|
|
]
|
|
self.assertAllEqual(
|
|
data_flow_ops.parallel_dynamic_stitch(indices, values).eval(),
|
|
[12, 23, 1, 2, 3, 34, 45, 56])
|
|
|
|
@test_util.run_deprecated_v1
|
|
def testSimple(self):
|
|
with self.session(use_gpu=False):
|
|
indices = [ops.convert_to_tensor([0, 1]), ops.convert_to_tensor([2, 3])]
|
|
values = [ops.convert_to_tensor([2, 3]), ops.convert_to_tensor([1, 1])]
|
|
self.assertAllEqual(
|
|
data_flow_ops.parallel_dynamic_stitch(indices, values).eval(),
|
|
[2, 3, 1, 1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|