Add risc op register.

PiperOrigin-RevId: 346946572
Change-Id: I63580bb0491591439a928038185542d52cde5e3b
This commit is contained in:
Jian Li 2020-12-11 00:16:40 -08:00 committed by TensorFlower Gardener
parent b2ad941e21
commit 44a4129b24
48 changed files with 1695 additions and 0 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscBinaryArithmetic"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscBinaryComparison"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscBitcast"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscCast"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscCholesky"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscCondition"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscFft"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscGather"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscIsFinite"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscLogicalAnd"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscLogicalNot"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscLogicalOr"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscRandomUniform"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscReduce"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscReverse"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscScatter"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscSort"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscSqueeze"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscTranspose"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscTriangularSolve"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscUnary"
visibility: HIDDEN
}

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "RiscWhile"
visibility: HIDDEN
}

View File

@ -17,6 +17,36 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_binary_arithmetic_op",
srcs = ["risc_binary_arithmetic_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_binary_comparison_op",
srcs = ["risc_binary_comparison_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_bitcast_op",
srcs = ["risc_bitcast_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_broadcast_op",
srcs = ["risc_broadcast_op.cc"],
@ -27,6 +57,26 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_cast_op",
srcs = ["risc_cast_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_cholesky_op",
srcs = ["risc_cholesky_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_concat_op",
srcs = ["risc_concat_op.cc"],
@ -37,6 +87,16 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_condition_op",
srcs = ["risc_condition_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_conv_op",
srcs = ["risc_conv_op.cc"],
@ -57,6 +117,66 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_fft_op",
srcs = ["risc_fft_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_gather_op",
srcs = ["risc_gather_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_is_finite_op",
srcs = ["risc_is_finite_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_logical_and_op",
srcs = ["risc_logical_and_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_logical_not_op",
srcs = ["risc_logical_not_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_logical_or_op",
srcs = ["risc_logical_or_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_max_op",
srcs = ["risc_max_op.cc"],
@ -87,6 +207,26 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_random_uniform_op",
srcs = ["risc_random_uniform_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_reduce_op",
srcs = ["risc_reduce_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_reshape_op",
srcs = ["risc_reshape_op.cc"],
@ -97,6 +237,26 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_reverse_op",
srcs = ["risc_reverse_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_scatter_op",
srcs = ["risc_scatter_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_shape_op",
srcs = ["risc_shape_op.cc"],
@ -117,18 +277,100 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "risc_sort_op",
srcs = ["risc_sort_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_squeeze_op",
srcs = ["risc_squeeze_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_transpose_op",
srcs = ["risc_transpose_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_triangular_solve_op",
srcs = ["risc_triangular_solve_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_unary_op",
srcs = ["risc_unary_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "risc_while_op",
srcs = ["risc_while_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "experimental",
deps = [
":risc_add_op",
":risc_binary_arithmetic_op",
":risc_binary_comparison_op",
":risc_bitcast_op",
":risc_broadcast_op",
":risc_cast_op",
":risc_cholesky_op",
":risc_condition_op",
":risc_conv_op",
":risc_dot_op",
":risc_fft_op",
":risc_gather_op",
":risc_is_finite_op",
":risc_logical_and_op",
":risc_logical_not_op",
":risc_logical_or_op",
":risc_max_op",
":risc_pad_op",
":risc_pool_op",
":risc_random_uniform_op",
":risc_reduce_op",
":risc_reshape_op",
":risc_reverse_op",
":risc_scatter_op",
":risc_shape_op",
":risc_slice_op",
":risc_sort_op",
":risc_squeeze_op",
":risc_transpose_op",
":risc_triangular_solve_op",
":risc_unary_op",
":risc_while_op",
],
)

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscBinaryArithmeticOp : public OpKernel {
public:
explicit RiscBinaryArithmeticOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscBinaryArithmetic op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscBinaryArithmetic").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscBinaryArithmeticOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscBinaryComparisonOp : public OpKernel {
public:
explicit RiscBinaryComparisonOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscBinaryComparison op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscBinaryComparison").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscBinaryComparisonOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscBitcastOp : public OpKernel {
public:
explicit RiscBitcastOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscBitcast op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscBitcast").Device(DEVICE_CPU), RiscBitcastOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscCastOp : public OpKernel {
public:
explicit RiscCastOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscCast op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscCast").Device(DEVICE_CPU), RiscCastOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,48 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscCholeskyOp : public OpKernel {
public:
explicit RiscCholeskyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscCholesky op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscCholesky").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscCholeskyOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscConditionOp : public OpKernel {
public:
explicit RiscConditionOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscCondition op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscCondition").Device(DEVICE_CPU),
RiscConditionOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscFftOp : public OpKernel {
public:
explicit RiscFftOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscFft op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscFft").Device(DEVICE_CPU), RiscFftOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscGatherOp : public OpKernel {
public:
explicit RiscGatherOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscGather op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscGather").Device(DEVICE_CPU), RiscGatherOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscIsFiniteOp : public OpKernel {
public:
explicit RiscIsFiniteOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscIsFinite op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscIsFinite").Device(DEVICE_CPU),
RiscIsFiniteOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscLogicalAndOp : public OpKernel {
public:
explicit RiscLogicalAndOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscLogicalAnd op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscLogicalAnd").Device(DEVICE_CPU),
RiscLogicalAndOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscLogicalNotOp : public OpKernel {
public:
explicit RiscLogicalNotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscLogicalNot op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscLogicalNot").Device(DEVICE_CPU),
RiscLogicalNotOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscLogicalOrOp : public OpKernel {
public:
explicit RiscLogicalOrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscLogicalOr op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscLogicalOr").Device(DEVICE_CPU),
RiscLogicalOrOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscRandomUniformOp : public OpKernel {
public:
explicit RiscRandomUniformOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscRandomUniform op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscRandomUniform").Device(DEVICE_CPU),
RiscRandomUniformOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscReduceOp : public OpKernel {
public:
explicit RiscReduceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscReduce op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscReduce").Device(DEVICE_CPU), RiscReduceOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscReverseOp : public OpKernel {
public:
explicit RiscReverseOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscReverse op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscReverse").Device(DEVICE_CPU), RiscReverseOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscScatterOp : public OpKernel {
public:
explicit RiscScatterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscScatter op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscScatter").Device(DEVICE_CPU), RiscScatterOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscSortOp : public OpKernel {
public:
explicit RiscSortOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscSort op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscSort").Device(DEVICE_CPU), RiscSortOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscSqueezeOp : public OpKernel {
public:
explicit RiscSqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscSqueeze op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscSqueeze").Device(DEVICE_CPU), RiscSqueezeOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscTransposeOp : public OpKernel {
public:
explicit RiscTransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscTranspose op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscTranspose").Device(DEVICE_CPU),
RiscTransposeOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscTriangularSolveOp : public OpKernel {
public:
explicit RiscTriangularSolveOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscTriangularSolve op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscTriangularSolve").Device(DEVICE_CPU),
RiscTriangularSolveOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscUnaryOp : public OpKernel {
public:
explicit RiscUnaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscUnary op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscUnary").Device(DEVICE_CPU), RiscUnaryOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
class RiscWhileOp : public OpKernel {
public:
explicit RiscWhileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscWhile op.
}
};
REGISTER_KERNEL_BUILDER(Name("RiscWhile").Device(DEVICE_CPU), RiscWhileOp);
} // namespace experimental
} // namespace risc
} // namespace tensorflow

View File

@ -30,6 +30,31 @@ REGISTER_OP("RiscAdd")
.SetIsAggregate()
.SetIsCommutative();
// TODO(b/171294012): include RiscMax here as well.
REGISTER_OP("RiscBinaryArithmetic")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("op_type: {'ADD', 'SUB', 'MUL', 'DIV', 'REM', 'MIN', 'POW'}")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscBinaryComparison")
.Input("x: T")
.Input("y: T")
.Output("z: bool")
.Attr("op_type: {'EQ', 'NE', 'GE', 'GT', 'LE', 'LT'}")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscBitcast")
.Input("x: SrcT")
.Output("y: DstT")
.Attr("SrcT: type")
.Attr("DstT: type")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscBroadcast")
.Input("input: T")
@ -39,6 +64,20 @@ REGISTER_OP("RiscBroadcast")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscCast")
.Input("x: SrcT")
.Output("y: DstT")
.Attr("SrcT: type")
.Attr("DstT: type")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscCholesky")
.Input("input: T")
.Output("output: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscConcat")
.Input("values: N * T")
.Input("axis: Tidx")
@ -48,6 +87,18 @@ REGISTER_OP("RiscConcat")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ConcatV2Shape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscCondition")
.Input("pred: bool")
.Input("input_true: SrcT")
.Input("input_false: SrcT")
.Output("output: DstT")
.Attr("func_true: func")
.Attr("func_false: func")
.Attr("SrcT: {bfloat16, half, float, double}")
.Attr("DstT: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscConv")
.Input("input: T")
@ -68,6 +119,48 @@ REGISTER_OP("RiscDot")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::MatMulShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscFft")
.Input("input: Tcomplex")
.Output("output: Tcomplex")
.Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscGather")
.Input("params: Tparams")
.Input("indices: Tindices")
.Input("axis: Taxis")
.Attr("batch_dims: int = 0")
.Output("output: Tparams")
.Attr("Tparams: type")
.Attr("Tindices: {int32,int64}")
.Attr("Taxis: {int32,int64}")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscIsFinite")
.Input("x: T")
.Output("y: bool")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscLogicalAnd")
.Input("x: bool")
.Input("y: bool")
.Output("z: bool")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscLogicalNot")
.Input("x: bool")
.Output("z: bool")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscLogicalOr")
.Input("x: bool")
.Input("y: bool")
.Output("z: bool")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("RiscMax")
.Input("x: T")
.Input("y: T")
@ -96,6 +189,23 @@ REGISTER_OP("RiscPool")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscRandomUniform")
.Input("shape: T")
.Output("output: float")
.Attr("seed: int = 0")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscReduce")
.Input("tensor: T")
.Input("axis: Index")
.Output("output: T")
.Attr("reduce_type: {'MEAN', 'SUM'}")
.Attr("Index: {int32,int64} = DT_INT32")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscReshape")
.Input("tensor: T")
@ -105,6 +215,24 @@ REGISTER_OP("RiscReshape")
.Attr("Tshape: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscReverse")
.Input("tensor: T")
.Input("axis: Tidx")
.Output("output: T")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscScatter")
.Input("indices: Tindices")
.Input("updates: T")
.Input("shape: Tindices")
.Output("output: T")
.Attr("T: {bfloat16, half, float, double}")
.Attr("Tindices: {int32, int64}")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscShape")
.Input("input: T")
@ -122,4 +250,61 @@ REGISTER_OP("RiscSlice")
.Attr("Index: {int32,int64}")
.SetShapeFn(shape_inference::SliceShape);
REGISTER_OP("RiscSort")
.Input("input: T")
.Input("axis: Index")
.Output("output: T")
.Attr("Index: {int32,int64} = DT_INT32")
.Attr("T: {bfloat16, half, float, double}")
.Attr("direction: {'ASCENDING', 'DESCENDING'}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscSqueeze")
.Input("input: T")
.Output("output: T")
.Attr("T: type")
.Attr("squeeze_dims: list(int) >= 0 = []")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscTranspose")
.Input("x: T")
.Input("perm: Tperm")
.Output("y: T")
.Attr("T: type")
.Attr("Tperm: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::UnknownShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscTriangularSolve")
.Input("matrix: T")
.Input("rhs: T")
.Output("output: T")
.Attr("lower: bool = True")
.Attr("adjoint: bool = False")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("RiscUnary")
.Input("x: T")
.Output("y: T")
.Attr(
"op_type: {'ABL', 'CEIL', 'COS', 'EXP', 'FLOOR', 'IMAG', 'LOG', 'NEG', "
"'REAL', 'SIGN'}")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
// TODO(b/171294012): change shape function.
REGISTER_OP("RiscWhile")
.Input("input: T")
.Output("output: T")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
.Attr("output_shapes: list(shape) = []")
.Attr("parallel_iterations: int = 10")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape);
} // namespace tensorflow

View File

@ -28,6 +28,27 @@ def _RiscAddGrad(_, grad):
return None, None
@ops.RegisterGradient("RiscBinaryArithmetic")
def _RiscBinaryArithmeticGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscBinaryComparison")
def _RiscBinaryComparisonGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscBitcast")
def _RiscBitcastGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscBroadcast")
def _RiscBroadcastGrad(_, grad):
# pylint: disable=unused-argument
@ -35,6 +56,20 @@ def _RiscBroadcastGrad(_, grad):
return None, None
@ops.RegisterGradient("RiscCast")
def _RiscCastGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscCholesky")
def _RiscCholeskyGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscConcat")
def _RiscConcatGrad(_, grad):
# pylint: disable=unused-argument
@ -42,6 +77,13 @@ def _RiscConcatGrad(_, grad):
return None, None
@ops.RegisterGradient("RiscCondition")
def _RiscConditionGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscConv")
def _RiscConvGrad(_, grad):
# pylint: disable=unused-argument
@ -56,6 +98,48 @@ def _RiscDotGrad(_, grad):
return None, None
@ops.RegisterGradient("RiscFft")
def _RiscFftGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscGather")
def _RiscGatherGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscIsFinite")
def _RiscIsFiniteGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscLogicalAnd")
def _RiscLogicalAndGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscLogicalNot")
def _RiscLogicalNotGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscLogicalOr")
def _RiscLogicalOrGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscMax")
def _RiscMaxGrad(_, grad):
# pylint: disable=unused-argument
@ -77,6 +161,20 @@ def _RiscPoolGrad(_, grad):
return None, None
@ops.RegisterGradient("RiscRandomUniform")
def _RiscRandomUniformGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscReduce")
def _RiscReduceGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscReshape")
def _RiscReshapeGrad(_, grad):
# pylint: disable=unused-argument
@ -84,6 +182,20 @@ def _RiscReshapeGrad(_, grad):
return None, None
@ops.RegisterGradient("RiscReverse")
def _RiscReverseGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscScatter")
def _RiscScatterGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscShape")
def _RiscShapeGrad(_, grad):
# pylint: disable=unused-argument
@ -96,3 +208,45 @@ def _RiscSliceGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscSort")
def _RiscSortGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscSqueeze")
def _RiscSqueezeGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscTranspose")
def _RiscTransposeGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscTriangularSolve")
def _RiscTriangularSolvesGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscUnary")
def _RiscUnaryGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
@ops.RegisterGradient("RiscWhile")
def _RiscWhileGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None

View File

@ -29,14 +29,49 @@ def risc_add(
return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name)
def risc_binary_arithmetic(x, y, op_type, name='RISC_BinaryArithmetic'):
return gen_risc_ops.risc_binary_arithmetic(x, y, op_type=op_type, name=name)
def risc_binary_comparison(x, y, op_type, name='RISC_BinaryComparison'):
return gen_risc_ops.risc_binary_comparison(x, y, op_type=op_type, name=name)
def risc_bitcast(x, dtype, name='RISC_BITCAST'):
return gen_risc_ops.risc_bitcast(x, dtype, name=name)
def risc_broadcast(x, shape, name='RISC_BROADCAST'):
return gen_risc_ops.risc_broadcast(x, shape, name=name)
def risc_cast(x, dtype, name='RISC_CAST'):
return gen_risc_ops.risc_cast(x, dtype, name=name)
def risc_cholesky(x, name='RISC_CHOLESKY'):
return gen_risc_ops.risc_cholesky(x, name=name)
def risc_concat(x, axis, name='RISC_CONCAT'):
return gen_risc_ops.risc_concat(x, axis, name=name)
def risc_condition(pred,
input_true,
input_false,
func_true,
func_false,
name='RISC_CONDITION'):
return gen_risc_ops.risc_condition(
pred,
input_true,
input_false,
func_true=func_true,
func_false=func_false,
name=name)
def risc_conv(x,
kernel,
strides,
@ -65,6 +100,41 @@ def risc_dot(input_lhs,
name=name)
def risc_fft(x, name='RISC_FFT'):
return gen_risc_ops.risc_fft(x, name=name)
def risc_gather(params,
indices,
validate_indices=None,
axis=None,
batch_dims=0,
name='RISC_GATHER'):
return gen_risc_ops.risc_gather(
params,
indices,
validate_indices=validate_indices,
name=name,
axis=axis,
batch_dims=batch_dims)
def risc_is_finite(x, name='RISC_IS_FINITE'):
return gen_risc_ops.risc_is_finite(x, name=name)
def risc_logical_and(a, b, name='RISC_LOGICAL_AND'):
return gen_risc_ops.risc_logical_and(a, b, name=name)
def risc_logical_not(a, b, name='RISC_LOGICAL_NOT'):
return gen_risc_ops.risc_logical_not(a, b, name=name)
def risc_logical_or(a, b, name='RISC_LOGICAL_OR'):
return gen_risc_ops.risc_logical_or(a, b, name=name)
def risc_max(input_lhs, input_rhs, name='RISC_MAX'):
return gen_risc_ops.risc_max(input_lhs, input_rhs, name=name)
@ -78,13 +148,76 @@ def risc_pool(x, ksize, strides, pooling_type='MAX', name='RISC_POOL'):
x, ksize, strides, pooling_type=pooling_type, name=name)
def risc_random_uniform(shape, seed, name='RISC_RANDOM_UNIFORM'):
return gen_risc_ops.risc_random_uniform(shape, seed, name=name)
def risc_reduce(x, axis, reduce_type, name='RISC_REDUCE'):
return gen_risc_ops.risc_reduce(x, axis, reduce_type=reduce_type, name=name)
def risc_reshape(x, shape, name='RISC_RESHAPE'):
return gen_risc_ops.risc_reshape(x, shape, name=name)
def risc_reverse(x, axis, name='RISC_REVERSE'):
return gen_risc_ops.risc_reverse(x, axis, name=name)
def risc_scatter(indices, updates, shape, name='RISC_SCATTER'):
return gen_risc_ops.risc_scatter(indices, updates, shape, name=name)
def risc_shape(x, name='RISC_SHAPE'):
return gen_risc_ops.risc_shape(x, name=name)
def risc_slice(x, begin, size, name='RISC_SLICE'):
return gen_risc_ops.risc_slice(x, begin, size, name=name)
def risc_sort(x, axis, direction='ASCENDING', name='RISC_SORT'):
return gen_risc_ops.risc_sort(x, axis, direction=direction, name=name)
def risc_squeeze(x, axis=None, name='RISC_SQUEEZE'):
return gen_risc_ops.risc_squeeze(x, axis, name=name)
def risc_transpose(x, perm=None, name='RISC_TRANSPOSE'):
return gen_risc_ops.risc_transpose(x, perm, name=name)
def risc_triangular_solve(matrix,
rhs,
lower=True,
adjoint=False,
name='RISC_TRIANGULAR_SOLVE'):
return gen_risc_ops.risc_triangular_solve(
matrix, rhs, lower=lower, adjoint=adjoint, name=name)
def risc_unary(x, op_type='ABL', name='RISC_UNARY'):
return gen_risc_ops.risc_unary(x, op_type=op_type, name=name)
def risc_while(cond,
body,
loop_vars,
shape_invariants=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
maximum_iterations=None,
name='RISC_WHILE'):
return gen_risc_ops.risc_while(
cond=cond,
body=body,
loop_vars=loop_vars,
shape_invariants=shape_invariants,
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory,
name=name,
maximum_iterations=maximum_iterations,
return_same_structure=True)