Merge pull request #42300 from kttian:list_keys
PiperOrigin-RevId: 328350823 Change-Id: I41fc8ec95e2b70b9f86b58882f01a3e6a987b205
This commit is contained in:
commit
51636d1795
@ -0,0 +1,8 @@
|
||||
op {
|
||||
graph_op_name: "TensorMapStackKeys"
|
||||
summary: "Returns a Tensor stack of all keys in a tensor map."
|
||||
description: <<END
|
||||
input_handle: the input map
|
||||
keys: the returned Tensor of all keys in the map
|
||||
END
|
||||
}
|
@ -41,6 +41,9 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
|
||||
TensorMapHasKey);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorMapStackKeys").Device(DEVICE_CPU),
|
||||
TensorMapStackKeys);
|
||||
|
||||
#undef REGISTER_TENSOR_MAP_OPS_CPU
|
||||
|
||||
#define REGISTER_TENSOR_MAP_OPS_CPU(T)
|
||||
|
@ -15,39 +15,37 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MAP_KERNELS_H_
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||
#include "tensorflow/core/kernels/tensor_map.h"
|
||||
#include "tensorflow/core/util/batch_util.h"
|
||||
#include "tensorflow/core/util/tensor_ops_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GetInputMap(OpKernelContext* c, int index, const TensorMap** map) {
|
||||
if (!TensorShapeUtils::IsScalar(c->input(index).shape())) {
|
||||
Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) {
|
||||
if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) {
|
||||
return errors::InvalidArgument("Input map must be a scalar. Saw: ",
|
||||
c->input(index).shape().DebugString());
|
||||
ctx->input(index).shape().DebugString());
|
||||
}
|
||||
const TensorMap* m = c->input(index).scalar<Variant>()().get<TensorMap>();
|
||||
if (m == nullptr) {
|
||||
const TensorMap* map = ctx->input(index).scalar<Variant>()().get<TensorMap>();
|
||||
if (map == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Input handle is not a map. Saw: '",
|
||||
c->input(index).scalar<Variant>()().DebugString(), "'");
|
||||
ctx->input(index).scalar<Variant>()().DebugString(), "'");
|
||||
}
|
||||
*map = m;
|
||||
*ret_map = map;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(kattian): change into templated function
|
||||
Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
|
||||
Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32 input_index,
|
||||
int32 output_index,
|
||||
const TensorMap& input_map,
|
||||
TensorMap** output_map) {
|
||||
// Attempt to forward the input tensor to the output if possible.
|
||||
std::unique_ptr<Tensor> maybe_output = c->forward_input(
|
||||
std::unique_ptr<Tensor> maybe_output = ctx->forward_input(
|
||||
input_index, output_index, DT_VARIANT, TensorShape{},
|
||||
c->input_memory_type(input_index), AllocatorAttributes());
|
||||
ctx->input_memory_type(input_index), AllocatorAttributes());
|
||||
Tensor* output_tensor;
|
||||
if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT &&
|
||||
maybe_output->NumElements() == 1) {
|
||||
@ -60,7 +58,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
|
||||
}
|
||||
if (tmp_out->RefCountIsOne()) {
|
||||
// Woohoo, forwarding succeeded!
|
||||
c->set_output(output_index, *output_tensor);
|
||||
ctx->set_output(output_index, *output_tensor);
|
||||
*output_map = tmp_out;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -71,7 +69,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->allocate_output(output_index, {}, &output_tensor, attr));
|
||||
ctx->allocate_output(output_index, {}, &output_tensor, attr));
|
||||
output_tensor->scalar<Variant>()() = input_map.Copy();
|
||||
|
||||
*output_map = output_tensor->scalar<Variant>()().get<TensorMap>();
|
||||
@ -80,13 +78,13 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index,
|
||||
|
||||
class EmptyTensorMap : public OpKernel {
|
||||
public:
|
||||
explicit EmptyTensorMap(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor* result;
|
||||
AllocatorAttributes attr;
|
||||
attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr));
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr));
|
||||
TensorMap empty;
|
||||
result->scalar<Variant>()() = std::move(empty);
|
||||
}
|
||||
@ -94,87 +92,136 @@ class EmptyTensorMap : public OpKernel {
|
||||
|
||||
class TensorMapSize : public OpKernel {
|
||||
public:
|
||||
explicit TensorMapSize(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
~TensorMapSize() override {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const TensorMap* m = nullptr;
|
||||
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const TensorMap* map = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
|
||||
result->scalar<int32>()() = m->tensors().size();
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
|
||||
result->scalar<int32>()() = map->tensors().size();
|
||||
}
|
||||
};
|
||||
|
||||
class TensorMapLookup : public OpKernel {
|
||||
public:
|
||||
explicit TensorMapLookup(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
~TensorMapLookup() override {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const TensorKey& key = c->input(1);
|
||||
const TensorMap* m = nullptr;
|
||||
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const TensorKey& key = ctx->input(1);
|
||||
const TensorMap* map = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
|
||||
|
||||
OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(),
|
||||
errors::InvalidArgument("Trying to lookup non-existent key."));
|
||||
OP_REQUIRES(
|
||||
ctx, map->tensors().find(key) != map->tensors().end(),
|
||||
errors::InvalidArgument("Trying to lookup non-existent key. Could not "
|
||||
"find key \"" +
|
||||
key.SummarizeValue(100) + "\"."));
|
||||
|
||||
c->set_output(0, m->tensors().find(key)->second);
|
||||
ctx->set_output(0, map->tensors().find(key)->second);
|
||||
}
|
||||
};
|
||||
|
||||
class TensorMapInsert : public OpKernel {
|
||||
public:
|
||||
explicit TensorMapInsert(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
~TensorMapInsert() override {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const TensorKey& key = c->input(1);
|
||||
const Tensor& value = c->input(2);
|
||||
const TensorMap* m = nullptr;
|
||||
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const TensorKey& key = ctx->input(1);
|
||||
const Tensor& value = ctx->input(2);
|
||||
const TensorMap* map = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
|
||||
|
||||
TensorMap* output_map = nullptr;
|
||||
OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
|
||||
output_map->replace(key, value);
|
||||
}
|
||||
};
|
||||
|
||||
class TensorMapErase : public OpKernel {
|
||||
public:
|
||||
explicit TensorMapErase(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const TensorKey& key = c->input(1);
|
||||
const TensorMap* m = nullptr;
|
||||
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const TensorKey& key = ctx->input(1);
|
||||
const TensorMap* map = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
|
||||
|
||||
OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(),
|
||||
errors::InvalidArgument("Trying to erase non-existent item."));
|
||||
OP_REQUIRES(
|
||||
ctx, map->tensors().find(key) != map->tensors().end(),
|
||||
errors::InvalidArgument("Trying to erase non-existent item. Could not "
|
||||
"find key \"" +
|
||||
key.SummarizeValue(100) + "\"."));
|
||||
|
||||
TensorMap* output_map = nullptr;
|
||||
OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map));
|
||||
output_map->tensors().erase(key);
|
||||
}
|
||||
};
|
||||
|
||||
class TensorMapHasKey : public OpKernel {
|
||||
public:
|
||||
explicit TensorMapHasKey(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
~TensorMapHasKey() override {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const TensorKey& key = c->input(1);
|
||||
const TensorMap* m = nullptr;
|
||||
OP_REQUIRES_OK(c, GetInputMap(c, 0, &m));
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const TensorKey& key = ctx->input(1);
|
||||
const TensorMap* map = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result));
|
||||
result->scalar<bool>()() = m->tensors().find(key) != m->tensors().end();
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result));
|
||||
result->scalar<bool>()() = map->tensors().find(key) != map->tensors().end();
|
||||
}
|
||||
};
|
||||
|
||||
class TensorMapStackKeys : public OpKernel {
|
||||
public:
|
||||
explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_));
|
||||
}
|
||||
~TensorMapStackKeys() override {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const TensorMap* map = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map));
|
||||
|
||||
OP_REQUIRES(ctx, map->size() != 0,
|
||||
errors::InvalidArgument(
|
||||
"TensorMapStackKeys cannot be called on empty map."));
|
||||
|
||||
auto it = map->tensors().begin();
|
||||
TensorShape output_shape = it->first.shape();
|
||||
output_shape.InsertDim(0, map->tensors().size());
|
||||
Tensor* result;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result));
|
||||
|
||||
int i = 0;
|
||||
size_t sz = map->tensors().size();
|
||||
TensorShape key_shape = it->first.shape();
|
||||
while (it != map->tensors().end() && i < sz) {
|
||||
OP_REQUIRES(
|
||||
ctx, it->first.dtype() == key_dtype_,
|
||||
errors::InvalidArgument("Key does not match requested dtype."));
|
||||
OP_REQUIRES(
|
||||
ctx, it->first.shape() == key_shape,
|
||||
errors::InvalidArgument("Keys must all have the same shape."));
|
||||
OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i));
|
||||
i++;
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataType key_dtype_;
|
||||
};
|
||||
|
||||
template <typename Device>
|
||||
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
|
||||
Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a,
|
||||
const TensorMap& b, TensorMap* out) {
|
||||
// Binary add returns a map containing the union of keys.
|
||||
// Values with keys in the intersection are added.
|
||||
@ -185,7 +232,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
|
||||
if (it != out->tensors().end()) {
|
||||
Tensor out_tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
BinaryAddTensors<Device>(c, p.second, it->second, &out_tensor));
|
||||
BinaryAddTensors<Device>(ctx, p.second, it->second, &out_tensor));
|
||||
it->second = out_tensor;
|
||||
} else {
|
||||
out->tensors().emplace(p.first, p.second);
|
||||
@ -195,7 +242,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x,
|
||||
Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x,
|
||||
TensorMap* y) {
|
||||
// Zeros like returns an empty map.
|
||||
return Status::OK();
|
||||
|
@ -144,7 +144,19 @@ class TensorMap {
|
||||
size_t erase(TensorKey key) { return tensors_->values_.erase(key); }
|
||||
|
||||
// Size returns the number of elements in the map
|
||||
size_t size() { return tensors_->values_.size(); }
|
||||
size_t size() const { return tensors_->values_.size(); }
|
||||
|
||||
std::vector<Tensor> keys() const {
|
||||
std::vector<Tensor> keys;
|
||||
keys.reserve(tensors_->values_.size());
|
||||
absl::flat_hash_map<TensorKey, Tensor>::iterator it =
|
||||
tensors_->values_.begin();
|
||||
while (it != tensors_->values_.end()) {
|
||||
keys.push_back(it->first);
|
||||
it++;
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
// Is this TensorMap the only one with a reference to the underlying
|
||||
// container?
|
||||
|
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/tensor_map.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -45,7 +44,6 @@ TEST(TensorKeyTest, Equal) {
|
||||
}
|
||||
|
||||
TEST(TensorMapTest, Insert) {
|
||||
EXPECT_EQ(1, 1);
|
||||
TensorMap tm;
|
||||
TensorKey k = Tensor(11);
|
||||
Tensor v = Tensor(22);
|
||||
@ -102,12 +100,49 @@ TEST(TensorMapTest, Replace) {
|
||||
Tensor v1 = Tensor(22);
|
||||
Tensor v2 = Tensor(23);
|
||||
tm[k] = v2;
|
||||
|
||||
absl::flat_hash_map<TensorKey, Tensor>::iterator map_it = tm.find(k);
|
||||
EXPECT_EQ(map_it->first, k);
|
||||
test::ExpectTensorEqual<int32>(map_it->second, v2);
|
||||
}
|
||||
|
||||
TEST(TensorMapTest, ListKeys) {
|
||||
TensorMap tm;
|
||||
TensorKey k = Tensor(11.0);
|
||||
TensorKey k2 = Tensor(12.0);
|
||||
Tensor v = Tensor(22);
|
||||
Tensor v2 = Tensor(23);
|
||||
tm.insert(k, v);
|
||||
tm.insert(k2, v2);
|
||||
std::vector<Tensor> keys = tm.keys();
|
||||
|
||||
// Extract and sort double value for each key Tensor.
|
||||
std::vector<std::pair<double, int>> key_doubles;
|
||||
for (int i = 0; i < keys.size(); i++) {
|
||||
double x = keys[i].scalar<double>()();
|
||||
std::pair<double, int> p = std::pair<double, int>(x, i);
|
||||
key_doubles.push_back(p);
|
||||
}
|
||||
sort(key_doubles.begin(), key_doubles.end());
|
||||
// Check number of keys and each key.
|
||||
EXPECT_EQ(keys.size(), 2);
|
||||
EXPECT_EQ(key_doubles[0].first, 11.0);
|
||||
EXPECT_EQ(key_doubles[1].first, 12.0);
|
||||
// Check key shapes.
|
||||
int ind1 = key_doubles[0].second;
|
||||
int ind2 = key_doubles[1].second;
|
||||
EXPECT_EQ(keys[ind1].shape(), k.shape());
|
||||
EXPECT_EQ(keys[ind2].shape(), k2.shape());
|
||||
}
|
||||
|
||||
TEST(TensorMapTest, Size) {
|
||||
TensorMap tm;
|
||||
EXPECT_EQ(tm.size(), 0);
|
||||
TensorKey k = Tensor(11);
|
||||
Tensor v = Tensor(22);
|
||||
tm.insert(k, v);
|
||||
EXPECT_EQ(tm.size(), 1);
|
||||
}
|
||||
|
||||
TEST(TensorMapTest, Copy) {
|
||||
TensorMap tm;
|
||||
TensorKey k = Tensor(11);
|
||||
|
@ -1,19 +0,0 @@
|
||||
op {
|
||||
name: "TensorMapHasKey"
|
||||
input_arg {
|
||||
name: "input_handle"
|
||||
type: DT_VARIANT
|
||||
}
|
||||
input_arg {
|
||||
name: "key"
|
||||
type_attr: "element_dtype"
|
||||
}
|
||||
output_arg {
|
||||
name: "has_key"
|
||||
type: DT_BOOL
|
||||
}
|
||||
attr {
|
||||
name: "element_dtype"
|
||||
type: "type"
|
||||
}
|
||||
}
|
@ -69,10 +69,19 @@ REGISTER_OP("TensorMapErase")
|
||||
|
||||
REGISTER_OP("TensorMapHasKey")
|
||||
.Input("input_handle: variant")
|
||||
.Input("key: element_dtype")
|
||||
.Input("key: key_dtype")
|
||||
.Output("has_key: bool")
|
||||
.Attr("element_dtype: type")
|
||||
.Attr("key_dtype: type")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("TensorMapStackKeys")
|
||||
.Input("input_handle: variant")
|
||||
.Output("keys: key_dtype")
|
||||
.Attr("key_dtype: type")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShape()); // output keys
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import map_ops
|
||||
from tensorflow.python.ops import sort_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -57,7 +58,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Trying to lookup non-existent key."):
|
||||
"Trying to lookup non-existent key. *"):
|
||||
l = map_ops.tensor_map_lookup(m, k, dtypes.float32)
|
||||
self.evaluate(l)
|
||||
|
||||
@ -68,7 +69,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
v = constant_op.constant(11.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Trying to lookup non-existent key."):
|
||||
"Trying to lookup non-existent key. *"):
|
||||
l = map_ops.tensor_map_lookup(m, k2, dtypes.float32)
|
||||
self.evaluate(l)
|
||||
|
||||
@ -87,7 +88,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Trying to erase non-existent item."):
|
||||
"Trying to erase non-existent item. *"):
|
||||
m = map_ops.tensor_map_erase(m, k, dtypes.float32)
|
||||
self.evaluate(m)
|
||||
|
||||
@ -98,7 +99,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
v = constant_op.constant(2.0)
|
||||
m = map_ops.tensor_map_insert(m, k2, v)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Trying to erase non-existent item."):
|
||||
"Trying to erase non-existent item. *"):
|
||||
m = map_ops.tensor_map_erase(m, k, dtypes.float32)
|
||||
self.evaluate(m)
|
||||
|
||||
@ -133,6 +134,58 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllClose(l, v)
|
||||
self.assertAllClose(l2, default_value)
|
||||
|
||||
def testStackKeys(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
k3 = constant_op.constant(3.0)
|
||||
v = constant_op.constant(21.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
v3 = constant_op.constant(23.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
|
||||
expected = constant_op.constant([1.0, 2.0])
|
||||
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
||||
self.assertAllClose(sort_ops.sort(keys), expected)
|
||||
|
||||
m = map_ops.tensor_map_insert(m, k3, v3)
|
||||
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
|
||||
expected = constant_op.constant([1.0, 2.0, 3.0])
|
||||
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
||||
self.assertAllClose(sort_ops.sort(keys), expected)
|
||||
|
||||
def testStackKeysEmptyMapFails(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError, "TensorMapStackKeys cannot be called "
|
||||
"on empty map."):
|
||||
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
|
||||
self.evaluate(keys)
|
||||
|
||||
def testStackKeysIncorrectDtypeFails(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant("key_with_wrong_dtype")
|
||||
v = constant_op.constant(2.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
simple = "Key does not match requested dtype."
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, simple):
|
||||
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
|
||||
self.evaluate(keys)
|
||||
|
||||
def testStackKeysIncorrectShapeFails(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant([1.0, 11.0])
|
||||
v = constant_op.constant(2.0)
|
||||
v2 = constant_op.constant(22.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Keys must all have the same shape."):
|
||||
keys = map_ops.tensor_map_stack_keys(m, dtypes.float32)
|
||||
self.evaluate(keys)
|
||||
|
||||
def testInsertLookupGrad(self):
|
||||
with backprop.GradientTape() as tape:
|
||||
m = map_ops.empty_tensor_map()
|
||||
@ -397,6 +450,5 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(s, 0)
|
||||
self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -46,6 +46,11 @@ def tensor_map_erase(input_handle, key, value_dtype):
|
||||
def tensor_map_has_key(input_handle, key):
|
||||
return gen_map_ops.tensor_map_has_key(input_handle, key)
|
||||
|
||||
|
||||
def tensor_map_stack_keys(input_handle, key_dtype):
|
||||
return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype)
|
||||
|
||||
|
||||
@ops.RegisterGradient("TensorMapLookup")
|
||||
def LookupGrad(op, dval):
|
||||
_, k = op.inputs
|
||||
|
Loading…
Reference in New Issue
Block a user