From 55a311cb735689a431c6aa9a6c765c5c5c034ede Mon Sep 17 00:00:00 2001 From: Jian Li Date: Mon, 9 Nov 2020 08:58:46 -0800 Subject: [PATCH] Add RISC Conv Op register. PiperOrigin-RevId: 341415820 Change-Id: Ibd5f4c939e22be2af61e434e6927898b74e523a5 --- .../api_def/base_api/api_def_RiscConv.pbtxt | 54 +++++++++++++++++++ .../core/kernels/risc/experimental/BUILD | 11 ++++ .../kernels/risc/experimental/risc_conv_op.cc | 50 +++++++++++++++++ tensorflow/core/ops/risc_ops.cc | 11 ++++ tensorflow/python/ops/risc/risc_grad.py | 7 +++ tensorflow/python/ops/risc/risc_ops.py | 17 +++++- 6 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_RiscConv.pbtxt create mode 100644 tensorflow/core/kernels/risc/experimental/risc_conv_op.cc diff --git a/tensorflow/core/api_def/base_api/api_def_RiscConv.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscConv.pbtxt new file mode 100644 index 00000000000..a78ee1d2b89 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RiscConv.pbtxt @@ -0,0 +1,54 @@ +op { + graph_op_name: "RiscConv" + visibility: HIDDEN + in_arg { + name: "input" + description: < 1, there will be k-1 skipped cells between each +filter element on that dimension. The dimension order is determined by the +value of `data_format`, see above for details. Dilations in the batch and +depth dimensions must be 1. +END + } + summary: "Computes a 2-D convolution given 4-D `input` and `filter` tensors." +} diff --git a/tensorflow/core/kernels/risc/experimental/BUILD b/tensorflow/core/kernels/risc/experimental/BUILD index a16c0b66271..d0e94be3120 100644 --- a/tensorflow/core/kernels/risc/experimental/BUILD +++ b/tensorflow/core/kernels/risc/experimental/BUILD @@ -17,9 +17,20 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "risc_conv_op", + srcs = ["risc_conv_op.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + tf_kernel_library( name = "experimental", deps = [ ":risc_add_op", + ":risc_conv_op", ], ) diff --git a/tensorflow/core/kernels/risc/experimental/risc_conv_op.cc b/tensorflow/core/kernels/risc/experimental/risc_conv_op.cc new file mode 100644 index 00000000000..58c5ee98eae --- /dev/null +++ b/tensorflow/core/kernels/risc/experimental/risc_conv_op.cc @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace risc { +namespace experimental { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class RiscConvOp : public OpKernel { + public: + explicit RiscConvOp(OpKernelConstruction* context) : OpKernel(context) { + // TODO(b/171294012): Implement RiscConv op. + } + + void Compute(OpKernelContext* context) override { + // TODO(b/171294012): Implement RiscConv op. + } +}; + +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("RiscConv").Device(DEVICE_CPU).TypeConstraint("T"), \ + RiscConvOp); + +REGISTER_CPU(float); +REGISTER_CPU(double); + +} // namespace experimental +} // namespace risc +} // namespace tensorflow diff --git a/tensorflow/core/ops/risc_ops.cc b/tensorflow/core/ops/risc_ops.cc index 1d90a645965..a5b1e37fa84 100644 --- a/tensorflow/core/ops/risc_ops.cc +++ b/tensorflow/core/ops/risc_ops.cc @@ -30,4 +30,15 @@ REGISTER_OP("RiscAdd") .SetIsAggregate() .SetIsCommutative(); +// TODO(b/171294012): change shape function. +REGISTER_OP("RiscConv") + .Input("input: T") + .Input("filter: T") + .Output("output: T") + .Attr("T: {float, double}") + .Attr("strides: list(int)") + .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn(shape_inference::UnknownShape) + .Attr("dilations: list(int) = [1, 1, 1, 1]"); + } // namespace tensorflow diff --git a/tensorflow/python/ops/risc/risc_grad.py b/tensorflow/python/ops/risc/risc_grad.py index b125aab895a..5c0f76ba3a4 100644 --- a/tensorflow/python/ops/risc/risc_grad.py +++ b/tensorflow/python/ops/risc/risc_grad.py @@ -28,3 +28,10 @@ def _RiscAddGrad(_, grad): # pylint: disable=unused-argument # TODO(b/171294012): Implement gradient of RISC with RISC ops. return None, None + + +@ops.RegisterGradient("RiscConv") +def _RiscConvGrad(_, grad): + # pylint: disable=unused-argument + # TODO(b/171294012): Implement gradient of RISC with RISC ops. + return None, None diff --git a/tensorflow/python/ops/risc/risc_ops.py b/tensorflow/python/ops/risc/risc_ops.py index 8682ebdd269..f59e42dbf6e 100644 --- a/tensorflow/python/ops/risc/risc_ops.py +++ b/tensorflow/python/ops/risc/risc_ops.py @@ -30,5 +30,20 @@ from tensorflow.python.ops.risc_ops_gen import * def risc_add( input_lhs, input_rhs, - name="RISC_ADD"): + name='RISC_ADD'): return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name) + + +def risc_conv(x, + kernel, + strides, + data_format='NHWC', + dilations=None, + name='RISC_CONV'): + return gen_risc_ops.risc_conv( + x, + kernel, + strides, + data_format=data_format, + dilations=dilations, + name=name)