Expand risc ops and update shape function for binary ops.

PiperOrigin-RevId: 353334713
Change-Id: If1eea4d6d3bc98c8d97faf5f2e3130e15052c6af
This commit is contained in:
Jian Li 2021-01-22 15:50:49 -08:00 committed by TensorFlower Gardener
parent 697f117f36
commit 21d0847f3c
37 changed files with 1353 additions and 105 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscAbs"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscCeil"
visibility: HIDDEN
}

View File

@ -1,54 +1,4 @@
op {
graph_op_name: "RiscConv"
visibility: HIDDEN
in_arg {
name: "input"
description: <<END
A 4-D tensor. The dimension order is interpreted according to the value
of `data_format`, see below for details.
END
}
in_arg {
name: "filter"
description: <<END
A 4-D tensor of shape
`[filter_height, filter_width, in_channels, out_channels]`
END
}
out_arg {
name: "output"
description: <<END
A 4-D tensor. The dimension order is determined by the value of
`data_format`, see below for details.
END
}
attr {
name: "strides"
description: <<END
1-D tensor of length 4. The stride of the sliding window for each
dimension of `input`. The dimension order is determined by the value of
`data_format`, see below for details.
END
}
attr {
name: "data_format"
description: <<END
Specify the data format of the input and output data. With the
default format "NHWC", the data is stored in the order of:
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
attr {
name: "dilations"
description: <<END
1-D tensor of length 4. The dilation factor for each dimension of
`input`. If set to k > 1, there will be k-1 skipped cells between each
filter element on that dimension. The dimension order is determined by the
value of `data_format`, see above for details. Dilations in the batch and
depth dimensions must be 1.
END
}
summary: "Computes a 2-D convolution given 4-D `input` and `filter` tensors."
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscCos"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscDiv"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscExp"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscFloor"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscImag"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscLog"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscMin"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscMul"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscNeg"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscPow"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscReal"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscRem"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscSign"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscSub"
visibility: HIDDEN
}

View File

@ -7,6 +7,16 @@ package(
licenses = ["notice"], # Apache 2.0
)
tf_kernel_library(
name = "risc_abs_op",
srcs = ["risc_abs_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_add_op",
srcs = ["risc_add_op.cc"],
@ -67,6 +77,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_ceil_op",
srcs = ["risc_ceil_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_cholesky_op",
srcs = ["risc_cholesky_op.cc"],
@ -107,6 +127,26 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_cos_op",
srcs = ["risc_cos_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_div_op",
srcs = ["risc_div_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_dot_op",
srcs = ["risc_dot_op.cc"],
@ -117,6 +157,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_exp_op",
srcs = ["risc_exp_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_fft_op",
srcs = ["risc_fft_op.cc"],
@ -127,6 +177,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_floor_op",
srcs = ["risc_floor_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_gather_op",
srcs = ["risc_gather_op.cc"],
@ -137,6 +197,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_imag_op",
srcs = ["risc_imag_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_is_finite_op",
srcs = ["risc_is_finite_op.cc"],
@ -147,6 +217,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_log_op",
srcs = ["risc_log_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_logical_and_op",
srcs = ["risc_logical_and_op.cc"],
@ -187,6 +267,36 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_min_op",
srcs = ["risc_min_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_mul_op",
srcs = ["risc_mul_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_neg_op",
srcs = ["risc_neg_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_pad_op",
srcs = ["risc_pad_op.cc"],
@ -207,6 +317,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_pow_op",
srcs = ["risc_pow_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_random_uniform_op",
srcs = ["risc_random_uniform_op.cc"],
@ -217,6 +337,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_real_op",
srcs = ["risc_real_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_reduce_op",
srcs = ["risc_reduce_op.cc"],
@ -227,6 +357,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_rem_op",
srcs = ["risc_rem_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_reshape_op",
srcs = ["risc_reshape_op.cc"],
@ -267,6 +407,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_sign_op",
srcs = ["risc_sign_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_slice_op",
srcs = ["risc_slice_op.cc"],
@ -297,6 +447,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_sub_op",
srcs = ["risc_sub_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_transpose_op",
srcs = ["risc_transpose_op.cc"],
@ -340,34 +500,50 @@ tf_kernel_library(
tf_kernel_library(
name = "experimental",
deps = [
":risc_abs_op",
":risc_add_op",
":risc_binary_arithmetic_op",
":risc_binary_comparison_op",
":risc_bitcast_op",
":risc_broadcast_op",
":risc_cast_op",
":risc_ceil_op",
":risc_cholesky_op",
":risc_condition_op",
":risc_conv_op",
":risc_cos_op",
":risc_div_op",
":risc_dot_op",
":risc_exp_op",
":risc_fft_op",
":risc_floor_op",
":risc_gather_op",
":risc_imag_op",
":risc_is_finite_op",
":risc_log_op",
":risc_logical_and_op",
":risc_logical_not_op",
":risc_logical_or_op",
":risc_max_op",
":risc_min_op",
":risc_mul_op",
":risc_neg_op",
":risc_pad_op",
":risc_pool_op",
":risc_pow_op",
":risc_random_uniform_op",
":risc_real_op",
":risc_reduce_op",
":risc_rem_op",
":risc_reshape_op",
":risc_reverse_op",
":risc_scatter_op",
":risc_shape_op",
":risc_sign_op",
":risc_slice_op",
":risc_sort_op",
":risc_squeeze_op",
":risc_sub_op",
":risc_transpose_op",
":risc_triangular_solve_op",
":risc_unary_op",

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscAbsOp : public OpKernel {
public:
explicit RiscAbsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscAbs op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscAbs").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscAbsOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscCeilOp : public OpKernel {
public:
explicit RiscCeilOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscCeil op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscCeil").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscCeilOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscCosOp : public OpKernel {
public:
explicit RiscCosOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscCos op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscCos").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscCosOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscDivOp : public OpKernel {
public:
explicit RiscDivOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscDiv op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscDiv").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscDivOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscExpOp : public OpKernel {
public:
explicit RiscExpOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscExp op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscExp").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscExpOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscFloorOp : public OpKernel {
public:
explicit RiscFloorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscFloor op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscFloor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscFloorOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscImagOp : public OpKernel {
public:
explicit RiscImagOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscImag op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscImag").Device(DEVICE_CPU), RiscImagOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscLogOp : public OpKernel {
public:
explicit RiscLogOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscLog op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscLog").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscLogOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscMinOp : public OpKernel {
public:
explicit RiscMinOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscMin op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscMin").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscMinOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscMulOp : public OpKernel {
public:
explicit RiscMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscMul op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscMulOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscNegOp : public OpKernel {
public:
explicit RiscNegOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscNeg op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscNeg").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscNegOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscPowOp : public OpKernel {
public:
explicit RiscPowOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscPow op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscPow").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscPowOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscRealOp : public OpKernel {
public:
explicit RiscRealOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscReal op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscReal").Device(DEVICE_CPU), RiscRealOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscRemOp : public OpKernel {
public:
explicit RiscRemOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscRem op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscRem").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscRemOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscSignOp : public OpKernel {
public:
explicit RiscSignOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscSign op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscSign").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscSignOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscMaxOp : public OpKernel {
public:
explicit RiscMaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscMax op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscMax").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscMaxOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -21,16 +21,46 @@ limitations under the License.
namespace tensorflow {
namespace {
Status RiscBinaryNonBroadcastOpShapeFn(shape_inference::InferenceContext* c) {
const auto rank = c->Rank(c->input(0));
if (rank != c->Rank(c->input(1))) {
return errors::InvalidArgument("Mismatch rank for input.");
}
for (int i = 0; i < rank; ++i) {
if (!c->ValueKnown(c->Dim(c->input(0), i)) ||
!c->ValueKnown(c->Dim(c->input(1), i))) {
continue;
}
if (c->Value(c->Dim(c->input(0), i)) != c->Value(c->Dim(c->input(1), i))) {
return errors::InvalidArgument("Mismatch shapes for input.");
}
}
c->set_output(0, c->input(0));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
}
return Status::OK();
}
} // namespace
REGISTER_OP("RiscAbs")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscAdd")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn)
.SetIsAggregate()
.SetIsCommutative();
// TODO(b/171294012): include RiscMax here as well.
// TODO(b/178234771): retire this.
REGISTER_OP("RiscBinaryArithmetic")
.Input("x: T")
.Input("y: T")
@ -45,9 +75,9 @@ REGISTER_OP("RiscBinaryComparison")
.Output("z: bool")
.Attr("op_type: {'EQ', 'NE', 'GE', 'GT', 'LE', 'LT'}")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscBitcast")
.Input("x: SrcT")
.Output("y: DstT")
@ -55,7 +85,7 @@ REGISTER_OP("RiscBitcast")
.Attr("DstT: type")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscBroadcast")
.Input("input: T")
.Input("shape: Tidx")
@ -71,7 +101,13 @@ REGISTER_OP("RiscCast")
.Attr("DstT: type")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscCeil")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscCholesky")
.Input("input: T")
.Output("output: T")
@ -87,7 +123,7 @@ REGISTER_OP("RiscConcat")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ConcatV2Shape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscCondition")
.Input("pred: bool")
.Input("input_true: SrcT")
@ -99,7 +135,7 @@ REGISTER_OP("RiscCondition")
.Attr("DstT: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscConv")
.Input("input: T")
.Input("filter: T")
@ -110,6 +146,19 @@ REGISTER_OP("RiscConv")
.SetShapeFn(shape_inference::UnknownShape)
.Attr("dilations: list(int) = [1, 1, 1, 1]");
REGISTER_OP("RiscCos")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscDiv")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
REGISTER_OP("RiscDot")
.Input("a: T")
.Input("b: T")
@ -119,14 +168,26 @@ REGISTER_OP("RiscDot")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::MatMulShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscExp")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscFft")
.Input("input: Tcomplex")
.Output("output: Tcomplex")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscFloor")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscGather")
.Input("params: Tparams")
.Input("indices: Tindices")
@ -138,37 +199,72 @@ REGISTER_OP("RiscGather")
.Attr("Taxis: {int32,int64}")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscImag")
.Input("input: T")
.Output("output: Tout")
.Attr("T: {complex64, complex128} = DT_COMPLEX64")
.Attr("Tout: {float, double} = DT_FLOAT")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscIsFinite")
.Input("x: T")
.Output("y: bool")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscLog")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscLogicalAnd")
.Input("x: bool")
.Input("y: bool")
.Output("z: bool")
.SetShapeFn(shape_inference::UnchangedShape);
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscLogicalNot")
.Input("x: bool")
.Output("z: bool")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscLogicalOr")
.Input("x: bool")
.Input("y: bool")
.Output("z: bool")
.SetShapeFn(shape_inference::UnchangedShape);
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscMax")
.Input("x: T")
.Input("y: T")
.Output("max: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
REGISTER_OP("RiscMin")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
REGISTER_OP("RiscMul")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
REGISTER_OP("RiscNeg")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscPad")
.Input("input: T")
.Input("paddings: Tpaddings")
@ -178,7 +274,7 @@ REGISTER_OP("RiscPad")
.Attr("Tpaddings: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscPool")
.Input("value: T")
.Output("output: T")
@ -189,6 +285,13 @@ REGISTER_OP("RiscPool")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscPow")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
REGISTER_OP("RiscRandomUniform")
.Input("shape: T")
.Output("output: float")
@ -196,7 +299,14 @@ REGISTER_OP("RiscRandomUniform")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscReal")
.Input("input: T")
.Output("output: Tout")
.Attr("T: {complex64, complex128} = DT_COMPLEX64")
.Attr("Tout: {float, double} = DT_FLOAT")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscReduce")
.Input("tensor: T")
.Input("axis: Index")
@ -206,7 +316,14 @@ REGISTER_OP("RiscReduce")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscRem")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscReshape")
.Input("tensor: T")
.Input("shape: Tshape")
@ -215,6 +332,7 @@ REGISTER_OP("RiscReshape")
.Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscReverse")
.Input("tensor: T")
.Input("axis: Tidx")
@ -223,7 +341,7 @@ REGISTER_OP("RiscReverse")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscScatter")
.Input("indices: Tindices")
.Input("updates: T")
@ -233,7 +351,7 @@ REGISTER_OP("RiscScatter")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscShape")
.Input("input: T")
.Output("output: out_type")
@ -241,6 +359,12 @@ REGISTER_OP("RiscShape")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscSign")
.Input("x: T")
.Output("y: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscSlice")
.Input("input: T")
.Input("begin: Index")
@ -259,7 +383,7 @@ REGISTER_OP("RiscSort")
.Attr("direction: {'ASCENDING', 'DESCENDING'}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscSqueeze")
.Input("input: T")
.Output("output: T")
@ -267,7 +391,14 @@ REGISTER_OP("RiscSqueeze")
.Attr("squeeze_dims: list(int) >= 0 = []")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscSub")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscTranspose")
.Input("x: T")
.Input("perm: Tperm")
@ -276,7 +407,7 @@ REGISTER_OP("RiscTranspose")
.Attr("Tperm: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscTriangularSolve")
.Input("matrix: T")
.Input("rhs: T")
@ -286,6 +417,7 @@ REGISTER_OP("RiscTriangularSolve")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/178234771): retire this.
REGISTER_OP("RiscUnary")
.Input("x: T")
.Output("y: T")
@ -295,7 +427,7 @@ REGISTER_OP("RiscUnary")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
// TODO(b/178234771): change shape function.
REGISTER_OP("RiscWhile")
.Input("input: T")
.Output("output: T")

View File

@ -21,232 +21,344 @@ from __future__ import print_function
from tensorflow.python.framework import ops
@ops.RegisterGradient("RiscAbs")
def _RiscAbsGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscAdd")
def _RiscAddGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscBinaryArithmetic")
def _RiscBinaryArithmeticGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscBinaryComparison")
def _RiscBinaryComparisonGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscBitcast")
def _RiscBitcastGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscBroadcast")
def _RiscBroadcastGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscCast")
def _RiscCastGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscCholesky")
def _RiscCholeskyGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscCeil")
def _RiscCeilGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscConcat")
def _RiscConcatGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscCondition")
def _RiscConditionGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscConv")
def _RiscConvGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscCos")
def _RiscCosGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscDiv")
def _RiscDivGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscDot")
def _RiscDotGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscExp")
def _RiscExpGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscFft")
def _RiscFftGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscFloor")
def _RiscFloorGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscGather")
def _RiscGatherGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscImag")
def _RiscImagGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscIsFinite")
def _RiscIsFiniteGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscLog")
def _RiscLogGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscLogicalAnd")
def _RiscLogicalAndGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscLogicalNot")
def _RiscLogicalNotGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscLogicalOr")
def _RiscLogicalOrGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscMax")
def _RiscMaxGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscMin")
def _RiscMinGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscMul")
def _RiscMulGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscNeg")
def _RiscNegGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscPad")
def _RiscPadGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscPool")
def _RiscPoolGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscPow")
def _RiscPowGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscRandomUniform")
def _RiscRandomUniformGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscReal")
def _RiscRealGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscReduce")
def _RiscReduceGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscRem")
def _RiscRemGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscReshape")
def _RiscReshapeGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscReverse")
def _RiscReverseGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscScatter")
def _RiscScatterGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscShape")
def _RiscShapeGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscSign")
def _RiscSignGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscSlice")
def _RiscSliceGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscSort")
def _RiscSortGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscSqueeze")
def _RiscSqueezeGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscSub")
def _RiscSubGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscTranspose")
def _RiscTransposeGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscTriangularSolve")
def _RiscTriangularSolvesGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscUnary")
def _RiscUnaryGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscWhile")
def _RiscWhileGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
# TODO(b/178234771): Implement gradient of RISC with RISC ops.
return None, None

View File

@ -22,6 +22,10 @@ from __future__ import print_function
from tensorflow.python.ops import gen_risc_ops
def risc_abs(x, name='RISC_ABS'):
return gen_risc_ops.risc_abs(x, name=name)
def risc_add(
input_lhs,
input_rhs,
@ -49,6 +53,14 @@ def risc_cast(x, dtype, name='RISC_CAST'):
return gen_risc_ops.risc_cast(x, dtype, name=name)
def risc_ceil(x, name='RISC_CEIL'):
return gen_risc_ops.risc_ceil(x, name=name)
def risc_cos(x, name='RISC_COS'):
return gen_risc_ops.risc_cos(x, name=name)
def risc_cholesky(x, name='RISC_CHOLESKY'):
return gen_risc_ops.risc_cholesky(x, name=name)
@ -87,6 +99,10 @@ def risc_conv(x,
name=name)
def risc_div(input_lhs, input_rhs, name='RISC_DIV'):
return gen_risc_ops.risc_div(input_lhs, input_rhs, name=name)
def risc_dot(input_lhs,
input_rhs,
transpose_a=False,
@ -100,10 +116,18 @@ def risc_dot(input_lhs,
name=name)
def risc_exp(x, name='RISC_EXP'):
return gen_risc_ops.risc_exp(x, name=name)
def risc_fft(x, name='RISC_FFT'):
return gen_risc_ops.risc_fft(x, name=name)
def risc_floor(x, name='RISC_FLOOR'):
return gen_risc_ops.risc_floor(x, name=name)
def risc_gather(params,
indices,
validate_indices=None,
@ -119,10 +143,18 @@ def risc_gather(params,
batch_dims=batch_dims)
def risc_imag(x, name='RISC_IMAG'):
return gen_risc_ops.risc_imag(x, name=name)
def risc_is_finite(x, name='RISC_IS_FINITE'):
return gen_risc_ops.risc_is_finite(x, name=name)
def risc_log(x, name='RISC_LOG'):
return gen_risc_ops.risc_log(x, name=name)
def risc_logical_and(a, b, name='RISC_LOGICAL_AND'):
return gen_risc_ops.risc_logical_and(a, b, name=name)
@ -139,6 +171,18 @@ def risc_max(input_lhs, input_rhs, name='RISC_MAX'):
return gen_risc_ops.risc_max(input_lhs, input_rhs, name=name)
def risc_min(input_lhs, input_rhs, name='RISC_MIN'):
return gen_risc_ops.risc_min(input_lhs, input_rhs, name=name)
def risc_mul(input_lhs, input_rhs, name='RISC_MUL'):
return gen_risc_ops.risc_mul(input_lhs, input_rhs, name=name)
def risc_neg(x, name='RISC_NEG'):
return gen_risc_ops.risc_neg(x, name=name)
def risc_pad(x, padding, constant_values, name='RISC_PAD'):
return gen_risc_ops.risc_pad(x, padding, constant_values, name=name)
@ -148,14 +192,26 @@ def risc_pool(x, ksize, strides, pooling_type='MAX', name='RISC_POOL'):
x, ksize, strides, pooling_type=pooling_type, name=name)
def risc_pow(input_lhs, input_rhs, name='RISC_POW'):
return gen_risc_ops.risc_pow(input_lhs, input_rhs, name=name)
def risc_random_uniform(shape, seed, name='RISC_RANDOM_UNIFORM'):
return gen_risc_ops.risc_random_uniform(shape, seed, name=name)
def risc_real(x, name='RISC_REAL'):
return gen_risc_ops.risc_real(x, name=name)
def risc_reduce(x, axis, reduce_type, name='RISC_REDUCE'):
return gen_risc_ops.risc_reduce(x, axis, reduce_type=reduce_type, name=name)
def risc_rem(x, name='RISC_REM'):
return gen_risc_ops.risc_rem(x, name=name)
def risc_reshape(x, shape, name='RISC_RESHAPE'):
return gen_risc_ops.risc_reshape(x, shape, name=name)
@ -172,10 +228,18 @@ def risc_shape(x, name='RISC_SHAPE'):
return gen_risc_ops.risc_shape(x, name=name)
def risc_sign(x, name='RISC_SIGN'):
return gen_risc_ops.risc_sign(x, name=name)
def risc_slice(x, begin, size, name='RISC_SLICE'):
return gen_risc_ops.risc_slice(x, begin, size, name=name)
def risc_sub(input_lhs, input_rhs, name='RISC_SUB'):
return gen_risc_ops.risc_sub(input_lhs, input_rhs, name=name)
def risc_sort(x, axis, direction='ASCENDING', name='RISC_SORT'):
return gen_risc_ops.risc_sort(x, axis, direction=direction, name=name)