Add risc op register.
PiperOrigin-RevId: 346946572 Change-Id: I63580bb0491591439a928038185542d52cde5e3b
This commit is contained in:
parent
b2ad941e21
commit
44a4129b24
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscBinaryArithmetic"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscBinaryComparison"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscBitcast"
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_RiscCast.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_RiscCast.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscCast"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscCholesky"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscCondition"
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_RiscFft.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_RiscFft.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscFft"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscGather"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscIsFinite"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscLogicalAnd"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscLogicalNot"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscLogicalOr"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscRandomUniform"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscReduce"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscReverse"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscScatter"
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_RiscSort.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_RiscSort.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscSort"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscSqueeze"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscTranspose"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscTriangularSolve"
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_RiscUnary.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_RiscUnary.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscUnary"
|
||||
visibility: HIDDEN
|
||||
}
|
4
tensorflow/core/api_def/base_api/api_def_RiscWhile.pbtxt
Normal file
4
tensorflow/core/api_def/base_api/api_def_RiscWhile.pbtxt
Normal file
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "RiscWhile"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -17,6 +17,36 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_binary_arithmetic_op",
|
||||
srcs = ["risc_binary_arithmetic_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_binary_comparison_op",
|
||||
srcs = ["risc_binary_comparison_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_bitcast_op",
|
||||
srcs = ["risc_bitcast_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_broadcast_op",
|
||||
srcs = ["risc_broadcast_op.cc"],
|
||||
@ -27,6 +57,26 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_cast_op",
|
||||
srcs = ["risc_cast_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_cholesky_op",
|
||||
srcs = ["risc_cholesky_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_concat_op",
|
||||
srcs = ["risc_concat_op.cc"],
|
||||
@ -37,6 +87,16 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_condition_op",
|
||||
srcs = ["risc_condition_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_conv_op",
|
||||
srcs = ["risc_conv_op.cc"],
|
||||
@ -57,6 +117,66 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_fft_op",
|
||||
srcs = ["risc_fft_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_gather_op",
|
||||
srcs = ["risc_gather_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_is_finite_op",
|
||||
srcs = ["risc_is_finite_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_logical_and_op",
|
||||
srcs = ["risc_logical_and_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_logical_not_op",
|
||||
srcs = ["risc_logical_not_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_logical_or_op",
|
||||
srcs = ["risc_logical_or_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_max_op",
|
||||
srcs = ["risc_max_op.cc"],
|
||||
@ -87,6 +207,26 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_random_uniform_op",
|
||||
srcs = ["risc_random_uniform_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_reduce_op",
|
||||
srcs = ["risc_reduce_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_reshape_op",
|
||||
srcs = ["risc_reshape_op.cc"],
|
||||
@ -97,6 +237,26 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_reverse_op",
|
||||
srcs = ["risc_reverse_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_scatter_op",
|
||||
srcs = ["risc_scatter_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_shape_op",
|
||||
srcs = ["risc_shape_op.cc"],
|
||||
@ -117,18 +277,100 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_sort_op",
|
||||
srcs = ["risc_sort_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_squeeze_op",
|
||||
srcs = ["risc_squeeze_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_transpose_op",
|
||||
srcs = ["risc_transpose_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_triangular_solve_op",
|
||||
srcs = ["risc_triangular_solve_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_unary_op",
|
||||
srcs = ["risc_unary_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "risc_while_op",
|
||||
srcs = ["risc_while_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "experimental",
|
||||
deps = [
|
||||
":risc_add_op",
|
||||
":risc_binary_arithmetic_op",
|
||||
":risc_binary_comparison_op",
|
||||
":risc_bitcast_op",
|
||||
":risc_broadcast_op",
|
||||
":risc_cast_op",
|
||||
":risc_cholesky_op",
|
||||
":risc_condition_op",
|
||||
":risc_conv_op",
|
||||
":risc_dot_op",
|
||||
":risc_fft_op",
|
||||
":risc_gather_op",
|
||||
":risc_is_finite_op",
|
||||
":risc_logical_and_op",
|
||||
":risc_logical_not_op",
|
||||
":risc_logical_or_op",
|
||||
":risc_max_op",
|
||||
":risc_pad_op",
|
||||
":risc_pool_op",
|
||||
":risc_random_uniform_op",
|
||||
":risc_reduce_op",
|
||||
":risc_reshape_op",
|
||||
":risc_reverse_op",
|
||||
":risc_scatter_op",
|
||||
":risc_shape_op",
|
||||
":risc_slice_op",
|
||||
":risc_sort_op",
|
||||
":risc_squeeze_op",
|
||||
":risc_transpose_op",
|
||||
":risc_triangular_solve_op",
|
||||
":risc_unary_op",
|
||||
":risc_while_op",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
template <typename T>
|
||||
class RiscBinaryArithmeticOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscBinaryArithmeticOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscBinaryArithmetic op.
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RiscBinaryArithmetic").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
RiscBinaryArithmeticOp<T>);
|
||||
|
||||
REGISTER_CPU(bfloat16);
|
||||
REGISTER_CPU(Eigen::half);
|
||||
REGISTER_CPU(float);
|
||||
REGISTER_CPU(double);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
template <typename T>
|
||||
class RiscBinaryComparisonOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscBinaryComparisonOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscBinaryComparison op.
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RiscBinaryComparison").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
RiscBinaryComparisonOp<T>);
|
||||
|
||||
REGISTER_CPU(bfloat16);
|
||||
REGISTER_CPU(Eigen::half);
|
||||
REGISTER_CPU(float);
|
||||
REGISTER_CPU(double);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_bitcast_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_bitcast_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscBitcastOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscBitcastOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscBitcast op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscBitcast").Device(DEVICE_CPU), RiscBitcastOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_cast_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_cast_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscCastOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscCastOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscCast op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscCast").Device(DEVICE_CPU), RiscCastOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,48 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
template <typename T>
|
||||
class RiscCholeskyOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscCholeskyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscCholesky op.
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("RiscCholesky").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
RiscCholeskyOp<T>);
|
||||
|
||||
REGISTER_CPU(bfloat16);
|
||||
REGISTER_CPU(Eigen::half);
|
||||
REGISTER_CPU(float);
|
||||
REGISTER_CPU(double);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscConditionOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscConditionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscCondition op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscCondition").Device(DEVICE_CPU),
|
||||
RiscConditionOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_fft_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_fft_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscFftOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscFftOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscFft op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscFft").Device(DEVICE_CPU), RiscFftOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_gather_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_gather_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscGatherOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscGatherOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscGather op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscGather").Device(DEVICE_CPU), RiscGatherOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscIsFiniteOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscIsFiniteOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscIsFinite op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscIsFinite").Device(DEVICE_CPU),
|
||||
RiscIsFiniteOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscLogicalAndOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscLogicalAndOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscLogicalAnd op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscLogicalAnd").Device(DEVICE_CPU),
|
||||
RiscLogicalAndOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscLogicalNotOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscLogicalNotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscLogicalNot op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscLogicalNot").Device(DEVICE_CPU),
|
||||
RiscLogicalNotOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscLogicalOrOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscLogicalOrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscLogicalOr op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscLogicalOr").Device(DEVICE_CPU),
|
||||
RiscLogicalOrOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscRandomUniformOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscRandomUniformOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscRandomUniform op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscRandomUniform").Device(DEVICE_CPU),
|
||||
RiscRandomUniformOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_reduce_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_reduce_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscReduceOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscReduceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscReduce op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscReduce").Device(DEVICE_CPU), RiscReduceOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_reverse_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_reverse_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscReverseOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscReverseOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscReverse op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscReverse").Device(DEVICE_CPU), RiscReverseOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_scatter_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_scatter_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscScatterOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscScatterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscScatter op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscScatter").Device(DEVICE_CPU), RiscScatterOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_sort_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_sort_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscSortOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscSortOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscSort op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscSort").Device(DEVICE_CPU), RiscSortOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_squeeze_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_squeeze_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscSqueezeOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscSqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscSqueeze op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscSqueeze").Device(DEVICE_CPU), RiscSqueezeOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscTransposeOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscTranspose op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscTranspose").Device(DEVICE_CPU),
|
||||
RiscTransposeOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscTriangularSolveOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscTriangularSolveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscTriangularSolve op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscTriangularSolve").Device(DEVICE_CPU),
|
||||
RiscTriangularSolveOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_unary_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_unary_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscUnaryOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscUnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscUnary op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscUnary").Device(DEVICE_CPU), RiscUnaryOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
39
tensorflow/core/kernels/risc/experimental/risc_while_op.cc
Normal file
39
tensorflow/core/kernels/risc/experimental/risc_while_op.cc
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace risc {
|
||||
namespace experimental {
|
||||
|
||||
class RiscWhileOp : public OpKernel {
|
||||
public:
|
||||
explicit RiscWhileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// TODO(b/171294012): Implement RiscWhile op.
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RiscWhile").Device(DEVICE_CPU), RiscWhileOp);
|
||||
|
||||
} // namespace experimental
|
||||
} // namespace risc
|
||||
} // namespace tensorflow
|
@ -30,6 +30,31 @@ REGISTER_OP("RiscAdd")
|
||||
.SetIsAggregate()
|
||||
.SetIsCommutative();
|
||||
|
||||
// TODO(b/171294012): include RiscMax here as well.
|
||||
REGISTER_OP("RiscBinaryArithmetic")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr("op_type: {'ADD', 'SUB', 'MUL', 'DIV', 'REM', 'MIN', 'POW'}")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("RiscBinaryComparison")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
.Output("z: bool")
|
||||
.Attr("op_type: {'EQ', 'NE', 'GE', 'GT', 'LE', 'LT'}")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscBitcast")
|
||||
.Input("x: SrcT")
|
||||
.Output("y: DstT")
|
||||
.Attr("SrcT: type")
|
||||
.Attr("DstT: type")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscBroadcast")
|
||||
.Input("input: T")
|
||||
@ -39,6 +64,20 @@ REGISTER_OP("RiscBroadcast")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("RiscCast")
|
||||
.Input("x: SrcT")
|
||||
.Output("y: DstT")
|
||||
.Attr("SrcT: type")
|
||||
.Attr("DstT: type")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscCholesky")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("RiscConcat")
|
||||
.Input("values: N * T")
|
||||
.Input("axis: Tidx")
|
||||
@ -48,6 +87,18 @@ REGISTER_OP("RiscConcat")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::ConcatV2Shape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscCondition")
|
||||
.Input("pred: bool")
|
||||
.Input("input_true: SrcT")
|
||||
.Input("input_false: SrcT")
|
||||
.Output("output: DstT")
|
||||
.Attr("func_true: func")
|
||||
.Attr("func_false: func")
|
||||
.Attr("SrcT: {bfloat16, half, float, double}")
|
||||
.Attr("DstT: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscConv")
|
||||
.Input("input: T")
|
||||
@ -68,6 +119,48 @@ REGISTER_OP("RiscDot")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::MatMulShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscFft")
|
||||
.Input("input: Tcomplex")
|
||||
.Output("output: Tcomplex")
|
||||
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscGather")
|
||||
.Input("params: Tparams")
|
||||
.Input("indices: Tindices")
|
||||
.Input("axis: Taxis")
|
||||
.Attr("batch_dims: int = 0")
|
||||
.Output("output: Tparams")
|
||||
.Attr("Tparams: type")
|
||||
.Attr("Tindices: {int32,int64}")
|
||||
.Attr("Taxis: {int32,int64}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("RiscIsFinite")
|
||||
.Input("x: T")
|
||||
.Output("y: bool")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("RiscLogicalAnd")
|
||||
.Input("x: bool")
|
||||
.Input("y: bool")
|
||||
.Output("z: bool")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("RiscLogicalNot")
|
||||
.Input("x: bool")
|
||||
.Output("z: bool")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("RiscLogicalOr")
|
||||
.Input("x: bool")
|
||||
.Input("y: bool")
|
||||
.Output("z: bool")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("RiscMax")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
@ -96,6 +189,23 @@ REGISTER_OP("RiscPool")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("RiscRandomUniform")
|
||||
.Input("shape: T")
|
||||
.Output("output: float")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("T: {int32, int64}")
|
||||
.SetShapeFn(shape_inference::RandomShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscReduce")
|
||||
.Input("tensor: T")
|
||||
.Input("axis: Index")
|
||||
.Output("output: T")
|
||||
.Attr("reduce_type: {'MEAN', 'SUM'}")
|
||||
.Attr("Index: {int32,int64} = DT_INT32")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscReshape")
|
||||
.Input("tensor: T")
|
||||
@ -105,6 +215,24 @@ REGISTER_OP("RiscReshape")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("RiscReverse")
|
||||
.Input("tensor: T")
|
||||
.Input("axis: Tidx")
|
||||
.Output("output: T")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscScatter")
|
||||
.Input("indices: Tindices")
|
||||
.Input("updates: T")
|
||||
.Input("shape: Tindices")
|
||||
.Output("output: T")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.Attr("Tindices: {int32, int64}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscShape")
|
||||
.Input("input: T")
|
||||
@ -122,4 +250,61 @@ REGISTER_OP("RiscSlice")
|
||||
.Attr("Index: {int32,int64}")
|
||||
.SetShapeFn(shape_inference::SliceShape);
|
||||
|
||||
REGISTER_OP("RiscSort")
|
||||
.Input("input: T")
|
||||
.Input("axis: Index")
|
||||
.Output("output: T")
|
||||
.Attr("Index: {int32,int64} = DT_INT32")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.Attr("direction: {'ASCENDING', 'DESCENDING'}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscSqueeze")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("squeeze_dims: list(int) >= 0 = []")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscTranspose")
|
||||
.Input("x: T")
|
||||
.Input("perm: Tperm")
|
||||
.Output("y: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tperm: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscTriangularSolve")
|
||||
.Input("matrix: T")
|
||||
.Input("rhs: T")
|
||||
.Output("output: T")
|
||||
.Attr("lower: bool = True")
|
||||
.Attr("adjoint: bool = False")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("RiscUnary")
|
||||
.Input("x: T")
|
||||
.Output("y: T")
|
||||
.Attr(
|
||||
"op_type: {'ABL', 'CEIL', 'COS', 'EXP', 'FLOOR', 'IMAG', 'LOG', 'NEG', "
|
||||
"'REAL', 'SIGN'}")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// TODO(b/171294012): change shape function.
|
||||
REGISTER_OP("RiscWhile")
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: list(type) >= 0")
|
||||
.Attr("cond: func")
|
||||
.Attr("body: func")
|
||||
.Attr("output_shapes: list(shape) = []")
|
||||
.Attr("parallel_iterations: int = 10")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::UnknownShape);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,27 @@ def _RiscAddGrad(_, grad):
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscBinaryArithmetic")
|
||||
def _RiscBinaryArithmeticGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscBinaryComparison")
|
||||
def _RiscBinaryComparisonGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscBitcast")
|
||||
def _RiscBitcastGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscBroadcast")
|
||||
def _RiscBroadcastGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
@ -35,6 +56,20 @@ def _RiscBroadcastGrad(_, grad):
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscCast")
|
||||
def _RiscCastGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscCholesky")
|
||||
def _RiscCholeskyGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscConcat")
|
||||
def _RiscConcatGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
@ -42,6 +77,13 @@ def _RiscConcatGrad(_, grad):
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscCondition")
|
||||
def _RiscConditionGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscConv")
|
||||
def _RiscConvGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
@ -56,6 +98,48 @@ def _RiscDotGrad(_, grad):
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscFft")
|
||||
def _RiscFftGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscGather")
|
||||
def _RiscGatherGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscIsFinite")
|
||||
def _RiscIsFiniteGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscLogicalAnd")
|
||||
def _RiscLogicalAndGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscLogicalNot")
|
||||
def _RiscLogicalNotGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscLogicalOr")
|
||||
def _RiscLogicalOrGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscMax")
|
||||
def _RiscMaxGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
@ -77,6 +161,20 @@ def _RiscPoolGrad(_, grad):
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscRandomUniform")
|
||||
def _RiscRandomUniformGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscReduce")
|
||||
def _RiscReduceGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscReshape")
|
||||
def _RiscReshapeGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
@ -84,6 +182,20 @@ def _RiscReshapeGrad(_, grad):
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscReverse")
|
||||
def _RiscReverseGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscScatter")
|
||||
def _RiscScatterGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscShape")
|
||||
def _RiscShapeGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
@ -96,3 +208,45 @@ def _RiscSliceGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscSort")
|
||||
def _RiscSortGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscSqueeze")
|
||||
def _RiscSqueezeGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscTranspose")
|
||||
def _RiscTransposeGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscTriangularSolve")
|
||||
def _RiscTriangularSolvesGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscUnary")
|
||||
def _RiscUnaryGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("RiscWhile")
|
||||
def _RiscWhileGrad(_, grad):
|
||||
# pylint: disable=unused-argument
|
||||
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
|
||||
return None, None
|
||||
|
@ -29,14 +29,49 @@ def risc_add(
|
||||
return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name)
|
||||
|
||||
|
||||
def risc_binary_arithmetic(x, y, op_type, name='RISC_BinaryArithmetic'):
|
||||
return gen_risc_ops.risc_binary_arithmetic(x, y, op_type=op_type, name=name)
|
||||
|
||||
|
||||
def risc_binary_comparison(x, y, op_type, name='RISC_BinaryComparison'):
|
||||
return gen_risc_ops.risc_binary_comparison(x, y, op_type=op_type, name=name)
|
||||
|
||||
|
||||
def risc_bitcast(x, dtype, name='RISC_BITCAST'):
|
||||
return gen_risc_ops.risc_bitcast(x, dtype, name=name)
|
||||
|
||||
|
||||
def risc_broadcast(x, shape, name='RISC_BROADCAST'):
|
||||
return gen_risc_ops.risc_broadcast(x, shape, name=name)
|
||||
|
||||
|
||||
def risc_cast(x, dtype, name='RISC_CAST'):
|
||||
return gen_risc_ops.risc_cast(x, dtype, name=name)
|
||||
|
||||
|
||||
def risc_cholesky(x, name='RISC_CHOLESKY'):
|
||||
return gen_risc_ops.risc_cholesky(x, name=name)
|
||||
|
||||
|
||||
def risc_concat(x, axis, name='RISC_CONCAT'):
|
||||
return gen_risc_ops.risc_concat(x, axis, name=name)
|
||||
|
||||
|
||||
def risc_condition(pred,
|
||||
input_true,
|
||||
input_false,
|
||||
func_true,
|
||||
func_false,
|
||||
name='RISC_CONDITION'):
|
||||
return gen_risc_ops.risc_condition(
|
||||
pred,
|
||||
input_true,
|
||||
input_false,
|
||||
func_true=func_true,
|
||||
func_false=func_false,
|
||||
name=name)
|
||||
|
||||
|
||||
def risc_conv(x,
|
||||
kernel,
|
||||
strides,
|
||||
@ -65,6 +100,41 @@ def risc_dot(input_lhs,
|
||||
name=name)
|
||||
|
||||
|
||||
def risc_fft(x, name='RISC_FFT'):
|
||||
return gen_risc_ops.risc_fft(x, name=name)
|
||||
|
||||
|
||||
def risc_gather(params,
|
||||
indices,
|
||||
validate_indices=None,
|
||||
axis=None,
|
||||
batch_dims=0,
|
||||
name='RISC_GATHER'):
|
||||
return gen_risc_ops.risc_gather(
|
||||
params,
|
||||
indices,
|
||||
validate_indices=validate_indices,
|
||||
name=name,
|
||||
axis=axis,
|
||||
batch_dims=batch_dims)
|
||||
|
||||
|
||||
def risc_is_finite(x, name='RISC_IS_FINITE'):
|
||||
return gen_risc_ops.risc_is_finite(x, name=name)
|
||||
|
||||
|
||||
def risc_logical_and(a, b, name='RISC_LOGICAL_AND'):
|
||||
return gen_risc_ops.risc_logical_and(a, b, name=name)
|
||||
|
||||
|
||||
def risc_logical_not(a, b, name='RISC_LOGICAL_NOT'):
|
||||
return gen_risc_ops.risc_logical_not(a, b, name=name)
|
||||
|
||||
|
||||
def risc_logical_or(a, b, name='RISC_LOGICAL_OR'):
|
||||
return gen_risc_ops.risc_logical_or(a, b, name=name)
|
||||
|
||||
|
||||
def risc_max(input_lhs, input_rhs, name='RISC_MAX'):
|
||||
return gen_risc_ops.risc_max(input_lhs, input_rhs, name=name)
|
||||
|
||||
@ -78,13 +148,76 @@ def risc_pool(x, ksize, strides, pooling_type='MAX', name='RISC_POOL'):
|
||||
x, ksize, strides, pooling_type=pooling_type, name=name)
|
||||
|
||||
|
||||
def risc_random_uniform(shape, seed, name='RISC_RANDOM_UNIFORM'):
|
||||
return gen_risc_ops.risc_random_uniform(shape, seed, name=name)
|
||||
|
||||
|
||||
def risc_reduce(x, axis, reduce_type, name='RISC_REDUCE'):
|
||||
return gen_risc_ops.risc_reduce(x, axis, reduce_type=reduce_type, name=name)
|
||||
|
||||
|
||||
def risc_reshape(x, shape, name='RISC_RESHAPE'):
|
||||
return gen_risc_ops.risc_reshape(x, shape, name=name)
|
||||
|
||||
|
||||
def risc_reverse(x, axis, name='RISC_REVERSE'):
|
||||
return gen_risc_ops.risc_reverse(x, axis, name=name)
|
||||
|
||||
|
||||
def risc_scatter(indices, updates, shape, name='RISC_SCATTER'):
|
||||
return gen_risc_ops.risc_scatter(indices, updates, shape, name=name)
|
||||
|
||||
|
||||
def risc_shape(x, name='RISC_SHAPE'):
|
||||
return gen_risc_ops.risc_shape(x, name=name)
|
||||
|
||||
|
||||
def risc_slice(x, begin, size, name='RISC_SLICE'):
|
||||
return gen_risc_ops.risc_slice(x, begin, size, name=name)
|
||||
|
||||
|
||||
def risc_sort(x, axis, direction='ASCENDING', name='RISC_SORT'):
|
||||
return gen_risc_ops.risc_sort(x, axis, direction=direction, name=name)
|
||||
|
||||
|
||||
def risc_squeeze(x, axis=None, name='RISC_SQUEEZE'):
|
||||
return gen_risc_ops.risc_squeeze(x, axis, name=name)
|
||||
|
||||
|
||||
def risc_transpose(x, perm=None, name='RISC_TRANSPOSE'):
|
||||
return gen_risc_ops.risc_transpose(x, perm, name=name)
|
||||
|
||||
|
||||
def risc_triangular_solve(matrix,
|
||||
rhs,
|
||||
lower=True,
|
||||
adjoint=False,
|
||||
name='RISC_TRIANGULAR_SOLVE'):
|
||||
return gen_risc_ops.risc_triangular_solve(
|
||||
matrix, rhs, lower=lower, adjoint=adjoint, name=name)
|
||||
|
||||
|
||||
def risc_unary(x, op_type='ABL', name='RISC_UNARY'):
|
||||
return gen_risc_ops.risc_unary(x, op_type=op_type, name=name)
|
||||
|
||||
|
||||
def risc_while(cond,
|
||||
body,
|
||||
loop_vars,
|
||||
shape_invariants=None,
|
||||
parallel_iterations=10,
|
||||
back_prop=True,
|
||||
swap_memory=False,
|
||||
maximum_iterations=None,
|
||||
name='RISC_WHILE'):
|
||||
return gen_risc_ops.risc_while(
|
||||
cond=cond,
|
||||
body=body,
|
||||
loop_vars=loop_vars,
|
||||
shape_invariants=shape_invariants,
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop,
|
||||
swap_memory=swap_memory,
|
||||
name=name,
|
||||
maximum_iterations=maximum_iterations,
|
||||
return_same_structure=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user