rename ListKeys to StackKeys
This commit is contained in:
parent
a93f5e719f
commit
39284c4ebb
@ -41,6 +41,9 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
|
||||
TensorMapHasKey);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorMapStackKeys").Device(DEVICE_CPU),
|
||||
TensorMapStackKeys);
|
||||
|
||||
#undef REGISTER_TENSOR_MAP_OPS_CPU
|
||||
|
||||
#define REGISTER_TENSOR_MAP_OPS_CPU(T)
|
||||
|
@ -174,13 +174,12 @@ class TensorMapHasKey : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
<<<<<<< HEAD
|
||||
class TensorMapListKeys : public OpKernel {
|
||||
class TensorMapStackKeys : public OpKernel {
|
||||
public:
|
||||
explicit TensorMapListKeys(OpKernelConstruction* c) : OpKernel(c) {
|
||||
explicit TensorMapStackKeys(OpKernelConstruction* c) : OpKernel(c) {
|
||||
OP_REQUIRES_OK(c, c->GetAttr("key_dtype", &key_dtype_));
|
||||
}
|
||||
~TensorMapListKeys() override {}
|
||||
~TensorMapStackKeys() override {}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const TensorMap* m = nullptr;
|
||||
@ -203,8 +202,6 @@ class TensorMapListKeys : public OpKernel {
|
||||
DataType key_dtype_;
|
||||
};
|
||||
|
||||
=======
|
||||
>>>>>>> erase_change
|
||||
template <typename Device>
|
||||
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
|
||||
const TensorMap& b, TensorMap* out) {
|
||||
|
@ -74,14 +74,12 @@ REGISTER_OP("TensorMapHasKey")
|
||||
.Attr("key_dtype: type")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("TensorMapListKeys")
|
||||
REGISTER_OP("TensorMapStackKeys")
|
||||
.Input("input_handle: variant")
|
||||
.Output("keys: key_dtype")
|
||||
.Attr("key_dtype: type")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShape()); // output keys
|
||||
//c->set_output(0, c->MakeShape({c->UnknownDim()}));
|
||||
//c->set_output(0, c->Vector(2));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
|
@ -132,7 +132,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllClose(l, v)
|
||||
self.assertAllClose(l2, default_value)
|
||||
|
||||
def testListKeys(self):
|
||||
def testStackKeys(self):
|
||||
m = map_ops.empty_tensor_map()
|
||||
k = constant_op.constant(1.0)
|
||||
k2 = constant_op.constant(2.0)
|
||||
@ -142,13 +142,13 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
v3 = constant_op.constant(23.0)
|
||||
m = map_ops.tensor_map_insert(m, k, v)
|
||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||
keys = map_ops.tensor_map_list_keys(m, k.dtype)
|
||||
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
|
||||
expected = constant_op.constant([1.0, 2.0])
|
||||
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
||||
self.assertAllClose(sort_ops.sort(keys), expected)
|
||||
|
||||
m = map_ops.tensor_map_insert(m, k3, v3)
|
||||
keys = map_ops.tensor_map_list_keys(m, k.dtype)
|
||||
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
|
||||
expected = constant_op.constant([1.0, 2.0, 3.0])
|
||||
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
||||
self.assertAllClose(sort_ops.sort(keys), expected)
|
||||
|
@ -46,8 +46,8 @@ def tensor_map_erase(input_handle, key, value_dtype):
|
||||
def tensor_map_has_key(input_handle, key):
|
||||
return gen_map_ops.tensor_map_has_key(input_handle, key)
|
||||
|
||||
def tensor_map_list_keys(input_handle, key_dtype):
|
||||
return gen_map_ops.tensor_map_list_keys(input_handle, key_dtype)
|
||||
def tensor_map_stack_keys(input_handle, key_dtype):
|
||||
return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype)
|
||||
|
||||
@ops.RegisterGradient("TensorMapLookup")
|
||||
def LookupGrad(op, dval):
|
||||
|
Loading…
Reference in New Issue
Block a user