Add RISC Conv Op register.

PiperOrigin-RevId: 341415820
Change-Id: Ibd5f4c939e22be2af61e434e6927898b74e523a5
This commit is contained in:
Jian Li 2020-11-09 08:58:46 -08:00 committed by TensorFlower Gardener
parent c86b90dc9e
commit 55a311cb73
6 changed files with 149 additions and 1 deletions

View File

@ -0,0 +1,54 @@
op {
graph_op_name: "RiscConv"
visibility: HIDDEN
in_arg {
name: "input"
description: <<END
A 4-D tensor. The dimension order is interpreted according to the value
of `data_format`, see below for details.
END
}
in_arg {
name: "filter"
description: <<END
A 4-D tensor of shape
`[filter_height, filter_width, in_channels, out_channels]`
END
}
out_arg {
name: "output"
description: <<END
A 4-D tensor. The dimension order is determined by the value of
`data_format`, see below for details.
END
}
attr {
name: "strides"
description: <<END
1-D tensor of length 4. The stride of the sliding window for each
dimension of `input`. The dimension order is determined by the value of
`data_format`, see below for details.
END
}
attr {
name: "data_format"
description: <<END
Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
attr {
name: "dilations"
description: <<END
1-D tensor of length 4. The dilation factor for each dimension of
`input`. If set to k > 1, there will be k-1 skipped cells between each
filter element on that dimension. The dimension order is determined by the
value of `data_format`, see above for details. Dilations in the batch and
depth dimensions must be 1.
END
}
summary: "Computes a 2-D convolution given 4-D `input` and `filter` tensors."
}

View File

@ -17,9 +17,20 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_conv_op",
srcs = ["risc_conv_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "experimental",
deps = [
":risc_add_op",
":risc_conv_op",
],
)

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename T>
class RiscConvOp : public OpKernel {
public:
explicit RiscConvOp(OpKernelConstruction* context) : OpKernel(context) {
// TODO(b/171294012): Implement RiscConv op.
}
void Compute(OpKernelContext* context) override {
// TODO(b/171294012): Implement RiscConv op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscConv").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscConvOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -30,4 +30,15 @@ REGISTER_OP("RiscAdd")
.SetIsAggregate()
.SetIsCommutative();
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscConv")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("strides: list(int)")
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Attr("dilations: list(int) = [1, 1, 1, 1]");
} // namespace tensorflow

View File

@ -28,3 +28,10 @@ def _RiscAddGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscConv")
def _RiscConvGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None

View File

@ -30,5 +30,20 @@ from tensorflow.python.ops.risc_ops_gen import *
def risc_add(
input_lhs,
input_rhs,
name="RISC_ADD"):
name='RISC_ADD'):
return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name)
def risc_conv(x,
kernel,
strides,
data_format='NHWC',
dilations=None,
name='RISC_CONV'):
return gen_risc_ops.risc_conv(
x,
kernel,
strides,
data_format=data_format,
dilations=dilations,
name=name)