Merge pull request #42300 from kttian:list_keys

PiperOrigin-RevId: 328350823
Change-Id: I41fc8ec95e2b70b9f86b58882f01a3e6a987b205
This commit is contained in:
TensorFlower Gardener 2020-08-25 10:12:32 -07:00
commit 51636d1795
9 changed files with 240 additions and 88 deletions

View File

@ -0,0 +1,8 @@
op {
graph_op_name: "TensorMapStackKeys"
summary: "Returns a Tensor stack of all keys in a tensor map."
description: <<END
input_handle: the input map
keys: the returned Tensor of all keys in the map
END
}

View File

@ -41,6 +41,9 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
TensorMapHasKey); TensorMapHasKey);
REGISTER_KERNEL_BUILDER(Name("TensorMapStackKeys").Device(DEVICE_CPU),
TensorMapStackKeys);
#undef REGISTER_TENSOR_MAP_OPS_CPU #undef REGISTER_TENSOR_MAP_OPS_CPU
#define REGISTER_TENSOR_MAP_OPS_CPU(T) #define REGISTER_TENSOR_MAP_OPS_CPU(T)

View File

@ -15,39 +15,37 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ #ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
#define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_ #define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
#include <iostream>
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/tensor_map.h" #include "tensorflow/core/kernels/tensor_map.h"
#include "tensorflow/core/util/batch_util.h"
#include "tensorflow/core/util/tensor_ops_util.h" #include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow { namespace tensorflow {
Status GetInputMap(OpKernelContext* c, int index, const TensorMap** map) { Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) {
if (!TensorShapeUtils::IsScalar(c->input(index).shape())) { if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) {
return errors::InvalidArgument("Input map must be a scalar. Saw: ", return errors::InvalidArgument("Input map must be a scalar. Saw: ",
c->input(index).shape().DebugString()); ctx->input(index).shape().DebugString());
} }
const TensorMap* m = c->input(index).scalar<Variant>()().get<TensorMap>(); const TensorMap* map = ctx->input(index).scalar<Variant>()().get<TensorMap>();
if (m == nullptr) { if (map == nullptr) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Input handle is not a map. Saw: '", "Input handle is not a map. Saw: '",
c->input(index).scalar<Variant>()().DebugString(), "'"); ctx->input(index).scalar<Variant>()().DebugString(), "'");
} }
*map = m; *ret_map = map;
return Status::OK(); return Status::OK();
} }
// TODO(kattian): change into templated function // TODO(kattian): change into templated function
Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index, Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32 input_index,
int32 output_index, int32 output_index,
const TensorMap& input_map, const TensorMap& input_map,
TensorMap** output_map) { TensorMap** output_map) {
// Attempt to forward the input tensor to the output if possible. // Attempt to forward the input tensor to the output if possible.
std::unique_ptr<Tensor> maybe_output = c->forward_input( std::unique_ptr<Tensor> maybe_output = ctx->forward_input(
input_index, output_index, DT_VARIANT, TensorShape{}, input_index, output_index, DT_VARIANT, TensorShape{},
c->input_memory_type(input_index), AllocatorAttributes()); ctx->input_memory_type(input_index), AllocatorAttributes());
Tensor* output_tensor; Tensor* output_tensor;
if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT && if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
maybe_output->NumElements() == 1) { maybe_output->NumElements() == 1) {
@ -60,7 +58,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
} }
if (tmp_out->RefCountIsOne()) { if (tmp_out->RefCountIsOne()) {
// Woohoo, forwarding succeeded! // Woohoo, forwarding succeeded!
c->set_output(output_index, *output_tensor); ctx->set_output(output_index, *output_tensor);
*output_map = tmp_out; *output_map = tmp_out;
return Status::OK(); return Status::OK();
} }
@ -71,7 +69,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
AllocatorAttributes attr; AllocatorAttributes attr;
attr.set_on_host(true); attr.set_on_host(true);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
c->allocate_output(output_index, {}, &output_tensor, attr)); ctx->allocate_output(output_index, {}, &output_tensor, attr));
output_tensor->scalar<Variant>()() = input_map.Copy(); output_tensor->scalar<Variant>()() = input_map.Copy();
*output_map = output_tensor->scalar<Variant>()().get<TensorMap>(); *output_map = output_tensor->scalar<Variant>()().get<TensorMap>();
@ -80,13 +78,13 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
class EmptyTensorMap : public OpKernel { class EmptyTensorMap : public OpKernel {
public: public:
explicit EmptyTensorMap(OpKernelConstruction* c) : OpKernel(c) {} explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* ctx) override {
Tensor* result; Tensor* result;
AllocatorAttributes attr; AllocatorAttributes attr;
attr.set_on_host(true); attr.set_on_host(true);
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
TensorMap empty; TensorMap empty;
result->scalar<Variant>()() = std::move(empty); result->scalar<Variant>()() = std::move(empty);
} }
@ -94,87 +92,136 @@ class EmptyTensorMap : public OpKernel {
class TensorMapSize : public OpKernel { class TensorMapSize : public OpKernel {
public: public:
explicit TensorMapSize(OpKernelConstruction* c) : OpKernel(c) {} explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapSize() override {} ~TensorMapSize() override {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* ctx) override {
const TensorMap* m = nullptr; const TensorMap* map = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
Tensor* result; Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
result->scalar<int32>()() = m->tensors().size(); result->scalar<int32>()() = map->tensors().size();
} }
}; };
class TensorMapLookup : public OpKernel { class TensorMapLookup : public OpKernel {
public: public:
explicit TensorMapLookup(OpKernelConstruction* c) : OpKernel(c) {} explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapLookup() override {} ~TensorMapLookup() override {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* ctx) override {
const TensorKey& key = c->input(1); const TensorKey& key = ctx->input(1);
const TensorMap* m = nullptr; const TensorMap* map = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(), OP_REQUIRES(
errors::InvalidArgument("Trying to lookup non-existent key.")); ctx, map->tensors().find(key) != map->tensors().end(),
errors::InvalidArgument("Trying to lookup non-existent key. Could not "
"find key \"" +
key.SummarizeValue(100) + "\"."));
c->set_output(0, m->tensors().find(key)->second); ctx->set_output(0, map->tensors().find(key)->second);
} }
}; };
class TensorMapInsert : public OpKernel { class TensorMapInsert : public OpKernel {
public: public:
explicit TensorMapInsert(OpKernelConstruction* c) : OpKernel(c) {} explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapInsert() override {} ~TensorMapInsert() override {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* ctx) override {
const TensorKey& key = c->input(1); const TensorKey& key = ctx->input(1);
const Tensor& value = c->input(2); const Tensor& value = ctx->input(2);
const TensorMap* m = nullptr; const TensorMap* map = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
TensorMap* output_map = nullptr; TensorMap* output_map = nullptr;
OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map)); OP_REQUIRES_OK(ctx,
ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
output_map->replace(key, value); output_map->replace(key, value);
} }
}; };
class TensorMapErase : public OpKernel { class TensorMapErase : public OpKernel {
public: public:
explicit TensorMapErase(OpKernelConstruction* c) : OpKernel(c) {} explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* ctx) override {
const TensorKey& key = c->input(1); const TensorKey& key = ctx->input(1);
const TensorMap* m = nullptr; const TensorMap* map = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(), OP_REQUIRES(
errors::InvalidArgument("Trying to erase non-existent item.")); ctx, map->tensors().find(key) != map->tensors().end(),
errors::InvalidArgument("Trying to erase non-existent item. Could not "
"find key \"" +
key.SummarizeValue(100) + "\"."));
TensorMap* output_map = nullptr; TensorMap* output_map = nullptr;
OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map)); OP_REQUIRES_OK(ctx,
ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
output_map->tensors().erase(key); output_map->tensors().erase(key);
} }
}; };
class TensorMapHasKey : public OpKernel { class TensorMapHasKey : public OpKernel {
public: public:
explicit TensorMapHasKey(OpKernelConstruction* c) : OpKernel(c) {} explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {}
~TensorMapHasKey() override {} ~TensorMapHasKey() override {}
void Compute(OpKernelContext* c) override { void Compute(OpKernelContext* ctx) override {
const TensorKey& key = c->input(1); const TensorKey& key = ctx->input(1);
const TensorMap* m = nullptr; const TensorMap* map = nullptr;
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
Tensor* result; Tensor* result;
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
result->scalar<bool>()() = m->tensors().find(key) != m->tensors().end(); result->scalar<bool>()() = map->tensors().find(key) != map->tensors().end();
} }
}; };
class TensorMapStackKeys : public OpKernel {
public:
explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_));
}
~TensorMapStackKeys() override {}
void Compute(OpKernelContext* ctx) override {
const TensorMap* map = nullptr;
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
OP_REQUIRES(ctx, map->size() != 0,
errors::InvalidArgument(
"TensorMapStackKeys cannot be called on empty map."));
auto it = map->tensors().begin();
TensorShape output_shape = it->first.shape();
output_shape.InsertDim(0, map->tensors().size());
Tensor* result;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
int i = 0;
size_t sz = map->tensors().size();
TensorShape key_shape = it->first.shape();
while (it != map->tensors().end() && i < sz) {
OP_REQUIRES(
ctx, it->first.dtype() == key_dtype_,
errors::InvalidArgument("Key does not match requested dtype."));
OP_REQUIRES(
ctx, it->first.shape() == key_shape,
errors::InvalidArgument("Keys must all have the same shape."));
OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i));
i++;
it++;
}
}
private:
DataType key_dtype_;
};
template <typename Device> template <typename Device>
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a, Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a,
const TensorMap& b, TensorMap* out) { const TensorMap& b, TensorMap* out) {
// Binary add returns a map containing the union of keys. // Binary add returns a map containing the union of keys.
// Values with keys in the intersection are added. // Values with keys in the intersection are added.
@ -185,7 +232,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
if (it != out->tensors().end()) { if (it != out->tensors().end()) {
Tensor out_tensor; Tensor out_tensor;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
BinaryAddTensors<Device>(c, p.second, it->second, &out_tensor)); BinaryAddTensors<Device>(ctx, p.second, it->second, &out_tensor));
it->second = out_tensor; it->second = out_tensor;
} else { } else {
out->tensors().emplace(p.first, p.second); out->tensors().emplace(p.first, p.second);
@ -195,7 +242,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
} }
template <typename Device> template <typename Device>
Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x, Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x,
TensorMap* y) { TensorMap* y) {
// Zeros like returns an empty map. // Zeros like returns an empty map.
return Status::OK(); return Status::OK();

View File

@ -144,7 +144,19 @@ class TensorMap {
size_t erase(TensorKey key) { return tensors_->values_.erase(key); } size_t erase(TensorKey key) { return tensors_->values_.erase(key); }
// Size returns the number of elements in the map // Size returns the number of elements in the map
size_t size() { return tensors_->values_.size(); } size_t size() const { return tensors_->values_.size(); }
std::vector<Tensor> keys() const {
std::vector<Tensor> keys;
keys.reserve(tensors_->values_.size());
absl::flat_hash_map<TensorKey, Tensor>::iterator it =
tensors_->values_.begin();
while (it != tensors_->values_.end()) {
keys.push_back(it->first);
it++;
}
return keys;
}
// Is this TensorMap the only one with a reference to the underlying // Is this TensorMap the only one with a reference to the underlying
// container? // container?

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/tensor_map.h" #include "tensorflow/core/kernels/tensor_map.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
@ -45,7 +44,6 @@ TEST(TensorKeyTest, Equal) {
} }
TEST(TensorMapTest, Insert) { TEST(TensorMapTest, Insert) {
EXPECT_EQ(1, 1);
TensorMap tm; TensorMap tm;
TensorKey k = Tensor(11); TensorKey k = Tensor(11);
Tensor v = Tensor(22); Tensor v = Tensor(22);
@ -102,12 +100,49 @@ TEST(TensorMapTest, Replace) {
Tensor v1 = Tensor(22); Tensor v1 = Tensor(22);
Tensor v2 = Tensor(23); Tensor v2 = Tensor(23);
tm[k] = v2; tm[k] = v2;
absl::flat_hash_map<TensorKey, Tensor>::iterator map_it = tm.find(k); absl::flat_hash_map<TensorKey, Tensor>::iterator map_it = tm.find(k);
EXPECT_EQ(map_it->first, k); EXPECT_EQ(map_it->first, k);
test::ExpectTensorEqual<int32>(map_it->second, v2); test::ExpectTensorEqual<int32>(map_it->second, v2);
} }
TEST(TensorMapTest, ListKeys) {
TensorMap tm;
TensorKey k = Tensor(11.0);
TensorKey k2 = Tensor(12.0);
Tensor v = Tensor(22);
Tensor v2 = Tensor(23);
tm.insert(k, v);
tm.insert(k2, v2);
std::vector<Tensor> keys = tm.keys();
// Extract and sort double value for each key Tensor.
std::vector<std::pair<double, int>> key_doubles;
for (int i = 0; i < keys.size(); i++) {
double x = keys[i].scalar<double>()();
std::pair<double, int> p = std::pair<double, int>(x, i);
key_doubles.push_back(p);
}
sort(key_doubles.begin(), key_doubles.end());
// Check number of keys and each key.
EXPECT_EQ(keys.size(), 2);
EXPECT_EQ(key_doubles[0].first, 11.0);
EXPECT_EQ(key_doubles[1].first, 12.0);
// Check key shapes.
int ind1 = key_doubles[0].second;
int ind2 = key_doubles[1].second;
EXPECT_EQ(keys[ind1].shape(), k.shape());
EXPECT_EQ(keys[ind2].shape(), k2.shape());
}
TEST(TensorMapTest, Size) {
TensorMap tm;
EXPECT_EQ(tm.size(), 0);
TensorKey k = Tensor(11);
Tensor v = Tensor(22);
tm.insert(k, v);
EXPECT_EQ(tm.size(), 1);
}
TEST(TensorMapTest, Copy) { TEST(TensorMapTest, Copy) {
TensorMap tm; TensorMap tm;
TensorKey k = Tensor(11); TensorKey k = Tensor(11);

View File

@ -1,19 +0,0 @@
op {
name: "TensorMapHasKey"
input_arg {
name: "input_handle"
type: DT_VARIANT
}
input_arg {
name: "key"
type_attr: "element_dtype"
}
output_arg {
name: "has_key"
type: DT_BOOL
}
attr {
name: "element_dtype"
type: "type"
}
}

View File

@ -69,10 +69,19 @@ REGISTER_OP("TensorMapErase")
REGISTER_OP("TensorMapHasKey") REGISTER_OP("TensorMapHasKey")
.Input("input_handle: variant") .Input("input_handle: variant")
.Input("key: element_dtype") .Input("key: key_dtype")
.Output("has_key: bool") .Output("has_key: bool")
.Attr("element_dtype: type") .Attr("key_dtype: type")
.SetShapeFn(shape_inference::ScalarShape); .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("TensorMapStackKeys")
.Input("input_handle: variant")
.Output("keys: key_dtype")
.Attr("key_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShape()); // output keys
return Status::OK();
});
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import map_ops from tensorflow.python.ops import map_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -57,7 +58,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
m = map_ops.empty_tensor_map() m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0) k = constant_op.constant(1.0)
with self.assertRaisesRegex(errors.InvalidArgumentError, with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to lookup non-existent key."): "Trying to lookup non-existent key. *"):
l = map_ops.tensor_map_lookup(m, k, dtypes.float32) l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
self.evaluate(l) self.evaluate(l)
@ -68,7 +69,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
v = constant_op.constant(11.0) v = constant_op.constant(11.0)
m = map_ops.tensor_map_insert(m, k, v) m = map_ops.tensor_map_insert(m, k, v)
with self.assertRaisesRegex(errors.InvalidArgumentError, with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to lookup non-existent key."): "Trying to lookup non-existent key. *"):
l = map_ops.tensor_map_lookup(m, k2, dtypes.float32) l = map_ops.tensor_map_lookup(m, k2, dtypes.float32)
self.evaluate(l) self.evaluate(l)
@ -87,7 +88,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
m = map_ops.empty_tensor_map() m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0) k = constant_op.constant(1.0)
with self.assertRaisesRegex(errors.InvalidArgumentError, with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to erase non-existent item."): "Trying to erase non-existent item. *"):
m = map_ops.tensor_map_erase(m, k, dtypes.float32) m = map_ops.tensor_map_erase(m, k, dtypes.float32)
self.evaluate(m) self.evaluate(m)
@ -98,7 +99,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
v = constant_op.constant(2.0) v = constant_op.constant(2.0)
m = map_ops.tensor_map_insert(m, k2, v) m = map_ops.tensor_map_insert(m, k2, v)
with self.assertRaisesRegex(errors.InvalidArgumentError, with self.assertRaisesRegex(errors.InvalidArgumentError,
"Trying to erase non-existent item."): "Trying to erase non-existent item. *"):
m = map_ops.tensor_map_erase(m, k, dtypes.float32) m = map_ops.tensor_map_erase(m, k, dtypes.float32)
self.evaluate(m) self.evaluate(m)
@ -133,6 +134,58 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllClose(l, v) self.assertAllClose(l, v)
self.assertAllClose(l2, default_value) self.assertAllClose(l2, default_value)
def testStackKeys(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
k3 = constant_op.constant(3.0)
v = constant_op.constant(21.0)
v2 = constant_op.constant(22.0)
v3 = constant_op.constant(23.0)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
expected = constant_op.constant([1.0, 2.0])
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
self.assertAllClose(sort_ops.sort(keys), expected)
m = map_ops.tensor_map_insert(m, k3, v3)
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
expected = constant_op.constant([1.0, 2.0, 3.0])
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
self.assertAllClose(sort_ops.sort(keys), expected)
def testStackKeysEmptyMapFails(self):
m = map_ops.empty_tensor_map()
with self.assertRaisesRegex(
errors.InvalidArgumentError, "TensorMapStackKeys cannot be called "
"on empty map."):
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
self.evaluate(keys)
def testStackKeysIncorrectDtypeFails(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant("key_with_wrong_dtype")
v = constant_op.constant(2.0)
m = map_ops.tensor_map_insert(m, k, v)
simple = "Key does not match requested dtype."
with self.assertRaisesRegex(errors.InvalidArgumentError, simple):
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
self.evaluate(keys)
def testStackKeysIncorrectShapeFails(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant([1.0, 11.0])
v = constant_op.constant(2.0)
v2 = constant_op.constant(22.0)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Keys must all have the same shape."):
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
self.evaluate(keys)
def testInsertLookupGrad(self): def testInsertLookupGrad(self):
with backprop.GradientTape() as tape: with backprop.GradientTape() as tape:
m = map_ops.empty_tensor_map() m = map_ops.empty_tensor_map()
@ -397,6 +450,5 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(s, 0) self.assertAllEqual(s, 0)
self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False) self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -46,6 +46,11 @@ def tensor_map_erase(input_handle, key, value_dtype):
def tensor_map_has_key(input_handle, key): def tensor_map_has_key(input_handle, key):
return gen_map_ops.tensor_map_has_key(input_handle, key) return gen_map_ops.tensor_map_has_key(input_handle, key)
def tensor_map_stack_keys(input_handle, key_dtype):
return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype)
@ops.RegisterGradient("TensorMapLookup") @ops.RegisterGradient("TensorMapLookup")
def LookupGrad(op, dval): def LookupGrad(op, dval):
_, k = op.inputs _, k = op.inputs