Merge pull request #43269 from rhdong:RFC237-patch-for-recsys-sig
PiperOrigin-RevId: 339248366 Change-Id: I4bfb177937a173f20c8165a31ee5c5c7bdf9ef75
This commit is contained in:
commit
b0784e587b
@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
|
||||
const Tensor& default_value) {
|
||||
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
|
||||
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
|
||||
if (default_value.shape() != value_shape()) {
|
||||
TensorShape fullsize_value_shape = key.shape();
|
||||
for (int i = 0; i < key_shape().dims(); ++i) {
|
||||
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
|
||||
}
|
||||
fullsize_value_shape.AppendShape(value_shape());
|
||||
if (default_value.shape() != value_shape() &&
|
||||
default_value.shape() != fullsize_value_shape) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected shape ", value_shape().DebugString(),
|
||||
" for default value, got ", default_value.shape().DebugString());
|
||||
"Expected shape ", value_shape().DebugString(), " or ",
|
||||
fullsize_value_shape.DebugString(), " for default value, got ",
|
||||
default_value.shape().DebugString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
|
||||
// requirements are satisfied, otherwise it returns InvalidArgument:
|
||||
// - DataType of the tensor keys equals to the table key_dtype
|
||||
// - DataType of the tensor default_value equals to the table value_dtype
|
||||
// - the default_value tensor shape matches the table's value shape.
|
||||
// - the default_value tensor has the required shape given keys and the
|
||||
// tables's value shape.
|
||||
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
|
||||
|
||||
string DebugString() const override {
|
||||
|
@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface {
|
||||
|
||||
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
||||
const Tensor& default_value) override {
|
||||
const V default_val = default_value.flat<V>()(0);
|
||||
const auto key_values = key.flat<K>();
|
||||
auto value_values = value->flat<V>();
|
||||
const auto default_flat = default_value.flat<V>();
|
||||
|
||||
int64 total = value_values.size();
|
||||
int64 default_total = default_flat.size();
|
||||
bool is_full_size_default = (total == default_total);
|
||||
|
||||
tf_shared_lock l(mu_);
|
||||
for (int64 i = 0; i < key_values.size(); ++i) {
|
||||
// is_full_size_default is true:
|
||||
// Each key has an independent default value, key_values(i)
|
||||
// corresponding uses default_flat(i) as its default value.
|
||||
//
|
||||
// is_full_size_default is false:
|
||||
// All keys will share the default_flat(0) as default value.
|
||||
value_values(i) = gtl::FindWithDefault(
|
||||
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
|
||||
table_, SubtleMustCopyIfIntegral(key_values(i)),
|
||||
is_full_size_default ? default_flat(i) : default_flat(0));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
|
||||
|
||||
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
||||
const Tensor& default_value) override {
|
||||
const auto default_flat = default_value.flat<V>();
|
||||
const auto default_flat = default_value.flat_inner_dims<V, 2>();
|
||||
const auto key_values = key.flat<K>();
|
||||
auto value_values = value->flat_inner_dims<V, 2>();
|
||||
int64 value_dim = value_shape_.dim_size(0);
|
||||
|
||||
int64 total = value_values.size();
|
||||
int64 default_total = default_flat.size();
|
||||
bool is_full_size_default = (total == default_total);
|
||||
|
||||
tf_shared_lock l(mu_);
|
||||
for (int64 i = 0; i < key_values.size(); ++i) {
|
||||
ValueArray* value_vec =
|
||||
@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
|
||||
value_values(i, j) = value_vec->at(j);
|
||||
}
|
||||
} else {
|
||||
// is_full_size_default is true:
|
||||
// Each key has an independent default value, key_values(i)
|
||||
// corresponding uses default_flat(i) as its default value.
|
||||
//
|
||||
// is_full_size_default is false:
|
||||
// All keys will share the default_flat(0) as default value.
|
||||
for (int64 j = 0; j < value_dim; j++) {
|
||||
value_values(i, j) = default_flat(j);
|
||||
value_values(i, j) =
|
||||
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2")
|
||||
ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
||||
|
||||
// Default value must be scalar or vector.
|
||||
ShapeHandle keys;
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
|
||||
|
||||
ShapeAndType value_shape_and_type;
|
||||
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
|
||||
c,
|
||||
|
@ -25,7 +25,6 @@ namespace {
|
||||
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
|
||||
ShapeInferenceTestOp op("LookupTableFindV2");
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
|
||||
INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]");
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
|
||||
.Input({"table_handle", 0, DT_RESOURCE})
|
||||
.Input({"keys", 0, DT_INT64})
|
||||
|
@ -3375,6 +3375,71 @@ class MutableHashTableOpTest(test.TestCase):
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
||||
|
||||
def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
|
||||
default_val = [-1, -1]
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
|
||||
input_string = constant_op.constant([["brain", "salad"], ["tank",
|
||||
"tarkus"]])
|
||||
|
||||
invalid_default_val = constant_op.constant(
|
||||
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
(ValueError, errors_impl.InvalidArgumentError),
|
||||
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
|
||||
self.evaluate(table.lookup(input_string, invalid_default_val))
|
||||
|
||||
invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
|
||||
dtypes.int64)
|
||||
with self.assertRaisesRegex(
|
||||
(ValueError, errors_impl.InvalidArgumentError),
|
||||
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
|
||||
self.evaluate(table.lookup(input_string, invalid_default_val))
|
||||
|
||||
def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
|
||||
default_val = -1
|
||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
|
||||
self.evaluate(table.insert(keys, values))
|
||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||
|
||||
input_string = constant_op.constant([["brain", "salad"], ["tank",
|
||||
"tarkus"]])
|
||||
|
||||
dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
|
||||
dtypes.int64)
|
||||
output = table.lookup(input_string, dynamic_default_val)
|
||||
self.assertAllEqual([2, 2], output.get_shape())
|
||||
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([[0, 1], [-4, -5]], result)
|
||||
|
||||
def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
|
||||
default_val = [-1, -1]
|
||||
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
|
||||
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||
default_val)
|
||||
|
||||
self.evaluate(table.insert(keys, values))
|
||||
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||
|
||||
input_string = constant_op.constant([["brain", "salad"], ["tank",
|
||||
"tarkus"]])
|
||||
|
||||
dynamic_default_val = constant_op.constant(
|
||||
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
|
||||
output = table.lookup(input_string, dynamic_default_val)
|
||||
self.assertAllEqual([2, 2, 2], output.get_shape())
|
||||
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
|
||||
|
||||
def testMutableHashTableInsertHighRank(self):
|
||||
default_val = -1
|
||||
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
|
||||
|
@ -1849,7 +1849,7 @@ class MutableHashTable(LookupInterface):
|
||||
|
||||
return op
|
||||
|
||||
def lookup(self, keys, name=None):
|
||||
def lookup(self, keys, dynamic_default_values=None, name=None):
|
||||
"""Looks up `keys` in a table, outputs the corresponding values.
|
||||
|
||||
The `default_value` is used for keys not present in the table.
|
||||
@ -1857,6 +1857,23 @@ class MutableHashTable(LookupInterface):
|
||||
Args:
|
||||
keys: Keys to look up. Can be a tensor of any shape. Must match the
|
||||
table's key_dtype.
|
||||
dynamic_default_values: The values to use if a key is missing in the
|
||||
table. If None (by default), the `table.default_value` will be used.
|
||||
Shape of `dynamic_default_values` must be same with
|
||||
`table.default_value` or the lookup result tensor.
|
||||
In the latter case, each key will have a different default value.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
keys = [0, 1, 3]
|
||||
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
|
||||
|
||||
# The key '0' will use [1, 3, 4] as default value.
|
||||
# The key '1' will use [2, 3, 9] as default value.
|
||||
# The key '3' will use [8, 3, 0] as default value.
|
||||
```
|
||||
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -1870,8 +1887,9 @@ class MutableHashTable(LookupInterface):
|
||||
(self.resource_handle, keys, self._default_value)):
|
||||
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
|
||||
with ops.colocate_with(self.resource_handle):
|
||||
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
|
||||
self._default_value)
|
||||
values = gen_lookup_ops.lookup_table_find_v2(
|
||||
self.resource_handle, keys, dynamic_default_values
|
||||
if dynamic_default_values is not None else self._default_value)
|
||||
return values
|
||||
|
||||
def insert(self, keys, values, name=None):
|
||||
|
Loading…
Reference in New Issue
Block a user