Restore the adding_an_op code examples that used to live under
g3doc/, now under examples/. Partial fix of #8029. Change: 149142119
This commit is contained in:
parent
20af1b7baa
commit
88c9fb09bd
157
tensorflow/examples/adding_an_op/BUILD
Normal file
157
tensorflow/examples/adding_an_op/BUILD
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# Description:
|
||||||
|
# Code examples referenced by adding_an_op
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = ["//tensorflow:internal"],
|
||||||
|
features = [
|
||||||
|
"-layering_check",
|
||||||
|
"-parse_headers",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cuda_tests_tags")
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "zero_out_op_kernel_1.so",
|
||||||
|
srcs = ["zero_out_op_kernel_1.cc"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "zero_out_op_1",
|
||||||
|
srcs = ["zero_out_op_1.py"],
|
||||||
|
data = [":zero_out_op_kernel_1.so"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "zero_out_op_kernel_2.so",
|
||||||
|
srcs = ["zero_out_op_kernel_2.cc"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "zero_out_op_2",
|
||||||
|
srcs = ["zero_out_op_2.py"],
|
||||||
|
data = [":zero_out_op_kernel_2.so"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "zero_out_op_kernel_3.so",
|
||||||
|
srcs = ["zero_out_op_kernel_3.cc"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "zero_out_op_3",
|
||||||
|
srcs = ["zero_out_op_3.py"],
|
||||||
|
data = [":zero_out_op_kernel_3.so"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "zero_out_grad_2",
|
||||||
|
srcs = ["zero_out_grad_2.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":zero_out_op_2",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
|
"//tensorflow/python:sparse_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "zero_out_1_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["zero_out_1_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":zero_out_op_1",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "zero_out_2_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["zero_out_2_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":zero_out_grad_2",
|
||||||
|
":zero_out_op_2",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "zero_out_3_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["zero_out_3_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":zero_out_op_3",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_custom_op_library(
|
||||||
|
name = "cuda_op_kernel.so",
|
||||||
|
srcs = ["cuda_op_kernel.cc"],
|
||||||
|
gpu_srcs = ["cuda_op_kernel.cu.cc"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "cuda_op",
|
||||||
|
srcs = ["cuda_op.py"],
|
||||||
|
data = [":cuda_op_kernel.so"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "cuda_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["cuda_op_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = tf_cuda_tests_tags(),
|
||||||
|
deps = [
|
||||||
|
":cuda_op",
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "fact_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["fact_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow:tensorflow_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "attr_examples",
|
||||||
|
srcs = ["attr_examples.cc"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(
|
||||||
|
["**/*"],
|
||||||
|
exclude = [
|
||||||
|
"**/METADATA",
|
||||||
|
"**/OWNERS",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
0
tensorflow/examples/adding_an_op/__init__.py
Normal file
0
tensorflow/examples/adding_an_op/__init__.py
Normal file
46
tensorflow/examples/adding_an_op/attr_examples.cc
Normal file
46
tensorflow/examples/adding_an_op/attr_examples.cc
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
|
||||||
|
REGISTER_OP("RestrictedTypeExample").Attr("t: {int32, float, bool}");
|
||||||
|
|
||||||
|
REGISTER_OP("NumberType").Attr("t: numbertype");
|
||||||
|
|
||||||
|
REGISTER_OP("EnumExample").Attr("e: {'apple', 'orange'}");
|
||||||
|
|
||||||
|
REGISTER_OP("MinIntExample").Attr("a: int >= 2");
|
||||||
|
|
||||||
|
REGISTER_OP("TypeListExample").Attr("a: list({int32, float}) >= 3");
|
||||||
|
|
||||||
|
REGISTER_OP("AttrDefaultExample").Attr("i: int = 0");
|
||||||
|
|
||||||
|
REGISTER_OP("AttrDefaultExampleForAllTypes")
|
||||||
|
.Attr("s: string = 'foo'")
|
||||||
|
.Attr("i: int = 0")
|
||||||
|
.Attr("f: float = 1.0")
|
||||||
|
.Attr("b: bool = true")
|
||||||
|
.Attr("ty: type = DT_INT32")
|
||||||
|
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
|
||||||
|
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
|
||||||
|
.Attr("l_empty: list(int) = []")
|
||||||
|
.Attr("l_int: list(int) = [2, 3, 5, 7]");
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
printf("All registered ops:\n%s\n",
|
||||||
|
tensorflow::OpRegistry::Global()->DebugString(false).c_str());
|
||||||
|
return 0;
|
||||||
|
}
|
27
tensorflow/examples/adding_an_op/cuda_op.py
Normal file
27
tensorflow/examples/adding_an_op/cuda_op.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Cuda op Python library."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if tf.test.is_built_with_cuda():
|
||||||
|
_cuda_op_module = tf.load_op_library(os.path.join(
|
||||||
|
tf.resource_loader.get_data_files_path(), 'cuda_op_kernel.so'))
|
||||||
|
add_one = _cuda_op_module.add_one
|
55
tensorflow/examples/adding_an_op/cuda_op_kernel.cc
Normal file
55
tensorflow/examples/adding_an_op/cuda_op_kernel.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
|
using namespace tensorflow; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
REGISTER_OP("AddOne")
|
||||||
|
.Input("input: int32")
|
||||||
|
.Output("output: int32")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Adds 1 to all elements of the tensor.
|
||||||
|
|
||||||
|
output: A Tensor.
|
||||||
|
output = input + 1
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
void AddOneKernelLauncher(const int* in, const int N, int* out);
|
||||||
|
|
||||||
|
class AddOneOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// Grab the input tensor
|
||||||
|
const Tensor& input_tensor = context->input(0);
|
||||||
|
auto input = input_tensor.flat<int32>();
|
||||||
|
|
||||||
|
// Create an output tensor
|
||||||
|
Tensor* output_tensor = NULL;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||||
|
&output_tensor));
|
||||||
|
auto output = output_tensor->template flat<int32>();
|
||||||
|
|
||||||
|
// Set all but the first element of the output tensor to 0.
|
||||||
|
const int N = input.size();
|
||||||
|
// Call the cuda kernel launcher
|
||||||
|
AddOneKernelLauncher(input.data(), N, output.data());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_GPU), AddOneOp);
|
31
tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc
Normal file
31
tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
|
__global__ void AddOneKernel(const int* in, const int N, int* out) {
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
|
||||||
|
i += blockDim.x * gridDim.x) {
|
||||||
|
out[i] = in[i] + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddOneKernelLauncher(const int* in, const int N, int* out) {
|
||||||
|
AddOneKernel<<<32, 256>>>(in, N, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
35
tensorflow/examples/adding_an_op/cuda_op_test.py
Normal file
35
tensorflow/examples/adding_an_op/cuda_op_test.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test for version 1 of the zero_out op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.examples.adding_an_op import cuda_op
|
||||||
|
|
||||||
|
|
||||||
|
class AddOneTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
if tf.test.is_built_with_cuda():
|
||||||
|
with self.test_session():
|
||||||
|
result = cuda_op.add_one([5, 4, 3, 2, 1])
|
||||||
|
self.assertAllEqual(result.eval(), [6, 5, 4, 3, 2])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
32
tensorflow/examples/adding_an_op/fact_test.py
Normal file
32
tensorflow/examples/adding_an_op/fact_test.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Test that user ops can be used as expected."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class FactTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
with self.test_session():
|
||||||
|
print(tf.user_ops.my_fact().eval())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
42
tensorflow/examples/adding_an_op/zero_out_1_test.py
Normal file
42
tensorflow/examples/adding_an_op/zero_out_1_test.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Test for version 1 of the zero_out op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.examples.adding_an_op import zero_out_op_1
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroOut1Test(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
with self.test_session():
|
||||||
|
result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
|
||||||
|
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
def testLoadTwice(self):
|
||||||
|
zero_out_loaded_again = tf.load_op_library(os.path.join(
|
||||||
|
tf.resource_loader.get_data_files_path(), 'zero_out_op_kernel_1.so'))
|
||||||
|
self.assertEqual(zero_out_loaded_again, zero_out_op_1._zero_out_module)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
59
tensorflow/examples/adding_an_op/zero_out_2_test.py
Normal file
59
tensorflow/examples/adding_an_op/zero_out_2_test.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Test for version 2 of the zero_out op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
from tensorflow.examples.adding_an_op import zero_out_grad_2 # pylint: disable=unused-import
|
||||||
|
from tensorflow.examples.adding_an_op import zero_out_op_2
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroOut2Test(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
with self.test_session():
|
||||||
|
result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
|
||||||
|
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
def test_2d(self):
|
||||||
|
with self.test_session():
|
||||||
|
result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
|
||||||
|
self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
|
||||||
|
|
||||||
|
def test_grad(self):
|
||||||
|
with self.test_session():
|
||||||
|
shape = (5,)
|
||||||
|
x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32)
|
||||||
|
y = zero_out_op_2.zero_out(x)
|
||||||
|
err = tf.test.compute_gradient_error(x, shape, y, shape)
|
||||||
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
|
def test_grad_2d(self):
|
||||||
|
with self.test_session():
|
||||||
|
shape = (2, 3)
|
||||||
|
x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32)
|
||||||
|
y = zero_out_op_2.zero_out(x)
|
||||||
|
err = tf.test.compute_gradient_error(x, shape, y, shape)
|
||||||
|
self.assertLess(err, 1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
52
tensorflow/examples/adding_an_op/zero_out_3_test.py
Normal file
52
tensorflow/examples/adding_an_op/zero_out_3_test.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Test for version 3 of the zero_out op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.examples.adding_an_op import zero_out_op_3
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroOut3Test(tf.test.TestCase):
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
with self.test_session():
|
||||||
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1])
|
||||||
|
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
|
||||||
|
|
||||||
|
def testAttr(self):
|
||||||
|
with self.test_session():
|
||||||
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=3)
|
||||||
|
self.assertAllEqual(result.eval(), [0, 0, 0, 2, 0])
|
||||||
|
|
||||||
|
def testNegative(self):
|
||||||
|
with self.test_session():
|
||||||
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=-1)
|
||||||
|
with self.assertRaisesOpError("Need preserve_index >= 0, got -1"):
|
||||||
|
result.eval()
|
||||||
|
|
||||||
|
def testLarge(self):
|
||||||
|
with self.test_session():
|
||||||
|
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=17)
|
||||||
|
with self.assertRaisesOpError("preserve_index out of range"):
|
||||||
|
result.eval()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
44
tensorflow/examples/adding_an_op/zero_out_grad_2.py
Normal file
44
tensorflow/examples/adding_an_op/zero_out_grad_2.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""The gradient of the tutorial zero_out op."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("ZeroOut")
|
||||||
|
def _zero_out_grad(op, grad):
|
||||||
|
"""The gradients for `zero_out`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: The `zero_out` `Operation` that we are differentiating, which we can use
|
||||||
|
to find the inputs and outputs of the original op.
|
||||||
|
grad: Gradient with respect to the output of the `zero_out` op.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Gradients with respect to the input of `zero_out`.
|
||||||
|
"""
|
||||||
|
to_zero = op.inputs[0]
|
||||||
|
shape = array_ops.shape(to_zero)
|
||||||
|
index = array_ops.zeros_like(shape)
|
||||||
|
first_grad = array_ops.reshape(grad, [-1])[0]
|
||||||
|
to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
|
||||||
|
return [to_zero_grad] # List of one Tensor, since we have one input
|
27
tensorflow/examples/adding_an_op/zero_out_op_1.py
Normal file
27
tensorflow/examples/adding_an_op/zero_out_op_1.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""ZeroOut op Python library."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
_zero_out_module = tf.load_op_library(
|
||||||
|
os.path.join(tf.resource_loader.get_data_files_path(),
|
||||||
|
'zero_out_op_kernel_1.so'))
|
||||||
|
zero_out = _zero_out_module.zero_out
|
29
tensorflow/examples/adding_an_op/zero_out_op_2.py
Normal file
29
tensorflow/examples/adding_an_op/zero_out_op_2.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""ZeroOut ops Python library."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
_zero_out_module = tf.load_op_library(
|
||||||
|
os.path.join(tf.resource_loader.get_data_files_path(),
|
||||||
|
'zero_out_op_kernel_2.so'))
|
||||||
|
zero_out = _zero_out_module.zero_out
|
||||||
|
zero_out2 = _zero_out_module.zero_out2
|
||||||
|
zero_out3 = _zero_out_module.zero_out3
|
27
tensorflow/examples/adding_an_op/zero_out_op_3.py
Normal file
27
tensorflow/examples/adding_an_op/zero_out_op_3.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""ZeroOut op Python library."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
_zero_out_module = tf.load_op_library(
|
||||||
|
os.path.join(tf.resource_loader.get_data_files_path(),
|
||||||
|
'zero_out_op_kernel_3.so'))
|
||||||
|
zero_out = _zero_out_module.zero_out
|
62
tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc
Normal file
62
tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
|
using namespace tensorflow; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
REGISTER_OP("ZeroOut")
|
||||||
|
.Input("to_zero: int32")
|
||||||
|
.Output("zeroed: int32")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->input(0));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
Zeros out all but the first value of a Tensor.
|
||||||
|
|
||||||
|
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
|
||||||
|
otherwise.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
class ZeroOutOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// Grab the input tensor
|
||||||
|
const Tensor& input_tensor = context->input(0);
|
||||||
|
auto input = input_tensor.flat<int32>();
|
||||||
|
|
||||||
|
// Create an output tensor
|
||||||
|
Tensor* output_tensor = NULL;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||||
|
&output_tensor));
|
||||||
|
auto output = output_tensor->template flat<int32>();
|
||||||
|
|
||||||
|
// Set all but the first element of the output tensor to 0.
|
||||||
|
const int N = input.size();
|
||||||
|
for (int i = 1; i < N; i++) {
|
||||||
|
output(i) = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve the first input value.
|
||||||
|
if (N > 0) output(0) = input(0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
|
115
tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc
Normal file
115
tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
|
using namespace tensorflow; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
REGISTER_OP("ZeroOut")
|
||||||
|
.Attr("T: realnumbertype")
|
||||||
|
.Input("to_zero: T")
|
||||||
|
.Output("zeroed: T")
|
||||||
|
.SetShapeFn(shape_inference::UnchangedShape)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Zeros out all but the first value of a Tensor.
|
||||||
|
|
||||||
|
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
|
||||||
|
otherwise.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ZeroOut2")
|
||||||
|
.Attr("T: realnumbertype")
|
||||||
|
.Input("to_zero: T")
|
||||||
|
.Output("zeroed: T")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Zeros out all but the first value of a Tensor.
|
||||||
|
|
||||||
|
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
|
||||||
|
otherwise.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("ZeroOut3")
|
||||||
|
.Attr("T: realnumbertype")
|
||||||
|
.Input("to_zero: T")
|
||||||
|
.Output("zeroed: T")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Zeros out all but the first value of a Tensor.
|
||||||
|
|
||||||
|
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
|
||||||
|
otherwise.
|
||||||
|
)doc");
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class ZeroOutOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// Grab the input tensor
|
||||||
|
const Tensor& input_tensor = context->input(0);
|
||||||
|
auto input = input_tensor.flat<T>();
|
||||||
|
|
||||||
|
// Create an output tensor
|
||||||
|
Tensor* output = NULL;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, input_tensor.shape(), &output));
|
||||||
|
auto output_flat = output->template flat<T>();
|
||||||
|
|
||||||
|
// Set all the elements of the output tensor to 0
|
||||||
|
const int N = input.size();
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
output_flat(i) = T(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve the first input value
|
||||||
|
if (N > 0) output_flat(0) = input(0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ZeroOut")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<float>("T"),
|
||||||
|
ZeroOutOp<float>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ZeroOut")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<double>("T"),
|
||||||
|
ZeroOutOp<double>);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ZeroOut")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<int>("T"),
|
||||||
|
ZeroOutOp<int>);
|
||||||
|
|
||||||
|
#define REGISTER_KERNEL(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("ZeroOut2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
ZeroOutOp<type>)
|
||||||
|
|
||||||
|
REGISTER_KERNEL(float);
|
||||||
|
REGISTER_KERNEL(double);
|
||||||
|
REGISTER_KERNEL(int32);
|
||||||
|
|
||||||
|
#undef REGISTER_KERNEL
|
||||||
|
|
||||||
|
#define REGISTER_KERNEL(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("ZeroOut3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||||
|
ZeroOutOp<type>)
|
||||||
|
|
||||||
|
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||||
|
|
||||||
|
#undef REGISTER_KERNEL
|
72
tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc
Normal file
72
tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
|
using namespace tensorflow; // NOLINT(build/namespaces)
|
||||||
|
|
||||||
|
REGISTER_OP("ZeroOut")
|
||||||
|
.Attr("preserve_index: int = 0")
|
||||||
|
.Input("to_zero: int32")
|
||||||
|
.Output("zeroed: int32")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->input(0));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
class ZeroOutOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||||
|
// Get the index of the value to preserve
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("preserve_index", &preserve_index_));
|
||||||
|
// Check that preserve\_index is positive
|
||||||
|
OP_REQUIRES(context, preserve_index_ >= 0,
|
||||||
|
errors::InvalidArgument("Need preserve_index >= 0, got ",
|
||||||
|
preserve_index_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// Grab the input tensor
|
||||||
|
const Tensor& input_tensor = context->input(0);
|
||||||
|
auto input = input_tensor.flat<int32>();
|
||||||
|
|
||||||
|
// Check that preserve_index is in range
|
||||||
|
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
|
||||||
|
errors::InvalidArgument("preserve_index out of range"));
|
||||||
|
|
||||||
|
// Create an output tensor
|
||||||
|
Tensor* output_tensor = NULL;
|
||||||
|
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||||
|
&output_tensor));
|
||||||
|
auto output = output_tensor->template flat<int32>();
|
||||||
|
|
||||||
|
// Set all the elements of the output tensor to 0
|
||||||
|
const int N = input.size();
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
output(i) = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preserve the requested input value
|
||||||
|
output(preserve_index_) = input(preserve_index_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int preserve_index_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
|
Loading…
Reference in New Issue
Block a user