diff --git a/tensorflow/core/api_def/base_api/api_def_TensorMapStackKeys.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorMapStackKeys.pbtxt new file mode 100644 index 00000000000..a8ecb43328a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_TensorMapStackKeys.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "TensorMapStackKeys" + summary: "Returns a Tensor stack of all keys in a tensor map." + description: < - #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/tensor_map.h" +#include "tensorflow/core/util/batch_util.h" #include "tensorflow/core/util/tensor_ops_util.h" namespace tensorflow { -Status GetInputMap(OpKernelContext* c, int index, const TensorMap** map) { - if (!TensorShapeUtils::IsScalar(c->input(index).shape())) { +Status GetInputMap(OpKernelContext* ctx, int index, const TensorMap** ret_map) { + if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) { return errors::InvalidArgument("Input map must be a scalar. Saw: ", - c->input(index).shape().DebugString()); + ctx->input(index).shape().DebugString()); } - const TensorMap* m = c->input(index).scalar()().get(); - if (m == nullptr) { + const TensorMap* map = ctx->input(index).scalar()().get(); + if (map == nullptr) { return errors::InvalidArgument( "Input handle is not a map. Saw: '", - c->input(index).scalar()().DebugString(), "'"); + ctx->input(index).scalar()().DebugString(), "'"); } - *map = m; + *ret_map = map; return Status::OK(); } // TODO(kattian): change into templated function -Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index, +Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, int32 input_index, int32 output_index, const TensorMap& input_map, TensorMap** output_map) { // Attempt to forward the input tensor to the output if possible. - std::unique_ptr maybe_output = c->forward_input( + std::unique_ptr maybe_output = ctx->forward_input( input_index, output_index, DT_VARIANT, TensorShape{}, - c->input_memory_type(input_index), AllocatorAttributes()); + ctx->input_memory_type(input_index), AllocatorAttributes()); Tensor* output_tensor; if (maybe_output != nullptr && maybe_output->dtype() == DT_VARIANT && maybe_output->NumElements() == 1) { @@ -60,7 +58,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index, } if (tmp_out->RefCountIsOne()) { // Woohoo, forwarding succeeded! - c->set_output(output_index, *output_tensor); + ctx->set_output(output_index, *output_tensor); *output_map = tmp_out; return Status::OK(); } @@ -71,7 +69,7 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index, AllocatorAttributes attr; attr.set_on_host(true); TF_RETURN_IF_ERROR( - c->allocate_output(output_index, {}, &output_tensor, attr)); + ctx->allocate_output(output_index, {}, &output_tensor, attr)); output_tensor->scalar()() = input_map.Copy(); *output_map = output_tensor->scalar()().get(); @@ -80,13 +78,13 @@ Status ForwardInputOrCreateNewMap(OpKernelContext* c, int32 input_index, class EmptyTensorMap : public OpKernel { public: - explicit EmptyTensorMap(OpKernelConstruction* c) : OpKernel(c) {} + explicit EmptyTensorMap(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* c) override { + void Compute(OpKernelContext* ctx) override { Tensor* result; AllocatorAttributes attr; attr.set_on_host(true); - OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result, attr)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result, attr)); TensorMap empty; result->scalar()() = std::move(empty); } @@ -94,87 +92,136 @@ class EmptyTensorMap : public OpKernel { class TensorMapSize : public OpKernel { public: - explicit TensorMapSize(OpKernelConstruction* c) : OpKernel(c) {} + explicit TensorMapSize(OpKernelConstruction* ctx) : OpKernel(ctx) {} ~TensorMapSize() override {} - void Compute(OpKernelContext* c) override { - const TensorMap* m = nullptr; - OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); + void Compute(OpKernelContext* ctx) override { + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); Tensor* result; - OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); - result->scalar()() = m->tensors().size(); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result)); + result->scalar()() = map->tensors().size(); } }; class TensorMapLookup : public OpKernel { public: - explicit TensorMapLookup(OpKernelConstruction* c) : OpKernel(c) {} + explicit TensorMapLookup(OpKernelConstruction* ctx) : OpKernel(ctx) {} ~TensorMapLookup() override {} - void Compute(OpKernelContext* c) override { - const TensorKey& key = c->input(1); - const TensorMap* m = nullptr; - OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); - OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(), - errors::InvalidArgument("Trying to lookup non-existent key.")); + OP_REQUIRES( + ctx, map->tensors().find(key) != map->tensors().end(), + errors::InvalidArgument("Trying to lookup non-existent key. Could not " + "find key \"" + + key.SummarizeValue(100) + "\".")); - c->set_output(0, m->tensors().find(key)->second); + ctx->set_output(0, map->tensors().find(key)->second); } }; class TensorMapInsert : public OpKernel { public: - explicit TensorMapInsert(OpKernelConstruction* c) : OpKernel(c) {} + explicit TensorMapInsert(OpKernelConstruction* ctx) : OpKernel(ctx) {} ~TensorMapInsert() override {} - void Compute(OpKernelContext* c) override { - const TensorKey& key = c->input(1); - const Tensor& value = c->input(2); - const TensorMap* m = nullptr; - OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const Tensor& value = ctx->input(2); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); TensorMap* output_map = nullptr; - OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map)); + OP_REQUIRES_OK(ctx, + ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map)); output_map->replace(key, value); } }; class TensorMapErase : public OpKernel { public: - explicit TensorMapErase(OpKernelConstruction* c) : OpKernel(c) {} + explicit TensorMapErase(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* c) override { - const TensorKey& key = c->input(1); - const TensorMap* m = nullptr; - OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); - OP_REQUIRES(c, m->tensors().find(key) != m->tensors().end(), - errors::InvalidArgument("Trying to erase non-existent item.")); + OP_REQUIRES( + ctx, map->tensors().find(key) != map->tensors().end(), + errors::InvalidArgument("Trying to erase non-existent item. Could not " + "find key \"" + + key.SummarizeValue(100) + "\".")); TensorMap* output_map = nullptr; - OP_REQUIRES_OK(c, ForwardInputOrCreateNewMap(c, 0, 0, *m, &output_map)); + OP_REQUIRES_OK(ctx, + ForwardInputOrCreateNewMap(ctx, 0, 0, *map, &output_map)); output_map->tensors().erase(key); } }; class TensorMapHasKey : public OpKernel { public: - explicit TensorMapHasKey(OpKernelConstruction* c) : OpKernel(c) {} + explicit TensorMapHasKey(OpKernelConstruction* ctx) : OpKernel(ctx) {} ~TensorMapHasKey() override {} - void Compute(OpKernelContext* c) override { - const TensorKey& key = c->input(1); - const TensorMap* m = nullptr; - OP_REQUIRES_OK(c, GetInputMap(c, 0, &m)); + void Compute(OpKernelContext* ctx) override { + const TensorKey& key = ctx->input(1); + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); Tensor* result; - OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape{}, &result)); - result->scalar()() = m->tensors().find(key) != m->tensors().end(); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &result)); + result->scalar()() = map->tensors().find(key) != map->tensors().end(); } }; +class TensorMapStackKeys : public OpKernel { + public: + explicit TensorMapStackKeys(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("key_dtype", &key_dtype_)); + } + ~TensorMapStackKeys() override {} + + void Compute(OpKernelContext* ctx) override { + const TensorMap* map = nullptr; + OP_REQUIRES_OK(ctx, GetInputMap(ctx, 0, &map)); + + OP_REQUIRES(ctx, map->size() != 0, + errors::InvalidArgument( + "TensorMapStackKeys cannot be called on empty map.")); + + auto it = map->tensors().begin(); + TensorShape output_shape = it->first.shape(); + output_shape.InsertDim(0, map->tensors().size()); + Tensor* result; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &result)); + + int i = 0; + size_t sz = map->tensors().size(); + TensorShape key_shape = it->first.shape(); + while (it != map->tensors().end() && i < sz) { + OP_REQUIRES( + ctx, it->first.dtype() == key_dtype_, + errors::InvalidArgument("Key does not match requested dtype.")); + OP_REQUIRES( + ctx, it->first.shape() == key_shape, + errors::InvalidArgument("Keys must all have the same shape.")); + OP_REQUIRES_OK(ctx, batch_util::CopyElementToSlice(it->first, result, i)); + i++; + it++; + } + } + + private: + DataType key_dtype_; +}; + template -Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a, +Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a, const TensorMap& b, TensorMap* out) { // Binary add returns a map containing the union of keys. // Values with keys in the intersection are added. @@ -185,7 +232,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a, if (it != out->tensors().end()) { Tensor out_tensor; TF_RETURN_IF_ERROR( - BinaryAddTensors(c, p.second, it->second, &out_tensor)); + BinaryAddTensors(ctx, p.second, it->second, &out_tensor)); it->second = out_tensor; } else { out->tensors().emplace(p.first, p.second); @@ -195,7 +242,7 @@ Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a, } template -Status TensorMapZerosLike(OpKernelContext* c, const TensorMap& x, +Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x, TensorMap* y) { // Zeros like returns an empty map. return Status::OK(); diff --git a/tensorflow/core/kernels/tensor_map.h b/tensorflow/core/kernels/tensor_map.h index d29d244f1ca..cb4c827cc3c 100644 --- a/tensorflow/core/kernels/tensor_map.h +++ b/tensorflow/core/kernels/tensor_map.h @@ -144,7 +144,19 @@ class TensorMap { size_t erase(TensorKey key) { return tensors_->values_.erase(key); } // Size returns the number of elements in the map - size_t size() { return tensors_->values_.size(); } + size_t size() const { return tensors_->values_.size(); } + + std::vector keys() const { + std::vector keys; + keys.reserve(tensors_->values_.size()); + absl::flat_hash_map::iterator it = + tensors_->values_.begin(); + while (it != tensors_->values_.end()) { + keys.push_back(it->first); + it++; + } + return keys; + } // Is this TensorMap the only one with a reference to the underlying // container? diff --git a/tensorflow/core/kernels/tensor_map_test.cc b/tensorflow/core/kernels/tensor_map_test.cc index beaff6fc622..76c903f047c 100644 --- a/tensorflow/core/kernels/tensor_map_test.cc +++ b/tensorflow/core/kernels/tensor_map_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/tensor_map.h" - #include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -45,7 +44,6 @@ TEST(TensorKeyTest, Equal) { } TEST(TensorMapTest, Insert) { - EXPECT_EQ(1, 1); TensorMap tm; TensorKey k = Tensor(11); Tensor v = Tensor(22); @@ -102,12 +100,49 @@ TEST(TensorMapTest, Replace) { Tensor v1 = Tensor(22); Tensor v2 = Tensor(23); tm[k] = v2; - absl::flat_hash_map::iterator map_it = tm.find(k); EXPECT_EQ(map_it->first, k); test::ExpectTensorEqual(map_it->second, v2); } +TEST(TensorMapTest, ListKeys) { + TensorMap tm; + TensorKey k = Tensor(11.0); + TensorKey k2 = Tensor(12.0); + Tensor v = Tensor(22); + Tensor v2 = Tensor(23); + tm.insert(k, v); + tm.insert(k2, v2); + std::vector keys = tm.keys(); + + // Extract and sort double value for each key Tensor. + std::vector> key_doubles; + for (int i = 0; i < keys.size(); i++) { + double x = keys[i].scalar()(); + std::pair p = std::pair(x, i); + key_doubles.push_back(p); + } + sort(key_doubles.begin(), key_doubles.end()); + // Check number of keys and each key. + EXPECT_EQ(keys.size(), 2); + EXPECT_EQ(key_doubles[0].first, 11.0); + EXPECT_EQ(key_doubles[1].first, 12.0); + // Check key shapes. + int ind1 = key_doubles[0].second; + int ind2 = key_doubles[1].second; + EXPECT_EQ(keys[ind1].shape(), k.shape()); + EXPECT_EQ(keys[ind2].shape(), k2.shape()); +} + +TEST(TensorMapTest, Size) { + TensorMap tm; + EXPECT_EQ(tm.size(), 0); + TensorKey k = Tensor(11); + Tensor v = Tensor(22); + tm.insert(k, v); + EXPECT_EQ(tm.size(), 1); +} + TEST(TensorMapTest, Copy) { TensorMap tm; TensorKey k = Tensor(11); diff --git a/tensorflow/core/ops/compat/ops_history_v2/TensorMapHasKey.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/TensorMapHasKey.pbtxt deleted file mode 100644 index 437822797af..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v2/TensorMapHasKey.pbtxt +++ /dev/null @@ -1,19 +0,0 @@ -op { - name: "TensorMapHasKey" - input_arg { - name: "input_handle" - type: DT_VARIANT - } - input_arg { - name: "key" - type_attr: "element_dtype" - } - output_arg { - name: "has_key" - type: DT_BOOL - } - attr { - name: "element_dtype" - type: "type" - } -} diff --git a/tensorflow/core/ops/map_ops.cc b/tensorflow/core/ops/map_ops.cc index f52075132eb..d54ef54b481 100644 --- a/tensorflow/core/ops/map_ops.cc +++ b/tensorflow/core/ops/map_ops.cc @@ -63,16 +63,25 @@ REGISTER_OP("TensorMapErase") .Attr("key_dtype: type") .Attr("value_dtype: type") .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); // output map + c->set_output(0, c->Scalar()); // output map return Status::OK(); }); REGISTER_OP("TensorMapHasKey") .Input("input_handle: variant") - .Input("key: element_dtype") + .Input("key: key_dtype") .Output("has_key: bool") - .Attr("element_dtype: type") + .Attr("key_dtype: type") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("TensorMapStackKeys") + .Input("input_handle: variant") + .Output("keys: key_dtype") + .Attr("key_dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShape()); // output keys + return Status::OK(); + }); + } // namespace } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/map_ops_test.py b/tensorflow/python/kernel_tests/map_ops_test.py index 8db5cd7a6b1..771e22e5c5e 100644 --- a/tensorflow/python/kernel_tests/map_ops_test.py +++ b/tensorflow/python/kernel_tests/map_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import map_ops +from tensorflow.python.ops import sort_ops from tensorflow.python.platform import test @@ -57,7 +58,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): m = map_ops.empty_tensor_map() k = constant_op.constant(1.0) with self.assertRaisesRegex(errors.InvalidArgumentError, - "Trying to lookup non-existent key."): + "Trying to lookup non-existent key. *"): l = map_ops.tensor_map_lookup(m, k, dtypes.float32) self.evaluate(l) @@ -68,7 +69,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): v = constant_op.constant(11.0) m = map_ops.tensor_map_insert(m, k, v) with self.assertRaisesRegex(errors.InvalidArgumentError, - "Trying to lookup non-existent key."): + "Trying to lookup non-existent key. *"): l = map_ops.tensor_map_lookup(m, k2, dtypes.float32) self.evaluate(l) @@ -87,7 +88,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): m = map_ops.empty_tensor_map() k = constant_op.constant(1.0) with self.assertRaisesRegex(errors.InvalidArgumentError, - "Trying to erase non-existent item."): + "Trying to erase non-existent item. *"): m = map_ops.tensor_map_erase(m, k, dtypes.float32) self.evaluate(m) @@ -98,7 +99,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): v = constant_op.constant(2.0) m = map_ops.tensor_map_insert(m, k2, v) with self.assertRaisesRegex(errors.InvalidArgumentError, - "Trying to erase non-existent item."): + "Trying to erase non-existent item. *"): m = map_ops.tensor_map_erase(m, k, dtypes.float32) self.evaluate(m) @@ -133,6 +134,58 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllClose(l, v) self.assertAllClose(l2, default_value) + def testStackKeys(self): + m = map_ops.empty_tensor_map() + k = constant_op.constant(1.0) + k2 = constant_op.constant(2.0) + k3 = constant_op.constant(3.0) + v = constant_op.constant(21.0) + v2 = constant_op.constant(22.0) + v3 = constant_op.constant(23.0) + m = map_ops.tensor_map_insert(m, k, v) + m = map_ops.tensor_map_insert(m, k2, v2) + keys = map_ops.tensor_map_stack_keys(m, k.dtype) + expected = constant_op.constant([1.0, 2.0]) + self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected)) + self.assertAllClose(sort_ops.sort(keys), expected) + + m = map_ops.tensor_map_insert(m, k3, v3) + keys = map_ops.tensor_map_stack_keys(m, k.dtype) + expected = constant_op.constant([1.0, 2.0, 3.0]) + self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected)) + self.assertAllClose(sort_ops.sort(keys), expected) + + def testStackKeysEmptyMapFails(self): + m = map_ops.empty_tensor_map() + with self.assertRaisesRegex( + errors.InvalidArgumentError, "TensorMapStackKeys cannot be called " + "on empty map."): + keys = map_ops.tensor_map_stack_keys(m, dtypes.float32) + self.evaluate(keys) + + def testStackKeysIncorrectDtypeFails(self): + m = map_ops.empty_tensor_map() + k = constant_op.constant("key_with_wrong_dtype") + v = constant_op.constant(2.0) + m = map_ops.tensor_map_insert(m, k, v) + simple = "Key does not match requested dtype." + with self.assertRaisesRegex(errors.InvalidArgumentError, simple): + keys = map_ops.tensor_map_stack_keys(m, dtypes.float32) + self.evaluate(keys) + + def testStackKeysIncorrectShapeFails(self): + m = map_ops.empty_tensor_map() + k = constant_op.constant(1.0) + k2 = constant_op.constant([1.0, 11.0]) + v = constant_op.constant(2.0) + v2 = constant_op.constant(22.0) + m = map_ops.tensor_map_insert(m, k, v) + m = map_ops.tensor_map_insert(m, k2, v2) + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Keys must all have the same shape."): + keys = map_ops.tensor_map_stack_keys(m, dtypes.float32) + self.evaluate(keys) + def testInsertLookupGrad(self): with backprop.GradientTape() as tape: m = map_ops.empty_tensor_map() @@ -397,6 +450,5 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(s, 0) self.assertAllEqual(map_ops.tensor_map_has_key(m, k), False) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/map_ops.py b/tensorflow/python/ops/map_ops.py index ce8b7a6bc2f..7315e7e18bd 100644 --- a/tensorflow/python/ops/map_ops.py +++ b/tensorflow/python/ops/map_ops.py @@ -46,6 +46,11 @@ def tensor_map_erase(input_handle, key, value_dtype): def tensor_map_has_key(input_handle, key): return gen_map_ops.tensor_map_has_key(input_handle, key) + +def tensor_map_stack_keys(input_handle, key_dtype): + return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype) + + @ops.RegisterGradient("TensorMapLookup") def LookupGrad(op, dval): _, k = op.inputs