Restore the adding_an_op code examples that used to live under

g3doc/, now under examples/.  Partial fix of #8029.
Change: 149142119
This commit is contained in:
A. Unique TensorFlower 2017-03-03 12:42:42 -08:00 committed by TensorFlower Gardener
parent 20af1b7baa
commit 88c9fb09bd
18 changed files with 912 additions and 0 deletions

View File

@ -0,0 +1,157 @@
# Description:
# Code examples referenced by adding_an_op
package(
default_visibility = ["//tensorflow:internal"],
features = [
"-layering_check",
"-parse_headers",
],
)
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
load("//tensorflow:tensorflow.bzl", "tf_cuda_tests_tags")
exports_files(["LICENSE"])
tf_custom_op_library(
name = "zero_out_op_kernel_1.so",
srcs = ["zero_out_op_kernel_1.cc"],
)
py_library(
name = "zero_out_op_1",
srcs = ["zero_out_op_1.py"],
data = [":zero_out_op_kernel_1.so"],
srcs_version = "PY2AND3",
)
tf_custom_op_library(
name = "zero_out_op_kernel_2.so",
srcs = ["zero_out_op_kernel_2.cc"],
)
py_library(
name = "zero_out_op_2",
srcs = ["zero_out_op_2.py"],
data = [":zero_out_op_kernel_2.so"],
srcs_version = "PY2AND3",
)
tf_custom_op_library(
name = "zero_out_op_kernel_3.so",
srcs = ["zero_out_op_kernel_3.cc"],
)
py_library(
name = "zero_out_op_3",
srcs = ["zero_out_op_3.py"],
data = [":zero_out_op_kernel_3.so"],
srcs_version = "PY2AND3",
)
py_library(
name = "zero_out_grad_2",
srcs = ["zero_out_grad_2.py"],
srcs_version = "PY2AND3",
deps = [
":zero_out_op_2",
"//tensorflow:tensorflow_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:sparse_ops",
],
)
py_test(
name = "zero_out_1_test",
size = "small",
srcs = ["zero_out_1_test.py"],
srcs_version = "PY2AND3",
deps = [
":zero_out_op_1",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
],
)
py_test(
name = "zero_out_2_test",
size = "small",
srcs = ["zero_out_2_test.py"],
srcs_version = "PY2AND3",
deps = [
":zero_out_grad_2",
":zero_out_op_2",
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "zero_out_3_test",
size = "small",
srcs = ["zero_out_3_test.py"],
srcs_version = "PY2AND3",
deps = [
":zero_out_op_3",
"//tensorflow:tensorflow_py",
"//third_party/py/numpy",
],
)
tf_custom_op_library(
name = "cuda_op_kernel.so",
srcs = ["cuda_op_kernel.cc"],
gpu_srcs = ["cuda_op_kernel.cu.cc"],
)
py_library(
name = "cuda_op",
srcs = ["cuda_op.py"],
data = [":cuda_op_kernel.so"],
srcs_version = "PY2AND3",
)
py_test(
name = "cuda_op_test",
size = "small",
srcs = ["cuda_op_test.py"],
srcs_version = "PY2AND3",
tags = tf_cuda_tests_tags(),
deps = [
":cuda_op",
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "fact_test",
size = "small",
srcs = ["fact_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
cc_binary(
name = "attr_examples",
srcs = ["attr_examples.cc"],
deps = [
"//tensorflow/core",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdio.h>
#include "tensorflow/core/framework/op.h"
REGISTER_OP("RestrictedTypeExample").Attr("t: {int32, float, bool}");
REGISTER_OP("NumberType").Attr("t: numbertype");
REGISTER_OP("EnumExample").Attr("e: {'apple', 'orange'}");
REGISTER_OP("MinIntExample").Attr("a: int >= 2");
REGISTER_OP("TypeListExample").Attr("a: list({int32, float}) >= 3");
REGISTER_OP("AttrDefaultExample").Attr("i: int = 0");
REGISTER_OP("AttrDefaultExampleForAllTypes")
.Attr("s: string = 'foo'")
.Attr("i: int = 0")
.Attr("f: float = 1.0")
.Attr("b: bool = true")
.Attr("ty: type = DT_INT32")
.Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
.Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
.Attr("l_empty: list(int) = []")
.Attr("l_int: list(int) = [2, 3, 5, 7]");
int main(int argc, char* argv[]) {
printf("All registered ops:\n%s\n",
tensorflow::OpRegistry::Global()->DebugString(false).c_str());
return 0;
}

View File

@ -0,0 +1,27 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cuda op Python library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
if tf.test.is_built_with_cuda():
_cuda_op_module = tf.load_op_library(os.path.join(
tf.resource_loader.get_data_files_path(), 'cuda_op_kernel.so'))
add_one = _cuda_op_module.add_one

View File

@ -0,0 +1,55 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow; // NOLINT(build/namespaces)
REGISTER_OP("AddOne")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
Adds 1 to all elements of the tensor.
output: A Tensor.
output = input + 1
)doc");
void AddOneKernelLauncher(const int* in, const int N, int* out);
class AddOneOp : public OpKernel {
public:
explicit AddOneOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
// Call the cuda kernel launcher
AddOneKernelLauncher(input.data(), N, output.data());
}
};
REGISTER_KERNEL_BUILDER(Name("AddOne").Device(DEVICE_GPU), AddOneOp);

View File

@ -0,0 +1,31 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
__global__ void AddOneKernel(const int* in, const int N, int* out) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
out[i] = in[i] + 1;
}
}
void AddOneKernelLauncher(const int* in, const int N, int* out) {
AddOneKernel<<<32, 256>>>(in, N, out);
}
#endif

View File

@ -0,0 +1,35 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for version 1 of the zero_out op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.adding_an_op import cuda_op
class AddOneTest(tf.test.TestCase):
def test(self):
if tf.test.is_built_with_cuda():
with self.test_session():
result = cuda_op.add_one([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [6, 5, 4, 3, 2])
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,32 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test that user ops can be used as expected."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class FactTest(tf.test.TestCase):
def test(self):
with self.test_session():
print(tf.user_ops.my_fact().eval())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,42 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for version 1 of the zero_out op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
from tensorflow.examples.adding_an_op import zero_out_op_1
class ZeroOut1Test(tf.test.TestCase):
def test(self):
with self.test_session():
result = zero_out_op_1.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
def testLoadTwice(self):
zero_out_loaded_again = tf.load_op_library(os.path.join(
tf.resource_loader.get_data_files_path(), 'zero_out_op_kernel_1.so'))
self.assertEqual(zero_out_loaded_again, zero_out_op_1._zero_out_module)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,59 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for version 2 of the zero_out op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.adding_an_op import zero_out_grad_2 # pylint: disable=unused-import
from tensorflow.examples.adding_an_op import zero_out_op_2
class ZeroOut2Test(tf.test.TestCase):
def test(self):
with self.test_session():
result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
def test_2d(self):
with self.test_session():
result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
def test_grad(self):
with self.test_session():
shape = (5,)
x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32)
y = zero_out_op_2.zero_out(x)
err = tf.test.compute_gradient_error(x, shape, y, shape)
self.assertLess(err, 1e-4)
def test_grad_2d(self):
with self.test_session():
shape = (2, 3)
x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32)
y = zero_out_op_2.zero_out(x)
err = tf.test.compute_gradient_error(x, shape, y, shape)
self.assertLess(err, 1e-4)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,52 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for version 3 of the zero_out op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.adding_an_op import zero_out_op_3
class ZeroOut3Test(tf.test.TestCase):
def test(self):
with self.test_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
def testAttr(self):
with self.test_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=3)
self.assertAllEqual(result.eval(), [0, 0, 0, 2, 0])
def testNegative(self):
with self.test_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=-1)
with self.assertRaisesOpError("Need preserve_index >= 0, got -1"):
result.eval()
def testLarge(self):
with self.test_session():
result = zero_out_op_3.zero_out([5, 4, 3, 2, 1], preserve_index=17)
with self.assertRaisesOpError("preserve_index out of range"):
result.eval()
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,44 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The gradient of the tutorial zero_out op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
"""The gradients for `zero_out`.
Args:
op: The `zero_out` `Operation` that we are differentiating, which we can use
to find the inputs and outputs of the original op.
grad: Gradient with respect to the output of the `zero_out` op.
Returns:
Gradients with respect to the input of `zero_out`.
"""
to_zero = op.inputs[0]
shape = array_ops.shape(to_zero)
index = array_ops.zeros_like(shape)
first_grad = array_ops.reshape(grad, [-1])[0]
to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
return [to_zero_grad] # List of one Tensor, since we have one input

View File

@ -0,0 +1,27 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ZeroOut op Python library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
_zero_out_module = tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(),
'zero_out_op_kernel_1.so'))
zero_out = _zero_out_module.zero_out

View File

@ -0,0 +1,29 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ZeroOut ops Python library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
_zero_out_module = tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(),
'zero_out_op_kernel_2.so'))
zero_out = _zero_out_module.zero_out
zero_out2 = _zero_out_module.zero_out2
zero_out3 = _zero_out_module.zero_out3

View File

@ -0,0 +1,27 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ZeroOut op Python library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
_zero_out_module = tf.load_op_library(
os.path.join(tf.resource_loader.get_data_files_path(),
'zero_out_op_kernel_3.so'))
zero_out = _zero_out_module.zero_out

View File

@ -0,0 +1,62 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow; // NOLINT(build/namespaces)
REGISTER_OP("ZeroOut")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
.Doc(R"doc(
Zeros out all but the first value of a Tensor.
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
otherwise.
)doc");
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
// Set all but the first element of the output tensor to 0.
const int N = input.size();
for (int i = 1; i < N; i++) {
output(i) = 0;
}
// Preserve the first input value.
if (N > 0) output(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

View File

@ -0,0 +1,115 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow; // NOLINT(build/namespaces)
REGISTER_OP("ZeroOut")
.Attr("T: realnumbertype")
.Input("to_zero: T")
.Output("zeroed: T")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Zeros out all but the first value of a Tensor.
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
otherwise.
)doc");
REGISTER_OP("ZeroOut2")
.Attr("T: realnumbertype")
.Input("to_zero: T")
.Output("zeroed: T")
.Doc(R"doc(
Zeros out all but the first value of a Tensor.
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
otherwise.
)doc");
REGISTER_OP("ZeroOut3")
.Attr("T: realnumbertype")
.Input("to_zero: T")
.Output("zeroed: T")
.Doc(R"doc(
Zeros out all but the first value of a Tensor.
zeroed: A Tensor whose first value is identical to `to_zero`, and 0
otherwise.
)doc");
template <typename T>
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<T>();
// Create an output tensor
Tensor* output = NULL;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_tensor.shape(), &output));
auto output_flat = output->template flat<T>();
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output_flat(i) = T(0);
}
// Preserve the first input value
if (N > 0) output_flat(0) = input(0);
}
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
ZeroOutOp<double>);
REGISTER_KERNEL_BUILDER(Name("ZeroOut")
.Device(DEVICE_CPU)
.TypeConstraint<int>("T"),
ZeroOutOp<int>);
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);
REGISTER_KERNEL(int32);
#undef REGISTER_KERNEL
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("ZeroOut3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
ZeroOutOp<type>)
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL

View File

@ -0,0 +1,72 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow; // NOLINT(build/namespaces)
REGISTER_OP("ZeroOut")
.Attr("preserve_index: int = 0")
.Input("to_zero: int32")
.Output("zeroed: int32")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});
class ZeroOutOp : public OpKernel {
public:
explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
// Get the index of the value to preserve
OP_REQUIRES_OK(context,
context->GetAttr("preserve_index", &preserve_index_));
// Check that preserve\_index is positive
OP_REQUIRES(context, preserve_index_ >= 0,
errors::InvalidArgument("Need preserve_index >= 0, got ",
preserve_index_));
}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
// Check that preserve_index is in range
OP_REQUIRES(context, preserve_index_ < input.dimension(0),
errors::InvalidArgument("preserve_index out of range"));
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
auto output = output_tensor->template flat<int32>();
// Set all the elements of the output tensor to 0
const int N = input.size();
for (int i = 0; i < N; i++) {
output(i) = 0;
}
// Preserve the requested input value
output(preserve_index_) = input(preserve_index_);
}
private:
int preserve_index_;
};
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);