Merge pull request #43269 from rhdong:RFC237-patch-for-recsys-sig
PiperOrigin-RevId: 339248366 Change-Id: I4bfb177937a173f20c8165a31ee5c5c7bdf9ef75
This commit is contained in:
commit
b0784e587b
@ -83,10 +83,17 @@ Status LookupInterface::CheckFindArguments(const Tensor& key,
|
|||||||
const Tensor& default_value) {
|
const Tensor& default_value) {
|
||||||
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
|
TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(key, default_value));
|
||||||
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
|
TF_RETURN_IF_ERROR(CheckKeyShape(key.shape()));
|
||||||
if (default_value.shape() != value_shape()) {
|
TensorShape fullsize_value_shape = key.shape();
|
||||||
|
for (int i = 0; i < key_shape().dims(); ++i) {
|
||||||
|
fullsize_value_shape.RemoveDim(fullsize_value_shape.dims() - 1);
|
||||||
|
}
|
||||||
|
fullsize_value_shape.AppendShape(value_shape());
|
||||||
|
if (default_value.shape() != value_shape() &&
|
||||||
|
default_value.shape() != fullsize_value_shape) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Expected shape ", value_shape().DebugString(),
|
"Expected shape ", value_shape().DebugString(), " or ",
|
||||||
" for default value, got ", default_value.shape().DebugString());
|
fullsize_value_shape.DebugString(), " for default value, got ",
|
||||||
|
default_value.shape().DebugString());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -128,7 +128,8 @@ class LookupInterface : public ResourceBase {
|
|||||||
// requirements are satisfied, otherwise it returns InvalidArgument:
|
// requirements are satisfied, otherwise it returns InvalidArgument:
|
||||||
// - DataType of the tensor keys equals to the table key_dtype
|
// - DataType of the tensor keys equals to the table key_dtype
|
||||||
// - DataType of the tensor default_value equals to the table value_dtype
|
// - DataType of the tensor default_value equals to the table value_dtype
|
||||||
// - the default_value tensor shape matches the table's value shape.
|
// - the default_value tensor has the required shape given keys and the
|
||||||
|
// tables's value shape.
|
||||||
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
|
Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
|
||||||
|
|
||||||
string DebugString() const override {
|
string DebugString() const override {
|
||||||
|
@ -56,14 +56,25 @@ class MutableHashTableOfScalars final : public LookupInterface {
|
|||||||
|
|
||||||
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
||||||
const Tensor& default_value) override {
|
const Tensor& default_value) override {
|
||||||
const V default_val = default_value.flat<V>()(0);
|
|
||||||
const auto key_values = key.flat<K>();
|
const auto key_values = key.flat<K>();
|
||||||
auto value_values = value->flat<V>();
|
auto value_values = value->flat<V>();
|
||||||
|
const auto default_flat = default_value.flat<V>();
|
||||||
|
|
||||||
|
int64 total = value_values.size();
|
||||||
|
int64 default_total = default_flat.size();
|
||||||
|
bool is_full_size_default = (total == default_total);
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
for (int64 i = 0; i < key_values.size(); ++i) {
|
for (int64 i = 0; i < key_values.size(); ++i) {
|
||||||
|
// is_full_size_default is true:
|
||||||
|
// Each key has an independent default value, key_values(i)
|
||||||
|
// corresponding uses default_flat(i) as its default value.
|
||||||
|
//
|
||||||
|
// is_full_size_default is false:
|
||||||
|
// All keys will share the default_flat(0) as default value.
|
||||||
value_values(i) = gtl::FindWithDefault(
|
value_values(i) = gtl::FindWithDefault(
|
||||||
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
|
table_, SubtleMustCopyIfIntegral(key_values(i)),
|
||||||
|
is_full_size_default ? default_flat(i) : default_flat(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -173,11 +184,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
|
|||||||
|
|
||||||
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
|
||||||
const Tensor& default_value) override {
|
const Tensor& default_value) override {
|
||||||
const auto default_flat = default_value.flat<V>();
|
const auto default_flat = default_value.flat_inner_dims<V, 2>();
|
||||||
const auto key_values = key.flat<K>();
|
const auto key_values = key.flat<K>();
|
||||||
auto value_values = value->flat_inner_dims<V, 2>();
|
auto value_values = value->flat_inner_dims<V, 2>();
|
||||||
int64 value_dim = value_shape_.dim_size(0);
|
int64 value_dim = value_shape_.dim_size(0);
|
||||||
|
|
||||||
|
int64 total = value_values.size();
|
||||||
|
int64 default_total = default_flat.size();
|
||||||
|
bool is_full_size_default = (total == default_total);
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
for (int64 i = 0; i < key_values.size(); ++i) {
|
for (int64 i = 0; i < key_values.size(); ++i) {
|
||||||
ValueArray* value_vec =
|
ValueArray* value_vec =
|
||||||
@ -187,8 +202,15 @@ class MutableHashTableOfTensors final : public LookupInterface {
|
|||||||
value_values(i, j) = value_vec->at(j);
|
value_values(i, j) = value_vec->at(j);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// is_full_size_default is true:
|
||||||
|
// Each key has an independent default value, key_values(i)
|
||||||
|
// corresponding uses default_flat(i) as its default value.
|
||||||
|
//
|
||||||
|
// is_full_size_default is false:
|
||||||
|
// All keys will share the default_flat(0) as default value.
|
||||||
for (int64 j = 0; j < value_dim; j++) {
|
for (int64 j = 0; j < value_dim; j++) {
|
||||||
value_values(i, j) = default_flat(j);
|
value_values(i, j) =
|
||||||
|
is_full_size_default ? default_flat(i, j) : default_flat(0, j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -169,10 +169,6 @@ REGISTER_OP("LookupTableFindV2")
|
|||||||
ShapeHandle handle;
|
ShapeHandle handle;
|
||||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
||||||
|
|
||||||
// Default value must be scalar or vector.
|
|
||||||
ShapeHandle keys;
|
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
|
|
||||||
|
|
||||||
ShapeAndType value_shape_and_type;
|
ShapeAndType value_shape_and_type;
|
||||||
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
|
TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
|
||||||
c,
|
c,
|
||||||
|
@ -25,7 +25,6 @@ namespace {
|
|||||||
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
|
TEST(LookupOpsTest, LookupTableFindV2_ShapeFn) {
|
||||||
ShapeInferenceTestOp op("LookupTableFindV2");
|
ShapeInferenceTestOp op("LookupTableFindV2");
|
||||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
|
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[?];?;?");
|
||||||
INFER_ERROR("Shape must be at most rank 1 but is rank 2", op, "[];?;[1,1]");
|
|
||||||
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
|
TF_ASSERT_OK(NodeDefBuilder("test", "LookupTableFindV2")
|
||||||
.Input({"table_handle", 0, DT_RESOURCE})
|
.Input({"table_handle", 0, DT_RESOURCE})
|
||||||
.Input({"keys", 0, DT_INT64})
|
.Input({"keys", 0, DT_INT64})
|
||||||
|
@ -3375,6 +3375,71 @@ class MutableHashTableOpTest(test.TestCase):
|
|||||||
result = self.evaluate(output)
|
result = self.evaluate(output)
|
||||||
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
self.assertAllEqual([[0, 1], [-1, -1]], result)
|
||||||
|
|
||||||
|
def testMutableHashTableFindWithInvalidShapeDefaultValue(self):
|
||||||
|
default_val = [-1, -1]
|
||||||
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||||
|
default_val)
|
||||||
|
|
||||||
|
input_string = constant_op.constant([["brain", "salad"], ["tank",
|
||||||
|
"tarkus"]])
|
||||||
|
|
||||||
|
invalid_default_val = constant_op.constant(
|
||||||
|
[[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
(ValueError, errors_impl.InvalidArgumentError),
|
||||||
|
"Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
|
||||||
|
self.evaluate(table.lookup(input_string, invalid_default_val))
|
||||||
|
|
||||||
|
invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
|
||||||
|
dtypes.int64)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
(ValueError, errors_impl.InvalidArgumentError),
|
||||||
|
"Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
|
||||||
|
self.evaluate(table.lookup(input_string, invalid_default_val))
|
||||||
|
|
||||||
|
def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(self):
|
||||||
|
default_val = -1
|
||||||
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
|
values = constant_op.constant([0, 1, 2], dtypes.int64)
|
||||||
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||||
|
default_val)
|
||||||
|
|
||||||
|
self.evaluate(table.insert(keys, values))
|
||||||
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
|
input_string = constant_op.constant([["brain", "salad"], ["tank",
|
||||||
|
"tarkus"]])
|
||||||
|
|
||||||
|
dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
|
||||||
|
dtypes.int64)
|
||||||
|
output = table.lookup(input_string, dynamic_default_val)
|
||||||
|
self.assertAllEqual([2, 2], output.get_shape())
|
||||||
|
|
||||||
|
result = self.evaluate(output)
|
||||||
|
self.assertAllEqual([[0, 1], [-4, -5]], result)
|
||||||
|
|
||||||
|
def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(self):
|
||||||
|
default_val = [-1, -1]
|
||||||
|
keys = constant_op.constant(["brain", "salad", "surgery"])
|
||||||
|
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
|
||||||
|
table = lookup_ops.MutableHashTable(dtypes.string, dtypes.int64,
|
||||||
|
default_val)
|
||||||
|
|
||||||
|
self.evaluate(table.insert(keys, values))
|
||||||
|
self.assertAllEqual(3, self.evaluate(table.size()))
|
||||||
|
|
||||||
|
input_string = constant_op.constant([["brain", "salad"], ["tank",
|
||||||
|
"tarkus"]])
|
||||||
|
|
||||||
|
dynamic_default_val = constant_op.constant(
|
||||||
|
[[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
|
||||||
|
output = table.lookup(input_string, dynamic_default_val)
|
||||||
|
self.assertAllEqual([2, 2, 2], output.get_shape())
|
||||||
|
|
||||||
|
result = self.evaluate(output)
|
||||||
|
self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
|
||||||
|
|
||||||
def testMutableHashTableInsertHighRank(self):
|
def testMutableHashTableInsertHighRank(self):
|
||||||
default_val = -1
|
default_val = -1
|
||||||
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
|
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
|
||||||
|
@ -1849,7 +1849,7 @@ class MutableHashTable(LookupInterface):
|
|||||||
|
|
||||||
return op
|
return op
|
||||||
|
|
||||||
def lookup(self, keys, name=None):
|
def lookup(self, keys, dynamic_default_values=None, name=None):
|
||||||
"""Looks up `keys` in a table, outputs the corresponding values.
|
"""Looks up `keys` in a table, outputs the corresponding values.
|
||||||
|
|
||||||
The `default_value` is used for keys not present in the table.
|
The `default_value` is used for keys not present in the table.
|
||||||
@ -1857,6 +1857,23 @@ class MutableHashTable(LookupInterface):
|
|||||||
Args:
|
Args:
|
||||||
keys: Keys to look up. Can be a tensor of any shape. Must match the
|
keys: Keys to look up. Can be a tensor of any shape. Must match the
|
||||||
table's key_dtype.
|
table's key_dtype.
|
||||||
|
dynamic_default_values: The values to use if a key is missing in the
|
||||||
|
table. If None (by default), the `table.default_value` will be used.
|
||||||
|
Shape of `dynamic_default_values` must be same with
|
||||||
|
`table.default_value` or the lookup result tensor.
|
||||||
|
In the latter case, each key will have a different default value.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
keys = [0, 1, 3]
|
||||||
|
dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
|
||||||
|
|
||||||
|
# The key '0' will use [1, 3, 4] as default value.
|
||||||
|
# The key '1' will use [2, 3, 9] as default value.
|
||||||
|
# The key '3' will use [8, 3, 0] as default value.
|
||||||
|
```
|
||||||
|
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1870,8 +1887,9 @@ class MutableHashTable(LookupInterface):
|
|||||||
(self.resource_handle, keys, self._default_value)):
|
(self.resource_handle, keys, self._default_value)):
|
||||||
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
|
keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
|
||||||
with ops.colocate_with(self.resource_handle):
|
with ops.colocate_with(self.resource_handle):
|
||||||
values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
|
values = gen_lookup_ops.lookup_table_find_v2(
|
||||||
self._default_value)
|
self.resource_handle, keys, dynamic_default_values
|
||||||
|
if dynamic_default_values is not None else self._default_value)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def insert(self, keys, values, name=None):
|
def insert(self, keys, values, name=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user