Merge pull request #43269 from rhdong:RFC237-patch-for-recsys-sig

PiperOrigin-RevId: 339248366
Change-Id: I4bfb177937a173f20c8165a31ee5c5c7bdf9ef75
This commit is contained in:
TensorFlower Gardener 2020-10-27 08:06:45 -07:00
commit b0784e587b
7 changed files with 124 additions and 16 deletions

View File

@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
const Tensor& default_value) {
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
if (default_value.shape() != value_shape()) {
TensorShape fullsize_value_shape = key.shape();
for (int i = 0; i < key_shape().dims(); ++i) {
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
}
fullsize_value_shape.AppendShape(value_shape());
if (default_value.shape() != value_shape() &&
default_value.shape() != fullsize_value_shape) {
return errors::InvalidArgument(
"Expected shape ", value_shape().DebugString(),
" for default value, got ", default_value.shape().DebugString());
"Expected shape ", value_shape().DebugString(), " or ",
fullsize_value_shape.DebugString(), " for default value, got ",
default_value.shape().DebugString());
}
return Status::OK();
}

View File

@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
// requirements are satisfied, otherwise it returns InvalidArgument:
// - DataType of the tensor keys equals to the table key_dtype
// - DataType of the tensor default_value equals to the table value_dtype
// - the default_value tensor shape matches the table's value shape.
// - the default_value tensor has the required shape given keys and the
// tables's value shape.
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
string DebugString() const override {

View File

@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface {
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const V default_val = default_value.flat<V>()(0);
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
const auto default_flat = default_value.flat<V>();
int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_size_default = (total == default_total);
tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
// is_full_size_default is true:
// Each key has an independent default value, key_values(i)
// corresponding uses default_flat(i) as its default value.
//
// is_full_size_default is false:
// All keys will share the default_flat(0) as default value.
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
table_, SubtleMustCopyIfIntegral(key_values(i)),
is_full_size_default ? default_flat(i) : default_flat(0));
}
return Status::OK();
@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override {
const auto default_flat = default_value.flat<V>();
const auto default_flat = default_value.flat_inner_dims<V, 2>();
const auto key_values = key.flat<K>();
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);
int64 total = value_values.size();
int64 default_total = default_flat.size();
bool is_full_size_default = (total == default_total);
tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
value_values(i, j) = value_vec->at(j);
}
} else {
// is_full_size_default is true:
// Each key has an independent default value, key_values(i)
// corresponding uses default_flat(i) as its default value.
//
// is_full_size_default is false:
// All keys will share the default_flat(0) as default value.
for (int64 j = 0; j < value_dim; j++) {
value_values(i, j) = default_flat(j);
value_values(i, j) =
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
}
}
}

View File

@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2")
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
// Default value must be scalar or vector.
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
ShapeAndType value_shape_and_type;
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
c,

View File

@ -25,7 +25,6 @@ namespace {
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
ShapeInferenceTestOp op("LookupTableFindV2");
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]");
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
.Input({"table_handle", 0, DT_RESOURCE})
.Input({"keys", 0, DT_INT64})

View File

@ -3375,6 +3375,71 @@ class MutableHashTableOpTest(test.TestCase):
result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
default_val = [-1, -1]
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
input_string = constant_op.constant([["brain", "salad"], ["tank",
"tarkus"]])
invalid_default_val = constant_op.constant(
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
with self.assertRaisesRegex(
(ValueError, errors_impl.InvalidArgumentError),
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
self.evaluate(table.lookup(input_string, invalid_default_val))
invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
dtypes.int64)
with self.assertRaisesRegex(
(ValueError, errors_impl.InvalidArgumentError),
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
self.evaluate(table.lookup(input_string, invalid_default_val))
def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))
input_string = constant_op.constant([["brain", "salad"], ["tank",
"tarkus"]])
dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2], output.get_shape())
result = self.evaluate(output)
self.assertAllEqual([[0, 1], [-4, -5]], result)
def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
default_val = [-1, -1]
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
self.evaluate(table.insert(keys, values))
self.assertAllEqual(3, self.evaluate(table.size()))
input_string = constant_op.constant([["brain", "salad"], ["tank",
"tarkus"]])
dynamic_default_val = constant_op.constant(
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
output = table.lookup(input_string, dynamic_default_val)
self.assertAllEqual([2, 2, 2], output.get_shape())
result = self.evaluate(output)
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
def testMutableHashTableInsertHighRank(self):
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])

View File

@ -1849,7 +1849,7 @@ class MutableHashTable(LookupInterface):
return op
def lookup(self, keys, name=None):
def lookup(self, keys, dynamic_default_values=None, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
The `default_value` is used for keys not present in the table.
@ -1857,6 +1857,23 @@ class MutableHashTable(LookupInterface):
Args:
keys: Keys to look up. Can be a tensor of any shape. Must match the
table's key_dtype.
dynamic_default_values: The values to use if a key is missing in the
table. If None (by default), the `table.default_value` will be used.
Shape of `dynamic_default_values` must be same with
`table.default_value` or the lookup result tensor.
In the latter case, each key will have a different default value.
For example:
```python
keys = [0, 1, 3]
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
# The key '0' will use [1, 3, 4] as default value.
# The key '1' will use [2, 3, 9] as default value.
# The key '3' will use [8, 3, 0] as default value.
```
name: A name for the operation (optional).
Returns:
@ -1870,8 +1887,9 @@ class MutableHashTable(LookupInterface):
(self.resource_handle, keys, self._default_value)):
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
with ops.colocate_with(self.resource_handle):
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
self._default_value)
values = gen_lookup_ops.lookup_table_find_v2(
self.resource_handle, keys, dynamic_default_values
if dynamic_default_values is not None else self._default_value)
return values
def insert(self, keys, values, name=None):