rename ListKeys to StackKeys
This commit is contained in:
parent
a93f5e719f
commit
39284c4ebb
@ -41,6 +41,9 @@ REGISTER_KERNEL_BUILDER(Name("TensorMapErase").Device(DEVICE_CPU),
|
|||||||
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("TensorMapHasKey").Device(DEVICE_CPU),
|
||||||
TensorMapHasKey);
|
TensorMapHasKey);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("TensorMapStackKeys").Device(DEVICE_CPU),
|
||||||
|
TensorMapStackKeys);
|
||||||
|
|
||||||
#undef REGISTER_TENSOR_MAP_OPS_CPU
|
#undef REGISTER_TENSOR_MAP_OPS_CPU
|
||||||
|
|
||||||
#define REGISTER_TENSOR_MAP_OPS_CPU(T)
|
#define REGISTER_TENSOR_MAP_OPS_CPU(T)
|
||||||
|
@ -174,13 +174,12 @@ class TensorMapHasKey : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
<<<<<<< HEAD
|
class TensorMapStackKeys : public OpKernel {
|
||||||
class TensorMapListKeys : public OpKernel {
|
|
||||||
public:
|
public:
|
||||||
explicit TensorMapListKeys(OpKernelConstruction* c) : OpKernel(c) {
|
explicit TensorMapStackKeys(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
OP_REQUIRES_OK(c, c->GetAttr("key_dtype", &key_dtype_));
|
OP_REQUIRES_OK(c, c->GetAttr("key_dtype", &key_dtype_));
|
||||||
}
|
}
|
||||||
~TensorMapListKeys() override {}
|
~TensorMapStackKeys() override {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* c) override {
|
void Compute(OpKernelContext* c) override {
|
||||||
const TensorMap* m = nullptr;
|
const TensorMap* m = nullptr;
|
||||||
@ -203,8 +202,6 @@ class TensorMapListKeys : public OpKernel {
|
|||||||
DataType key_dtype_;
|
DataType key_dtype_;
|
||||||
};
|
};
|
||||||
|
|
||||||
=======
|
|
||||||
>>>>>>> erase_change
|
|
||||||
template <typename Device>
|
template <typename Device>
|
||||||
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
|
Status TensorMapBinaryAdd(OpKernelContext* c, const TensorMap& a,
|
||||||
const TensorMap& b, TensorMap* out) {
|
const TensorMap& b, TensorMap* out) {
|
||||||
|
@ -74,14 +74,12 @@ REGISTER_OP("TensorMapHasKey")
|
|||||||
.Attr("key_dtype: type")
|
.Attr("key_dtype: type")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("TensorMapListKeys")
|
REGISTER_OP("TensorMapStackKeys")
|
||||||
.Input("input_handle: variant")
|
.Input("input_handle: variant")
|
||||||
.Output("keys: key_dtype")
|
.Output("keys: key_dtype")
|
||||||
.Attr("key_dtype: type")
|
.Attr("key_dtype: type")
|
||||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
c->set_output(0, c->UnknownShape()); // output keys
|
c->set_output(0, c->UnknownShape()); // output keys
|
||||||
//c->set_output(0, c->MakeShape({c->UnknownDim()}));
|
|
||||||
//c->set_output(0, c->Vector(2));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(l, v)
|
self.assertAllClose(l, v)
|
||||||
self.assertAllClose(l2, default_value)
|
self.assertAllClose(l2, default_value)
|
||||||
|
|
||||||
def testListKeys(self):
|
def testStackKeys(self):
|
||||||
m = map_ops.empty_tensor_map()
|
m = map_ops.empty_tensor_map()
|
||||||
k = constant_op.constant(1.0)
|
k = constant_op.constant(1.0)
|
||||||
k2 = constant_op.constant(2.0)
|
k2 = constant_op.constant(2.0)
|
||||||
@ -142,13 +142,13 @@ class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
v3 = constant_op.constant(23.0)
|
v3 = constant_op.constant(23.0)
|
||||||
m = map_ops.tensor_map_insert(m, k, v)
|
m = map_ops.tensor_map_insert(m, k, v)
|
||||||
m = map_ops.tensor_map_insert(m, k2, v2)
|
m = map_ops.tensor_map_insert(m, k2, v2)
|
||||||
keys = map_ops.tensor_map_list_keys(m, k.dtype)
|
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
|
||||||
expected = constant_op.constant([1.0, 2.0])
|
expected = constant_op.constant([1.0, 2.0])
|
||||||
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
||||||
self.assertAllClose(sort_ops.sort(keys), expected)
|
self.assertAllClose(sort_ops.sort(keys), expected)
|
||||||
|
|
||||||
m = map_ops.tensor_map_insert(m, k3, v3)
|
m = map_ops.tensor_map_insert(m, k3, v3)
|
||||||
keys = map_ops.tensor_map_list_keys(m, k.dtype)
|
keys = map_ops.tensor_map_stack_keys(m, k.dtype)
|
||||||
expected = constant_op.constant([1.0, 2.0, 3.0])
|
expected = constant_op.constant([1.0, 2.0, 3.0])
|
||||||
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
self.assertAllClose(array_ops.shape(keys), array_ops.shape(expected))
|
||||||
self.assertAllClose(sort_ops.sort(keys), expected)
|
self.assertAllClose(sort_ops.sort(keys), expected)
|
||||||
|
@ -46,8 +46,8 @@ def tensor_map_erase(input_handle, key, value_dtype):
|
|||||||
def tensor_map_has_key(input_handle, key):
|
def tensor_map_has_key(input_handle, key):
|
||||||
return gen_map_ops.tensor_map_has_key(input_handle, key)
|
return gen_map_ops.tensor_map_has_key(input_handle, key)
|
||||||
|
|
||||||
def tensor_map_list_keys(input_handle, key_dtype):
|
def tensor_map_stack_keys(input_handle, key_dtype):
|
||||||
return gen_map_ops.tensor_map_list_keys(input_handle, key_dtype)
|
return gen_map_ops.tensor_map_stack_keys(input_handle, key_dtype)
|
||||||
|
|
||||||
@ops.RegisterGradient("TensorMapLookup")
|
@ops.RegisterGradient("TensorMapLookup")
|
||||||
def LookupGrad(op, dval):
|
def LookupGrad(op, dval):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user