Merge pull request #42300 from kttian:list_keys

PiperOrigin-RevId: 328350823
Change-Id: I41fc8ec95e2b70b9f86b58882f01a3e6a987b205
This commit is contained in:
TensorFlower Gardener 2020-08-25 10:12:32 -07:00
commit 51636d1795
9 changed files with 240 additions and 88 deletions

View File

@ -0,0 +1,8 @@
op {
graph_op_name: "TensorMapStackKeys"
summary: "Returns a Tensor stack of all keys in a tensor map."
description: <<END
input_handle: the input map
keys: the returned Tensor of all keys in the map
END
}

View File

@ -41,6 +41,9 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
TensorMapHasKey);
REGISTER_KERNEL_BUILDER(Name("TensorMapStackKeys").Device(DEVICE_CPU),
TensorMapStackKeys);
#undef REGISTER_TENSOR_MAP_OPS_CPU
#define REGISTER_TENSOR_MAP_OPS_CPU(T)

View File

@ -15,39 +15,37 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
#define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
#include <iostream>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/tensor_map.h"
#include "tensorflow/core/util/batch_util.h"
#include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow {
Status GetInputMap(OpKernelContext* c, int index, const TensorMap** map) {
if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) {
if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) {
return errors::InvalidArgument("Input map must be a scalar. Saw: ",
c->input(index).shape().DebugString());
ctx->input(index).shape().DebugString());
}
const TensorMap* m = c->input(index).scalar<Variant>()().get<TensorMap>();
if (m == nullptr) {
const TensorMap* map = ctx->input(index).scalar<Variant>()().get<TensorMap>();
if (map == nullptr) {
return errors::InvalidArgument(
"Input handle is not a map. Saw: '",
c->input(index).scalar<Variant>()().DebugString(), "'");
ctx->input(index).scalar<Variant>()().DebugString(), "'");
}
*map = m;
*ret_map = map;
return Status::OK();
}
// TODO(kattian): change into templated function
Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32 input_index,
int32 output_index,
const TensorMap& input_map,
TensorMap** output_map) {
// Attempt to forward the input tensor to the output if possible.
std::unique_ptr<Tensor> maybe_output = c->forward_input(
std::unique_ptr<Tensor> maybe_output = ctx->forward_input(
input_index, output_index, DT_VARIANT, TensorShape{},
c->input_memory_type(input_index), AllocatorAttributes());
ctx->input_memory_type(input_index), AllocatorAttributes());
Tensor* output_tensor;
if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
maybe_output->NumElements() == 1) {
@ -60,7 +58,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
}
if (tmp_out->RefCountIsOne()) {
// Woohoo, forwarding succeeded!
c->set_output(output_index, *output_tensor);
ctx->set_output(output_index, *output_tensor);
*output_map = tmp_out;
return Status::OK();
}
@ -71,7 +69,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
AllocatorAttributes attr;
attr.set_on_host(true);
TF_RETURN_IF_ERROR(
c->allocate_output(output_index, {}, &output_tensor, attr));
ctx->allocate_output(output_index, {}, &output_tensor, attr));
output_tensor->scalar<Variant>()() = input_map.Copy();
*output_map = output_tensor->scalar<Variant>()().get<TensorMap>();
@ -80,13 +78,13 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
class EmptyTensorMap : public OpKernel {
public:
explicit EmptyTensorMap(OpKernelConstruction* c) : OpKernel(c) {}
explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* c) override {
void Compute(OpKernelContext* ctx) override {
Tensor* result;
AllocatorAttributes attr;
attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
TensorMap empty;
result->scalar<Variant>()() = std::move(empty);
}
@ -94,87 +92,136 @@ class EmptyTensorMap : public OpKernel {
class TensorMapSize : public OpKernel {
public:
explicit TensorMapSize(OpKernelConstruction* c) : OpKernel(c) {}
explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapSize() override {}
void Compute(OpKernelContext* c) override {
const TensorMap* m = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
void Compute(OpKernelContext* ctx) override {
const TensorMap* map = nullptr;
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
result->scalar<int32>()() = m->tensors().size();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
result->scalar<int32>()() = map->tensors().size();
}
};
class TensorMapLookup : public OpKernel {
public:
explicit TensorMapLookup(OpKernelConstruction* c) : OpKernel(c) {}
explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapLookup() override {}
void Compute(OpKernelContext* c) override {
const TensorKey& key = c->input(1);
const TensorMap* m = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
void Compute(OpKernelContext* ctx) override {
const TensorKey& key = ctx->input(1);
const TensorMap* map = nullptr;
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(),
errors::InvalidArgument("Trying to lookup non-existent key."));
OP_REQUIRES(
ctx, map->tensors().find(key) != map->tensors().end(),
errors::InvalidArgument("Trying to lookup non-existent key. Could not "
"find key \"" +
key.SummarizeValue(100) + "\"."));
c->set_output(0, m->tensors().find(key)->second);
ctx->set_output(0, map->tensors().find(key)->second);
}
};
class TensorMapInsert : public OpKernel {
public:
explicit TensorMapInsert(OpKernelConstruction* c) : OpKernel(c) {}
explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapInsert() override {}
void Compute(OpKernelContext* c) override {
const TensorKey& key = c->input(1);
const Tensor& value = c->input(2);
const TensorMap* m = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
void Compute(OpKernelContext* ctx) override {
const TensorKey& key = ctx->input(1);
const Tensor& value = ctx->input(2);
const TensorMap* map = nullptr;
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
TensorMap* output_map = nullptr;
OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map));
OP_REQUIRES_OK(ctx,
ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
output_map->replace(key, value);
}
};
class TensorMapErase : public OpKernel {
public:
explicit TensorMapErase(OpKernelConstruction* c) : OpKernel(c) {}
explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* c) override {
const TensorKey& key = c->input(1);
const TensorMap* m = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
void Compute(OpKernelContext* ctx) override {
const TensorKey& key = ctx->input(1);
const TensorMap* map = nullptr;
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(),
errors::InvalidArgument("Trying to erase non-existent item."));
OP_REQUIRES(
ctx, map->tensors().find(key) != map->tensors().end(),
errors::InvalidArgument("Trying to erase non-existent item. Could not "
"find key \"" +
key.SummarizeValue(100) + "\"."));
TensorMap* output_map = nullptr;
OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map));
OP_REQUIRES_OK(ctx,
ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
output_map->tensors().erase(key);
}
};
class TensorMapHasKey : public OpKernel {
public:
explicit TensorMapHasKey(OpKernelConstruction* c) : OpKernel(c) {}
explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapHasKey() override {}
void Compute(OpKernelContext* c) override {
const TensorKey& key = c->input(1);
const TensorMap* m = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
void Compute(OpKernelContext* ctx) override {
const TensorKey& key = ctx->input(1);
const TensorMap* map = nullptr;
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
result->scalar<bool>()() = m->tensors().find(key) != m->tensors().end();
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
result->scalar<bool>()() = map->tensors().find(key) != map->tensors().end();
}
};
class TensorMapStackKeys : public OpKernel {
public:
explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_));
}
~TensorMapStackKeys() override {}
void Compute(OpKernelContext* ctx) override {
const TensorMap* map = nullptr;
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
OP_REQUIRES(ctx, map->size() != 0,
errors::InvalidArgument(
"TensorMapStackKeys cannot be called on empty map."));
auto it = map->tensors().begin();
TensorShape output_shape = it->first.shape();
output_shape.InsertDim(0, map->tensors().size());
Tensor* result;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
int i = 0;
size_t sz = map->tensors().size();
TensorShape key_shape = it->first.shape();
while (it != map->tensors().end() && i < sz) {
OP_REQUIRES(
ctx, it->first.dtype() == key_dtype_,
errors::InvalidArgument("Key does not match requested dtype."));
OP_REQUIRES(
ctx, it->first.shape() == key_shape,
errors::InvalidArgument("Keys must all have the same shape."));
OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i));
i++;
it++;
}
}
private:
DataType key_dtype_;
};
template <typename Device>
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a,
const TensorMap& b, TensorMap* out) {
// Binary add returns a map containing the union of keys.
// Values with keys in the intersection are added.
@ -185,7 +232,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
if (it != out->tensors().end()) {
Tensor out_tensor;
TF_RETURN_IF_ERROR(
BinaryAddTensors<Device>(c, p.second, it->second, &out_tensor));
BinaryAddTensors<Device>(ctx, p.second, it->second, &out_tensor));
it->second = out_tensor;
} else {
out->tensors().emplace(p.first, p.second);
@ -195,7 +242,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
}
template <typename Device>
Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x,
Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x,
TensorMap* y) {
// Zeros like returns an empty map.
return Status::OK();

View File

@ -144,7 +144,19 @@ class TensorMap {
size_t erase(TensorKey key) { return tensors_->values_.erase(key); }
// Size returns the number of elements in the map
size_t size() { return tensors_->values_.size(); }
size_t size() const { return tensors_->values_.size(); }
std::vector<Tensor> keys() const {
std::vector<Tensor> keys;
keys.reserve(tensors_->values_.size());
absl::flat_hash_map<TensorKey, Tensor>::iterator it =
tensors_->values_.begin();
while (it != tensors_->values_.end()) {
keys.push_back(it->first);
it++;
}
return keys;
}
// Is this TensorMap the only one with a reference to the underlying
// container?

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/tensor_map.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -45,7 +44,6 @@ TEST(TensorKeyTest, Equal) {
}
TEST(TensorMapTest, Insert) {
EXPECT_EQ(1, 1);
TensorMap tm;
TensorKey k = Tensor(11);
Tensor v = Tensor(22);
@ -102,12 +100,49 @@ TEST(TensorMapTest, Replace) {
Tensor v1 = Tensor(22);
Tensor v2 = Tensor(23);
tm[k] = v2;
absl::flat_hash_map<TensorKey, Tensor>::iterator map_it = tm.find(k);
EXPECT_EQ(map_it->first, k);
test::ExpectTensorEqual<int32>(map_it->second, v2);
}
TEST(TensorMapTest, ListKeys) {
TensorMap tm;
TensorKey k = Tensor(11.0);
TensorKey k2 = Tensor(12.0);
Tensor v = Tensor(22);
Tensor v2 = Tensor(23);
tm.insert(k, v);
tm.insert(k2, v2);
std::vector<Tensor> keys = tm.keys();
// Extract and sort double value for each key Tensor.
std::vector<std::pair<double, int>> key_doubles;
for (int i = 0; i < keys.size(); i++) {
double x = keys[i].scalar<double>()();
std::pair<double, int> p = std::pair<double, int>(x, i);
key_doubles.push_back(p);
}
sort(key_doubles.begin(), key_doubles.end());
// Check number of keys and each key.
EXPECT_EQ(keys.size(), 2);
EXPECT_EQ(key_doubles[0].first, 11.0);
EXPECT_EQ(key_doubles[1].first, 12.0);
// Check key shapes.
int ind1 = key_doubles[0].second;
int ind2 = key_doubles[1].second;
EXPECT_EQ(keys[ind1].shape(), k.shape());
EXPECT_EQ(keys[ind2].shape(), k2.shape());
}
TEST(TensorMapTest, Size) {
TensorMap tm;
EXPECT_EQ(tm.size(), 0);
TensorKey k = Tensor(11);
Tensor v = Tensor(22);
tm.insert(k, v);
EXPECT_EQ(tm.size(), 1);
}
TEST(TensorMapTest, Copy) {
TensorMap tm;
TensorKey k = Tensor(11);

View File

@ -1,19 +0,0 @@
op {
name: "TensorMapHasKey"
input_arg {
name: "input_handle"
type: DT_VARIANT
}
input_arg {
name: "key"
type_attr: "element_dtype"
}
output_arg {
name: "has_key"
type: DT_BOOL
}
attr {
name: "element_dtype"
type: "type"
}
}

View File

@ -69,10 +69,19 @@ REGISTER_OP("TensorMapErase")
REGISTER_OP("TensorMapHasKey")
.Input("input_handle: variant")
.Input("key: element_dtype")
.Input("key: key_dtype")
.Output("has_key: bool")
.Attr("element_dtype: type")
.Attr("key_dtype: type")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("TensorMapStackKeys")
.Input("input_handle: variant")
.Output("keys: key_dtype")
.Attr("key_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShape()); // output keys
return Status::OK();
});
} // namespace
} // namespace tensorflow

View File

@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import map_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.platform import test
@ -57,7 +58,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to lookup non-existent key."):
"Trying to lookup non-existent key. *"):
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
self.evaluate(l)
@ -68,7 +69,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
v = constant_op.constant(11.0)
m = map_ops.tensor_map_insert(m, k, v)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to lookup non-existent key."):
"Trying to lookup non-existent key. *"):
l = map_ops.tensor_map_lookup(m, k2, dtypes.float32)
self.evaluate(l)
@ -87,7 +88,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to erase non-existent item."):
"Trying to erase non-existent item. *"):
m = map_ops.tensor_map_erase(m, k, dtypes.float32)
self.evaluate(m)
@ -98,7 +99,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
v = constant_op.constant(2.0)
m = map_ops.tensor_map_insert(m, k2, v)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to erase non-existent item."):
"Trying to erase non-existent item. *"):
m = map_ops.tensor_map_erase(m, k, dtypes.float32)
self.evaluate(m)
@ -133,6 +134,58 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllClose(l, v)
self.assertAllClose(l2, default_value)
def testStackKeys(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
k3 = constant_op.constant(3.0)
v = constant_op.constant(21.0)
v2 = constant_op.constant(22.0)
v3 = constant_op.constant(23.0)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
expected = constant_op.constant([1.0, 2.0])
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
self.assertAllClose(sort_ops.sort(keys), expected)
m = map_ops.tensor_map_insert(m, k3, v3)
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
expected = constant_op.constant([1.0, 2.0, 3.0])
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
self.assertAllClose(sort_ops.sort(keys), expected)
def testStackKeysEmptyMapFails(self):
m = map_ops.empty_tensor_map()
with self.assertRaisesRegex(
errors.InvalidArgumentError, "TensorMapStackKeys cannot be called "
"on empty map."):
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
self.evaluate(keys)
def testStackKeysIncorrectDtypeFails(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant("key_with_wrong_dtype")
v = constant_op.constant(2.0)
m = map_ops.tensor_map_insert(m, k, v)
simple = "Key does not match requested dtype."
with self.assertRaisesRegex(errors.InvalidArgumentError, simple):
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
self.evaluate(keys)
def testStackKeysIncorrectShapeFails(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant([1.0, 11.0])
v = constant_op.constant(2.0)
v2 = constant_op.constant(22.0)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Keys must all have the same shape."):
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
self.evaluate(keys)
def testInsertLookupGrad(self):
with backprop.GradientTape() as tape:
m = map_ops.empty_tensor_map()
@ -397,6 +450,5 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(s, 0)
self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False)
if __name__ == "__main__":
test.main()

View File

@ -46,6 +46,11 @@ def tensor_map_erase(input_handle, key, value_dtype):
def tensor_map_has_key(input_handle, key):
return gen_map_ops.tensor_map_has_key(input_handle, key)
def tensor_map_stack_keys(input_handle, key_dtype):
return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype)
@ops.RegisterGradient("TensorMapLookup")
def LookupGrad(op, dval):
_, k = op.inputs