rename ListKeys to StackKeys

This commit is contained in:
Katherine Tian 2020-08-13 18:57:03 +00:00
parent a93f5e719f
commit 39284c4ebb
5 changed files with 12 additions and 14 deletions

View File

@ -41,6 +41,9 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
TensorMapHasKey);
REGISTER_KERNEL_BUILDER(Name("TensorMapStackKeys").Device(DEVICE_CPU),
TensorMapStackKeys);
#undef REGISTER_TENSOR_MAP_OPS_CPU
#define REGISTER_TENSOR_MAP_OPS_CPU(T)

View File

@ -174,13 +174,12 @@ class TensorMapHasKey : public OpKernel {
}
};
<<<<<<< HEAD
class TensorMapListKeys : public OpKernel {
class TensorMapStackKeys : public OpKernel {
public:
explicit TensorMapListKeys(OpKernelConstruction* c) : OpKernel(c) {
explicit TensorMapStackKeys(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("key_dtype", &key_dtype_));
}
~TensorMapListKeys() override {}
~TensorMapStackKeys() override {}
void Compute(OpKernelContext* c) override {
const TensorMap* m = nullptr;
@ -203,8 +202,6 @@ class TensorMapListKeys : public OpKernel {
DataType key_dtype_;
};
=======
>>>>>>> erase_change
template <typename Device>
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
const TensorMap& b, TensorMap* out) {

View File

@ -74,14 +74,12 @@ REGISTER_OP("TensorMapHasKey")
.Attr("key_dtype: type")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("TensorMapListKeys")
REGISTER_OP("TensorMapStackKeys")
.Input("input_handle: variant")
.Output("keys: key_dtype")
.Attr("key_dtype: type")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShape()); // output keys
//c->set_output(0, c->MakeShape({c->UnknownDim()}));
//c->set_output(0, c->Vector(2));
return Status::OK();
});

View File

@ -132,7 +132,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllClose(l, v)
self.assertAllClose(l2, default_value)
def testListKeys(self):
def testStackKeys(self):
m = map_ops.empty_tensor_map()
k = constant_op.constant(1.0)
k2 = constant_op.constant(2.0)
@ -142,13 +142,13 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
v3 = constant_op.constant(23.0)
m = map_ops.tensor_map_insert(m, k, v)
m = map_ops.tensor_map_insert(m, k2, v2)
keys = map_ops.tensor_map_list_keys(m, k.dtype)
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
expected = constant_op.constant([1.0, 2.0])
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
self.assertAllClose(sort_ops.sort(keys), expected)
m = map_ops.tensor_map_insert(m, k3, v3)
keys = map_ops.tensor_map_list_keys(m, k.dtype)
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
expected = constant_op.constant([1.0, 2.0, 3.0])
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
self.assertAllClose(sort_ops.sort(keys), expected)

View File

@ -46,8 +46,8 @@ def tensor_map_erase(input_handle, key, value_dtype):
def tensor_map_has_key(input_handle, key):
return gen_map_ops.tensor_map_has_key(input_handle, key)
def tensor_map_list_keys(input_handle, key_dtype):
return gen_map_ops.tensor_map_list_keys(input_handle, key_dtype)
def tensor_map_stack_keys(input_handle, key_dtype):
return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype)
@ops.RegisterGradient("TensorMapLookup")
def LookupGrad(op, dval):