diff --git a/tensorflow/core/api_def/base_api/api_def_RiscConcat.pbtxt b/tensorflow/core/api_def/base_api/api_def_RiscConcat.pbtxt new file mode 100644 index 00000000000..edd4b6a0a57 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_RiscConcat.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "RiscConcat" + visibility: HIDDEN +} diff --git a/tensorflow/core/kernels/risc/experimental/BUILD b/tensorflow/core/kernels/risc/experimental/BUILD index cae69b49890..86e5e7c06ed 100644 --- a/tensorflow/core/kernels/risc/experimental/BUILD +++ b/tensorflow/core/kernels/risc/experimental/BUILD @@ -27,6 +27,16 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "risc_concat_op", + srcs = ["risc_concat_op.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + tf_kernel_library( name = "risc_conv_op", srcs = ["risc_conv_op.cc"], diff --git a/tensorflow/core/kernels/risc/experimental/risc_concat_op.cc b/tensorflow/core/kernels/risc/experimental/risc_concat_op.cc new file mode 100644 index 00000000000..b01a9128d8b --- /dev/null +++ b/tensorflow/core/kernels/risc/experimental/risc_concat_op.cc @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { +namespace risc { +namespace experimental { + +template +class RiscConcatOp : public OpKernel { + public: + explicit RiscConcatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + // TODO(b/171294012): Implement RiscConcat op. + } +}; + +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("RiscConcat").Device(DEVICE_CPU).TypeConstraint("T"), \ + RiscConcatOp); + +REGISTER_CPU(bfloat16); +REGISTER_CPU(Eigen::half); +REGISTER_CPU(float); +REGISTER_CPU(double); + +} // namespace experimental +} // namespace risc +} // namespace tensorflow diff --git a/tensorflow/core/ops/risc_ops.cc b/tensorflow/core/ops/risc_ops.cc index 6eee82b0164..7b097021fc6 100644 --- a/tensorflow/core/ops/risc_ops.cc +++ b/tensorflow/core/ops/risc_ops.cc @@ -39,6 +39,15 @@ REGISTER_OP("RiscBroadcast") .Attr("Tidx: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("RiscConcat") + .Input("values: N * T") + .Input("axis: Tidx") + .Output("output: T") + .Attr("N: int >= 2") + .Attr("T: type") + .Attr("Tidx: {int32, int64} = DT_INT32") + .SetShapeFn(shape_inference::ConcatV2Shape); + // TODO(b/171294012): change shape function. REGISTER_OP("RiscConv") .Input("input: T") diff --git a/tensorflow/python/ops/risc/risc_grad.py b/tensorflow/python/ops/risc/risc_grad.py index 8d93f7b01e7..035bd9bf093 100644 --- a/tensorflow/python/ops/risc/risc_grad.py +++ b/tensorflow/python/ops/risc/risc_grad.py @@ -35,6 +35,13 @@ def _RiscBroadcastGrad(_, grad): return None, None +@ops.RegisterGradient("RiscConcat") +def _RiscConcatGrad(_, grad): + # pylint: disable=unused-argument + # TODO(b/171294012): Implement gradient of RISC with RISC ops. + return None, None + + @ops.RegisterGradient("RiscConv") def _RiscConvGrad(_, grad): # pylint: disable=unused-argument diff --git a/tensorflow/python/ops/risc/risc_ops.py b/tensorflow/python/ops/risc/risc_ops.py index d3e92c90d50..14cdb9e4b44 100644 --- a/tensorflow/python/ops/risc/risc_ops.py +++ b/tensorflow/python/ops/risc/risc_ops.py @@ -38,6 +38,10 @@ def risc_broadcast(x, shape, name='RISC_BROADCAST'): return gen_risc_ops.risc_broadcast(x, shape, name=name) +def risc_concat(x, axis, name='RISC_CONCAT'): + return gen_risc_ops.risc_concat(x, axis, name=name) + + def risc_conv(x, kernel, strides,