diff --git a/tensorflow/examples/adding_an_op/BUILD b/tensorflow/examples/adding_an_op/BUILD new file mode 100644 index 00000000000..ffaf9349d2a --- /dev/null +++ b/tensorflow/examples/adding_an_op/BUILD @@ -0,0 +1,157 @@ +# Description: +# Code examples referenced by adding_an_op + +package( + default_visibility = ["//tensorflow:internal"], + features = [ + "-layering_check", + "-parse_headers", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_cuda_tests_tags") + +exports_files(["LICENSE"]) + +tf_custom_op_library( + name = "zero_out_op_kernel_1.so", + srcs = ["zero_out_op_kernel_1.cc"], +) + +py_library( + name = "zero_out_op_1", + srcs = ["zero_out_op_1.py"], + data = [":zero_out_op_kernel_1.so"], + srcs_version = "PY2AND3", +) + +tf_custom_op_library( + name = "zero_out_op_kernel_2.so", + srcs = ["zero_out_op_kernel_2.cc"], +) + +py_library( + name = "zero_out_op_2", + srcs = ["zero_out_op_2.py"], + data = [":zero_out_op_kernel_2.so"], + srcs_version = "PY2AND3", +) + +tf_custom_op_library( + name = "zero_out_op_kernel_3.so", + srcs = ["zero_out_op_kernel_3.cc"], +) + +py_library( + name = "zero_out_op_3", + srcs = ["zero_out_op_3.py"], + data = [":zero_out_op_kernel_3.so"], + srcs_version = "PY2AND3", +) + +py_library( + name = "zero_out_grad_2", + srcs = ["zero_out_grad_2.py"], + srcs_version = "PY2AND3", + deps = [ + ":zero_out_op_2", + "//tensorflow:tensorflow_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:sparse_ops", + ], +) + +py_test( + name = "zero_out_1_test", + size = "small", + srcs = ["zero_out_1_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":zero_out_op_1", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +py_test( + name = "zero_out_2_test", + size = "small", + srcs = ["zero_out_2_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":zero_out_grad_2", + ":zero_out_op_2", + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "zero_out_3_test", + size = "small", + srcs = ["zero_out_3_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":zero_out_op_3", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +tf_custom_op_library( + name = "cuda_op_kernel.so", + srcs = ["cuda_op_kernel.cc"], + gpu_srcs = ["cuda_op_kernel.cu.cc"], +) + +py_library( + name = "cuda_op", + srcs = ["cuda_op.py"], + data = [":cuda_op_kernel.so"], + srcs_version = "PY2AND3", +) + +py_test( + name = "cuda_op_test", + size = "small", + srcs = ["cuda_op_test.py"], + srcs_version = "PY2AND3", + tags = tf_cuda_tests_tags(), + deps = [ + ":cuda_op", + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "fact_test", + size = "small", + srcs = ["fact_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +cc_binary( + name = "attr_examples", + srcs = ["attr_examples.cc"], + deps = [ + "//tensorflow/core", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/examples/adding_an_op/__init__.py b/tensorflow/examples/adding_an_op/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tensorflow/examples/adding_an_op/attr_examples.cc b/tensorflow/examples/adding_an_op/attr_examples.cc new file mode 100644 index 00000000000..4eb35668cee --- /dev/null +++ b/tensorflow/examples/adding_an_op/attr_examples.cc @@ -0,0 +1,46 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/core/framework/op.h" + +REGISTER_OP("RestrictedTypeExample").Attr("t: {int32, float, bool}"); + +REGISTER_OP("NumberType").Attr("t: numbertype"); + +REGISTER_OP("EnumExample").Attr("e: {'apple', 'orange'}"); + +REGISTER_OP("MinIntExample").Attr("a: int >= 2"); + +REGISTER_OP("TypeListExample").Attr("a: list({int32, float}) >= 3"); + +REGISTER_OP("AttrDefaultExample").Attr("i: int = 0"); + +REGISTER_OP("AttrDefaultExampleForAllTypes") + .Attr("s: string = 'foo'") + .Attr("i: int = 0") + .Attr("f: float = 1.0") + .Attr("b: bool = true") + .Attr("ty: type = DT_INT32") + .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }") + .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }") + .Attr("l_empty: list(int) = []") + .Attr("l_int: list(int) = [2, 3, 5, 7]"); + +int main(int argc, char* argv[]) { + printf("All registered ops:\n%s\n", + tensorflow::OpRegistry::Global()->DebugString(false).c_str()); + return 0; +} diff --git a/tensorflow/examples/adding_an_op/cuda_op.py b/tensorflow/examples/adding_an_op/cuda_op.py new file mode 100644 index 00000000000..dd5428870bf --- /dev/null +++ b/tensorflow/examples/adding_an_op/cuda_op.py @@ -0,0 +1,27 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cuda op Python library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import tensorflow as tf + +if tf.test.is_built_with_cuda(): + _cuda_op_module = tf.load_op_library(os.path.join( + tf.resource_loader.get_data_files_path(), 'cuda_op_kernel.so')) + add_one = _cuda_op_module.add_one diff --git a/tensorflow/examples/adding_an_op/cuda_op_kernel.cc b/tensorflow/examples/adding_an_op/cuda_op_kernel.cc new file mode 100644 index 00000000000..2f323b43d85 --- /dev/null +++ b/tensorflow/examples/adding_an_op/cuda_op_kernel.cc @@ -0,0 +1,55 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" + +using namespace tensorflow; // NOLINT(build/namespaces) + +REGISTER_OP("AddOne") + .Input("input: int32") + .Output("output: int32") + .Doc(R"doc( +Adds 1 to all elements of the tensor. + +output: A Tensor. + output = input + 1 +)doc"); + +void AddOneKernelLauncher(const int* in, const int N, int* out); + +class AddOneOp : public OpKernel { + public: + explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + + // Create an output tensor + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + auto output = output_tensor->template flat(); + + // Set all but the first element of the output tensor to 0. + const int N = input.size(); + // Call the cuda kernel launcher + AddOneKernelLauncher(input.data(), N, output.data()); + } +}; + +REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_GPU), AddOneOp); diff --git a/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc b/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc new file mode 100644 index 00000000000..65b50bd3ae9 --- /dev/null +++ b/tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc @@ -0,0 +1,31 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +__global__ void AddOneKernel(const int* in, const int N, int* out) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; + i += blockDim.x * gridDim.x) { + out[i] = in[i] + 1; + } +} + +void AddOneKernelLauncher(const int* in, const int N, int* out) { + AddOneKernel<<<32, 256>>>(in, N, out); +} + +#endif diff --git a/tensorflow/examples/adding_an_op/cuda_op_test.py b/tensorflow/examples/adding_an_op/cuda_op_test.py new file mode 100644 index 00000000000..07390bc3bf1 --- /dev/null +++ b/tensorflow/examples/adding_an_op/cuda_op_test.py @@ -0,0 +1,35 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for version 1 of the zero_out op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.examples.adding_an_op import cuda_op + + +class AddOneTest(tf.test.TestCase): + + def test(self): + if tf.test.is_built_with_cuda(): + with self.test_session(): + result = cuda_op.add_one([5, 4, 3, 2, 1]) + self.assertAllEqual(result.eval(), [6, 5, 4, 3, 2]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/examples/adding_an_op/fact_test.py b/tensorflow/examples/adding_an_op/fact_test.py new file mode 100644 index 00000000000..f7f17e51803 --- /dev/null +++ b/tensorflow/examples/adding_an_op/fact_test.py @@ -0,0 +1,32 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test that user ops can be used as expected.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class FactTest(tf.test.TestCase): + + def test(self): + with self.test_session(): + print(tf.user_ops.my_fact().eval()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/examples/adding_an_op/zero_out_1_test.py b/tensorflow/examples/adding_an_op/zero_out_1_test.py new file mode 100644 index 00000000000..fac486100d8 --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_1_test.py @@ -0,0 +1,42 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for version 1 of the zero_out op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import tensorflow as tf +from tensorflow.examples.adding_an_op import zero_out_op_1 + + +class ZeroOut1Test(tf.test.TestCase): + + def test(self): + with self.test_session(): + result = zero_out_op_1.zero_out([5, 4, 3, 2, 1]) + self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0]) + + def testLoadTwice(self): + zero_out_loaded_again = tf.load_op_library(os.path.join( + tf.resource_loader.get_data_files_path(), 'zero_out_op_kernel_1.so')) + self.assertEqual(zero_out_loaded_again, zero_out_op_1._zero_out_module) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/examples/adding_an_op/zero_out_2_test.py b/tensorflow/examples/adding_an_op/zero_out_2_test.py new file mode 100644 index 00000000000..217bbbcffa3 --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_2_test.py @@ -0,0 +1,59 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for version 2 of the zero_out op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +from tensorflow.examples.adding_an_op import zero_out_grad_2 # pylint: disable=unused-import +from tensorflow.examples.adding_an_op import zero_out_op_2 + + +class ZeroOut2Test(tf.test.TestCase): + + def test(self): + with self.test_session(): + result = zero_out_op_2.zero_out([5, 4, 3, 2, 1]) + self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0]) + + def test_2d(self): + with self.test_session(): + result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]]) + self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]]) + + def test_grad(self): + with self.test_session(): + shape = (5,) + x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32) + y = zero_out_op_2.zero_out(x) + err = tf.test.compute_gradient_error(x, shape, y, shape) + self.assertLess(err, 1e-4) + + def test_grad_2d(self): + with self.test_session(): + shape = (2, 3) + x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32) + y = zero_out_op_2.zero_out(x) + err = tf.test.compute_gradient_error(x, shape, y, shape) + self.assertLess(err, 1e-4) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/examples/adding_an_op/zero_out_3_test.py b/tensorflow/examples/adding_an_op/zero_out_3_test.py new file mode 100644 index 00000000000..01280caf495 --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_3_test.py @@ -0,0 +1,52 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test for version 3 of the zero_out op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.examples.adding_an_op import zero_out_op_3 + + +class ZeroOut3Test(tf.test.TestCase): + + def test(self): + with self.test_session(): + result = zero_out_op_3.zero_out([5, 4, 3, 2, 1]) + self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0]) + + def testAttr(self): + with self.test_session(): + result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=3) + self.assertAllEqual(result.eval(), [0, 0, 0, 2, 0]) + + def testNegative(self): + with self.test_session(): + result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=-1) + with self.assertRaisesOpError("Need preserve_index >= 0, got -1"): + result.eval() + + def testLarge(self): + with self.test_session(): + result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=17) + with self.assertRaisesOpError("preserve_index out of range"): + result.eval() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/examples/adding_an_op/zero_out_grad_2.py b/tensorflow/examples/adding_an_op/zero_out_grad_2.py new file mode 100644 index 00000000000..dc24678e339 --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_grad_2.py @@ -0,0 +1,44 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""The gradient of the tutorial zero_out op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import sparse_ops + + +@ops.RegisterGradient("ZeroOut") +def _zero_out_grad(op, grad): + """The gradients for `zero_out`. + + Args: + op: The `zero_out` `Operation` that we are differentiating, which we can use + to find the inputs and outputs of the original op. + grad: Gradient with respect to the output of the `zero_out` op. + + Returns: + Gradients with respect to the input of `zero_out`. + """ + to_zero = op.inputs[0] + shape = array_ops.shape(to_zero) + index = array_ops.zeros_like(shape) + first_grad = array_ops.reshape(grad, [-1])[0] + to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0) + return [to_zero_grad] # List of one Tensor, since we have one input diff --git a/tensorflow/examples/adding_an_op/zero_out_op_1.py b/tensorflow/examples/adding_an_op/zero_out_op_1.py new file mode 100644 index 00000000000..6bd98b1f06d --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_op_1.py @@ -0,0 +1,27 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ZeroOut op Python library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import tensorflow as tf + +_zero_out_module = tf.load_op_library( + os.path.join(tf.resource_loader.get_data_files_path(), + 'zero_out_op_kernel_1.so')) +zero_out = _zero_out_module.zero_out diff --git a/tensorflow/examples/adding_an_op/zero_out_op_2.py b/tensorflow/examples/adding_an_op/zero_out_op_2.py new file mode 100644 index 00000000000..ba1e8b4d62c --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_op_2.py @@ -0,0 +1,29 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ZeroOut ops Python library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import tensorflow as tf + +_zero_out_module = tf.load_op_library( + os.path.join(tf.resource_loader.get_data_files_path(), + 'zero_out_op_kernel_2.so')) +zero_out = _zero_out_module.zero_out +zero_out2 = _zero_out_module.zero_out2 +zero_out3 = _zero_out_module.zero_out3 diff --git a/tensorflow/examples/adding_an_op/zero_out_op_3.py b/tensorflow/examples/adding_an_op/zero_out_op_3.py new file mode 100644 index 00000000000..4354ecaf5a6 --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_op_3.py @@ -0,0 +1,27 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ZeroOut op Python library.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import tensorflow as tf + +_zero_out_module = tf.load_op_library( + os.path.join(tf.resource_loader.get_data_files_path(), + 'zero_out_op_kernel_3.so')) +zero_out = _zero_out_module.zero_out diff --git a/tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc b/tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc new file mode 100644 index 00000000000..cc8c719c1ac --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_op_kernel_1.cc @@ -0,0 +1,62 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; // NOLINT(build/namespaces) + +REGISTER_OP("ZeroOut") + .Input("to_zero: int32") + .Output("zeroed: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Zeros out all but the first value of a Tensor. + +zeroed: A Tensor whose first value is identical to `to_zero`, and 0 + otherwise. +)doc"); + +class ZeroOutOp : public OpKernel { + public: + explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + + // Create an output tensor + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + auto output = output_tensor->template flat(); + + // Set all but the first element of the output tensor to 0. + const int N = input.size(); + for (int i = 1; i < N; i++) { + output(i) = 0; + } + + // Preserve the first input value. + if (N > 0) output(0) = input(0); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp); diff --git a/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc b/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc new file mode 100644 index 00000000000..3aa18c73071 --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_op_kernel_2.cc @@ -0,0 +1,115 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; // NOLINT(build/namespaces) + +REGISTER_OP("ZeroOut") + .Attr("T: realnumbertype") + .Input("to_zero: T") + .Output("zeroed: T") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Zeros out all but the first value of a Tensor. + +zeroed: A Tensor whose first value is identical to `to_zero`, and 0 + otherwise. +)doc"); + +REGISTER_OP("ZeroOut2") + .Attr("T: realnumbertype") + .Input("to_zero: T") + .Output("zeroed: T") + .Doc(R"doc( +Zeros out all but the first value of a Tensor. + +zeroed: A Tensor whose first value is identical to `to_zero`, and 0 + otherwise. +)doc"); + +REGISTER_OP("ZeroOut3") + .Attr("T: realnumbertype") + .Input("to_zero: T") + .Output("zeroed: T") + .Doc(R"doc( +Zeros out all but the first value of a Tensor. + +zeroed: A Tensor whose first value is identical to `to_zero`, and 0 + otherwise. +)doc"); + +template +class ZeroOutOp : public OpKernel { + public: + explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + + // Create an output tensor + Tensor* output = NULL; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_tensor.shape(), &output)); + auto output_flat = output->template flat(); + + // Set all the elements of the output tensor to 0 + const int N = input.size(); + for (int i = 0; i < N; i++) { + output_flat(i) = T(0); + } + + // Preserve the first input value + if (N > 0) output_flat(0) = input(0); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ZeroOut") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + ZeroOutOp); +REGISTER_KERNEL_BUILDER(Name("ZeroOut") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + ZeroOutOp); +REGISTER_KERNEL_BUILDER(Name("ZeroOut") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + ZeroOutOp); + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ZeroOut2").Device(DEVICE_CPU).TypeConstraint("T"), \ + ZeroOutOp) + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +REGISTER_KERNEL(int32); + +#undef REGISTER_KERNEL + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ZeroOut3").Device(DEVICE_CPU).TypeConstraint("T"), \ + ZeroOutOp) + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL); + +#undef REGISTER_KERNEL diff --git a/tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc b/tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc new file mode 100644 index 00000000000..76f6efa3340 --- /dev/null +++ b/tensorflow/examples/adding_an_op/zero_out_op_kernel_3.cc @@ -0,0 +1,72 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" + +using namespace tensorflow; // NOLINT(build/namespaces) + +REGISTER_OP("ZeroOut") + .Attr("preserve_index: int = 0") + .Input("to_zero: int32") + .Output("zeroed: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +class ZeroOutOp : public OpKernel { + public: + explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) { + // Get the index of the value to preserve + OP_REQUIRES_OK(context, + context->GetAttr("preserve_index", &preserve_index_)); + // Check that preserve\_index is positive + OP_REQUIRES(context, preserve_index_ >= 0, + errors::InvalidArgument("Need preserve_index >= 0, got ", + preserve_index_)); + } + + void Compute(OpKernelContext* context) override { + // Grab the input tensor + const Tensor& input_tensor = context->input(0); + auto input = input_tensor.flat(); + + // Check that preserve_index is in range + OP_REQUIRES(context, preserve_index_ < input.dimension(0), + errors::InvalidArgument("preserve_index out of range")); + + // Create an output tensor + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + auto output = output_tensor->template flat(); + + // Set all the elements of the output tensor to 0 + const int N = input.size(); + for (int i = 0; i < N; i++) { + output(i) = 0; + } + + // Preserve the requested input value + output(preserve_index_) = input(preserve_index_); + } + + private: + int preserve_index_; +}; + +REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);