Break up core/kernels/BUILD (part 1 of N):
Move linear algebra kernels to subdirectory tensorflow/core/kernels/linalg with its own BUILD file. PiperOrigin-RevId: 324923762 Change-Id: Id17aac690729b62ae97525df5bb57d6a073d6b0c
This commit is contained in:
parent
79594069bb
commit
84d053187c
@ -1010,7 +1010,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:histogram_op",
|
"//tensorflow/core/kernels:histogram_op",
|
||||||
"//tensorflow/core/kernels:image",
|
"//tensorflow/core/kernels:image",
|
||||||
"//tensorflow/core/kernels:io",
|
"//tensorflow/core/kernels:io",
|
||||||
"//tensorflow/core/kernels:linalg",
|
"//tensorflow/core/kernels/linalg:linalg",
|
||||||
"//tensorflow/core/kernels:lookup",
|
"//tensorflow/core/kernels:lookup",
|
||||||
"//tensorflow/core/kernels:logging",
|
"//tensorflow/core/kernels:logging",
|
||||||
"//tensorflow/core/kernels:manip",
|
"//tensorflow/core/kernels:manip",
|
||||||
|
@ -1039,9 +1039,6 @@ cc_library(
|
|||||||
":immutable_constant_op",
|
":immutable_constant_op",
|
||||||
":inplace_ops",
|
":inplace_ops",
|
||||||
":listdiff_op",
|
":listdiff_op",
|
||||||
":matrix_band_part_op",
|
|
||||||
":matrix_diag_op",
|
|
||||||
":matrix_set_diag_op",
|
|
||||||
":mirror_pad_op",
|
":mirror_pad_op",
|
||||||
":one_hot_op",
|
":one_hot_op",
|
||||||
":pack_op",
|
":pack_op",
|
||||||
@ -1174,26 +1171,6 @@ tf_kernel_library(
|
|||||||
deps = ARRAY_DEPS,
|
deps = ARRAY_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_band_part_op",
|
|
||||||
prefix = "matrix_band_part_op",
|
|
||||||
deps = if_cuda([
|
|
||||||
":cuda_solvers",
|
|
||||||
]) + ARRAY_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_diag_op",
|
|
||||||
prefix = "matrix_diag_op",
|
|
||||||
deps = ARRAY_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_set_diag_op",
|
|
||||||
prefix = "matrix_set_diag_op",
|
|
||||||
deps = ARRAY_DEPS + [":matrix_diag_op"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "mirror_pad_op",
|
name = "mirror_pad_op",
|
||||||
prefix = "mirror_pad_op",
|
prefix = "mirror_pad_op",
|
||||||
@ -1405,7 +1382,7 @@ tf_kernel_library(
|
|||||||
"where_op_gpu_impl_8.cu.cc",
|
"where_op_gpu_impl_8.cu.cc",
|
||||||
],
|
],
|
||||||
deps = if_cuda_or_rocm([
|
deps = if_cuda_or_rocm([
|
||||||
":cuda_solvers",
|
"//tensorflow/core/util:cuda_solvers",
|
||||||
]) + [":gpu_prim_hdrs"] +
|
]) + [":gpu_prim_hdrs"] +
|
||||||
ARRAY_DEPS,
|
ARRAY_DEPS,
|
||||||
)
|
)
|
||||||
@ -2785,21 +2762,6 @@ tf_cuda_cc_tests(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "eye_functor",
|
|
||||||
hdrs = ["eye_functor.h"],
|
|
||||||
gpu_srcs = [
|
|
||||||
"eye_functor_gpu.cu.cc",
|
|
||||||
"eye_functor.h",
|
|
||||||
],
|
|
||||||
visibility = [":friends"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
],
|
|
||||||
alwayslink = 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "fifo_queue",
|
name = "fifo_queue",
|
||||||
srcs = ["fifo_queue.cc"],
|
srcs = ["fifo_queue.cc"],
|
||||||
@ -3558,289 +3520,6 @@ tf_cc_tests(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "linalg",
|
|
||||||
deps = [
|
|
||||||
":banded_triangular_solve_op",
|
|
||||||
":cholesky_grad",
|
|
||||||
":cholesky_op",
|
|
||||||
":determinant_op",
|
|
||||||
":eig_op",
|
|
||||||
":einsum_op",
|
|
||||||
":lu_op",
|
|
||||||
":matrix_exponential_op",
|
|
||||||
":matrix_inverse_op",
|
|
||||||
":matrix_logarithm_op",
|
|
||||||
":matrix_solve_ls_op",
|
|
||||||
":matrix_solve_op",
|
|
||||||
":matrix_square_root_op",
|
|
||||||
":matrix_triangular_solve_op",
|
|
||||||
":qr_op",
|
|
||||||
":self_adjoint_eig_op",
|
|
||||||
":self_adjoint_eig_v2_op",
|
|
||||||
":svd_op",
|
|
||||||
":tridiagonal_matmul_op",
|
|
||||||
":tridiagonal_solve_op",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "cuda_solvers",
|
|
||||||
srcs = ["cuda_solvers.cc"],
|
|
||||||
hdrs = ["cuda_solvers.h"],
|
|
||||||
# @local_config_cuda//cuda:cusolver_static, //third_party/eigen3:blas,
|
|
||||||
# and //third_party/libf2c all contain various parts of BLAS, LAPACK,
|
|
||||||
# and f2c helper functions in global namespace. Tell the compiler to
|
|
||||||
# allow multiple definitions when linking this.
|
|
||||||
linkopts = select({
|
|
||||||
"//tensorflow:macos": [],
|
|
||||||
"//tensorflow:windows": [],
|
|
||||||
"//conditions:default": ["-Wl,-z,muldefs"],
|
|
||||||
}),
|
|
||||||
visibility = [":friends"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
|
||||||
"//tensorflow/stream_executor/cuda:cublas_lib",
|
|
||||||
"//tensorflow/stream_executor/cuda:cusolver_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "rocm_solvers",
|
|
||||||
srcs = ["rocm_solvers.cc"],
|
|
||||||
hdrs = ["rocm_solvers.h"],
|
|
||||||
visibility = [":friends"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:framework_internal",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/stream_executor/lib",
|
|
||||||
"//tensorflow/stream_executor/platform:dso_loader",
|
|
||||||
"//tensorflow/stream_executor/rocm:rocblas_plugin",
|
|
||||||
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
|
|
||||||
] + if_rocm([
|
|
||||||
"@local_config_rocm//rocm:rocprim",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "cuda_sparse",
|
|
||||||
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
|
|
||||||
hdrs = ["cuda_sparse.h"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core/kernels:cuda_solvers",
|
|
||||||
] + if_cuda([
|
|
||||||
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
|
||||||
"@cub_archive//:cub",
|
|
||||||
]) + if_rocm([
|
|
||||||
"@local_config_rocm//rocm:hipsparse",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
LINALG_DEPS = [
|
|
||||||
":linalg_ops_common",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
] + if_cuda([
|
|
||||||
":cuda_solvers",
|
|
||||||
":transpose_functor",
|
|
||||||
]) + if_rocm([
|
|
||||||
":rocm_solvers",
|
|
||||||
])
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "cholesky_op",
|
|
||||||
prefix = "cholesky_op",
|
|
||||||
deps = if_cuda([
|
|
||||||
":matrix_band_part_op",
|
|
||||||
]) + LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "cholesky_grad",
|
|
||||||
prefix = "cholesky_grad",
|
|
||||||
deps = LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "determinant_op",
|
|
||||||
prefix = "determinant_op",
|
|
||||||
deps = if_cuda([
|
|
||||||
":fill_functor",
|
|
||||||
]) + LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_exponential_op",
|
|
||||||
prefix = "matrix_exponential_op",
|
|
||||||
deps = LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_logarithm_op",
|
|
||||||
prefix = "matrix_logarithm_op",
|
|
||||||
deps = LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "self_adjoint_eig_op",
|
|
||||||
prefix = "self_adjoint_eig_op",
|
|
||||||
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "self_adjoint_eig_v2_op",
|
|
||||||
prefix = "self_adjoint_eig_v2_op",
|
|
||||||
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
|
|
||||||
":cast_op",
|
|
||||||
":cwise_op",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "eig_op",
|
|
||||||
prefix = "eig_op",
|
|
||||||
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
|
|
||||||
":cast_op",
|
|
||||||
":cwise_op",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_inverse_op",
|
|
||||||
prefix = "matrix_inverse_op",
|
|
||||||
deps = LINALG_DEPS + if_cuda([":eye_functor"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_solve_ls_op",
|
|
||||||
prefix = "matrix_solve_ls_op",
|
|
||||||
deps = LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_solve_op",
|
|
||||||
prefix = "matrix_solve_op",
|
|
||||||
deps = LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_square_root_op",
|
|
||||||
prefix = "matrix_square_root_op",
|
|
||||||
deps = LINALG_DEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "banded_triangular_solve_op",
|
|
||||||
prefix = "banded_triangular_solve_op",
|
|
||||||
deps = LINALG_DEPS + [":fill_functor"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "matrix_triangular_solve_op",
|
|
||||||
hdrs = ["matrix_triangular_solve_op_impl.h"],
|
|
||||||
prefix = "matrix_triangular_solve_op",
|
|
||||||
deps = [
|
|
||||||
":linalg_ops_common",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
":fill_functor",
|
|
||||||
"//tensorflow/core:stream_executor",
|
|
||||||
] + if_cuda([
|
|
||||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
|
||||||
":cuda_solvers",
|
|
||||||
]) + if_rocm([
|
|
||||||
"@local_config_rocm//rocm:rocprim",
|
|
||||||
":rocm_solvers",
|
|
||||||
]) + if_cuda_or_rocm([
|
|
||||||
":transpose_functor",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "tridiagonal_matmul_op",
|
|
||||||
srcs = ["tridiagonal_matmul_op.cc"],
|
|
||||||
gpu_srcs = ["tridiagonal_matmul_op_gpu.cu.cc"],
|
|
||||||
deps = LINALG_DEPS + if_cuda([
|
|
||||||
":cuda_sparse",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "tridiagonal_solve_op",
|
|
||||||
srcs = ["tridiagonal_solve_op.cc"],
|
|
||||||
gpu_srcs = ["tridiagonal_solve_op_gpu.cu.cc"],
|
|
||||||
deps = LINALG_DEPS + if_cuda([
|
|
||||||
":cuda_sparse",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "qr_op",
|
|
||||||
prefix = "qr_op",
|
|
||||||
deps = LINALG_DEPS + if_cuda([
|
|
||||||
":cwise_op",
|
|
||||||
":eye_functor",
|
|
||||||
":matrix_band_part_op",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "svd_op",
|
|
||||||
prefix = "svd_op",
|
|
||||||
deps = LINALG_DEPS + if_cuda([
|
|
||||||
":eye_functor",
|
|
||||||
]),
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "lu_op",
|
|
||||||
prefix = "lu_op",
|
|
||||||
deps = if_cuda([
|
|
||||||
":cuda_solvers",
|
|
||||||
":transpose_functor",
|
|
||||||
]) + [
|
|
||||||
"//third_party/eigen3",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "einsum_op",
|
|
||||||
prefix = "einsum_op",
|
|
||||||
deps = [
|
|
||||||
":batch_matmul_op",
|
|
||||||
":fill_functor",
|
|
||||||
":reduction_ops",
|
|
||||||
":transpose_functor",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "linalg_ops_common",
|
|
||||||
srcs = ["linalg_ops_common.cc"],
|
|
||||||
hdrs = ["linalg_ops_common.h"],
|
|
||||||
visibility = ["//visibility:private"],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "logging",
|
name = "logging",
|
||||||
deps = [
|
deps = [
|
||||||
@ -4208,7 +3887,7 @@ tf_kernel_library(
|
|||||||
name = "segment_reduction_ops",
|
name = "segment_reduction_ops",
|
||||||
prefix = "segment_reduction_ops",
|
prefix = "segment_reduction_ops",
|
||||||
deps = MATH_DEPS + if_cuda_or_rocm([
|
deps = MATH_DEPS + if_cuda_or_rocm([
|
||||||
":cuda_solvers",
|
"//tensorflow/core/util:cuda_solvers",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -4405,45 +4084,6 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
|
||||||
name = "banded_triangular_solve_op_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["banded_triangular_solve_op_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":banded_triangular_solve_op",
|
|
||||||
":matrix_set_diag_op",
|
|
||||||
":matrix_triangular_solve_op",
|
|
||||||
":ops_testutil",
|
|
||||||
":ops_util",
|
|
||||||
"//tensorflow/core:core_cpu",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core:testlib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
|
||||||
name = "matrix_triangular_solve_op_test",
|
|
||||||
size = "small",
|
|
||||||
srcs = ["matrix_triangular_solve_op_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":broadcast_to_op",
|
|
||||||
":matrix_triangular_solve_op",
|
|
||||||
":ops_testutil",
|
|
||||||
":ops_util",
|
|
||||||
"//tensorflow/core:core_cpu",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core:testlib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "scan_ops_test",
|
name = "scan_ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
@ -6672,10 +6312,7 @@ filegroup(
|
|||||||
"lookup_table_init_op.h",
|
"lookup_table_init_op.h",
|
||||||
"lookup_table_op.h",
|
"lookup_table_op.h",
|
||||||
"lookup_util.h",
|
"lookup_util.h",
|
||||||
"linalg_ops_common.h",
|
|
||||||
"list_kernels.h",
|
"list_kernels.h",
|
||||||
"matrix_diag_op.h",
|
|
||||||
"matrix_set_diag_op.h",
|
|
||||||
"maxpooling_op.h",
|
"maxpooling_op.h",
|
||||||
"mfcc.h",
|
"mfcc.h",
|
||||||
"mfcc_dct.h",
|
"mfcc_dct.h",
|
||||||
@ -6723,6 +6360,9 @@ filegroup(
|
|||||||
"xent_op.h",
|
"xent_op.h",
|
||||||
] + [
|
] + [
|
||||||
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles_hdrs",
|
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles_hdrs",
|
||||||
|
"//tensorflow/core/kernels/linalg:linalg_ops_common.h",
|
||||||
|
"//tensorflow/core/kernels/linalg:matrix_diag_op.h",
|
||||||
|
"//tensorflow/core/kernels/linalg:matrix_set_diag_op.h",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -6823,16 +6463,6 @@ filegroup(
|
|||||||
"encode_wav_op.cc",
|
"encode_wav_op.cc",
|
||||||
"eigen_contraction_kernel.cc",
|
"eigen_contraction_kernel.cc",
|
||||||
"eigen_contraction_kernel.h",
|
"eigen_contraction_kernel.h",
|
||||||
"einsum_op_impl_half.cc",
|
|
||||||
"einsum_op_impl_bfloat16.cc",
|
|
||||||
"einsum_op_impl_int32.cc",
|
|
||||||
"einsum_op_impl_int64.cc",
|
|
||||||
"einsum_op_impl_float.cc",
|
|
||||||
"einsum_op_impl_double.cc",
|
|
||||||
"einsum_op_impl_complex64.cc",
|
|
||||||
"einsum_op_impl_complex128.cc",
|
|
||||||
"einsum_op_impl.h",
|
|
||||||
"einsum_op.h",
|
|
||||||
"fake_quant_ops.cc",
|
"fake_quant_ops.cc",
|
||||||
"fifo_queue.cc",
|
"fifo_queue.cc",
|
||||||
"fifo_queue_op.cc",
|
"fifo_queue_op.cc",
|
||||||
@ -6844,6 +6474,17 @@ filegroup(
|
|||||||
"population_count_op.h",
|
"population_count_op.h",
|
||||||
"winograd_transform.h",
|
"winograd_transform.h",
|
||||||
":android_extended_ops_headers",
|
":android_extended_ops_headers",
|
||||||
|
] + [
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_half.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_bfloat16.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_int32.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_int64.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_float.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_double.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_complex64.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl_complex128.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op_impl.h",
|
||||||
|
"//tensorflow/core/kernels/linalg:einsum_op.h",
|
||||||
] + select({
|
] + select({
|
||||||
":xsmm_convolutions": [
|
":xsmm_convolutions": [
|
||||||
"xsmm_conv2d.h",
|
"xsmm_conv2d.h",
|
||||||
@ -6874,7 +6515,6 @@ filegroup(
|
|||||||
"in_topk_op.cc",
|
"in_topk_op.cc",
|
||||||
"in_topk_op.h",
|
"in_topk_op.h",
|
||||||
"initializable_lookup_table.cc",
|
"initializable_lookup_table.cc",
|
||||||
"linalg_ops_common.cc",
|
|
||||||
"list_kernels.cc",
|
"list_kernels.cc",
|
||||||
"logging_ops.cc",
|
"logging_ops.cc",
|
||||||
"logging_ops.h",
|
"logging_ops.h",
|
||||||
@ -6882,9 +6522,6 @@ filegroup(
|
|||||||
"lookup_table_op.cc",
|
"lookup_table_op.cc",
|
||||||
"lookup_util.cc",
|
"lookup_util.cc",
|
||||||
"lrn_op.cc",
|
"lrn_op.cc",
|
||||||
"matrix_diag_op.cc",
|
|
||||||
"matrix_inverse_op.cc",
|
|
||||||
"matrix_set_diag_op.cc",
|
|
||||||
"maxpooling_op.cc",
|
"maxpooling_op.cc",
|
||||||
"mfcc.cc",
|
"mfcc.cc",
|
||||||
"mfcc_dct.cc",
|
"mfcc_dct.cc",
|
||||||
@ -7006,6 +6643,10 @@ filegroup(
|
|||||||
":android_extended_ops_headers",
|
":android_extended_ops_headers",
|
||||||
] + [
|
] + [
|
||||||
"//tensorflow/core/kernels/boosted_trees:quantile_ops.cc",
|
"//tensorflow/core/kernels/boosted_trees:quantile_ops.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:linalg_ops_common.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:matrix_diag_op.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:matrix_inverse_op.cc",
|
||||||
|
"//tensorflow/core/kernels/linalg:matrix_set_diag_op.cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -7059,6 +6700,7 @@ filegroup(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"//tensorflow/c/kernels:android_all_op_kernels",
|
"//tensorflow/c/kernels:android_all_op_kernels",
|
||||||
"//tensorflow/core/kernels/data:android_all_op_kernels",
|
"//tensorflow/core/kernels/data:android_all_op_kernels",
|
||||||
|
"//tensorflow/core/kernels/linalg:android_all_op_kernels",
|
||||||
] + glob(
|
] + glob(
|
||||||
[
|
[
|
||||||
"*.cc",
|
"*.cc",
|
||||||
@ -8827,3 +8469,15 @@ tf_kernel_library(
|
|||||||
"@sobol_data",
|
"@sobol_data",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- temporary forwarding declaration for libraries in linalg
|
||||||
|
# TODO(b/160344057): Remove after updating dependencies.
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_inverse_op",
|
||||||
|
deps = ["//tensorflow/core/kernels/linalg:matrix_inverse_op"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "einsum_op",
|
||||||
|
deps = ["//tensorflow/core/kernels/linalg:einsum_op"],
|
||||||
|
)
|
||||||
|
376
tensorflow/core/kernels/linalg/BUILD
Normal file
376
tensorflow/core/kernels/linalg/BUILD
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
load(
|
||||||
|
"//tensorflow:tensorflow.bzl",
|
||||||
|
"if_cuda_or_rocm",
|
||||||
|
"tf_kernel_library",
|
||||||
|
)
|
||||||
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
load(
|
||||||
|
"@local_config_rocm//rocm:build_defs.bzl",
|
||||||
|
"if_rocm",
|
||||||
|
)
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||||
|
|
||||||
|
# Description:
|
||||||
|
# Op kernel implementations for TensorFlow.
|
||||||
|
#
|
||||||
|
# Note: Any test that uses GPU support and which we would like to
|
||||||
|
# benchmark should be linked statically so that it can be executed
|
||||||
|
# from a py_binary or cuda_py_test test logger. For such a test,
|
||||||
|
# append "_gpu" to the test name to invoke the GPU benchmarks. Example:
|
||||||
|
#
|
||||||
|
# # for CPU tests
|
||||||
|
# $ bazel test --config opt //third_party/tensorflow/core/kernels:my_op_test
|
||||||
|
# # for GPU benchmarks
|
||||||
|
# $ bazel run --config opt --config=cuda //third_party/tensorflow/core/kernels:my_op_test_gpu -- --benchmarks=..
|
||||||
|
#
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow:__subpackages__",
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(rmlarsen): Remove ASAP.
|
||||||
|
package_group(
|
||||||
|
name = "friends",
|
||||||
|
packages = ["//tensorflow/..."],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Export a few files for use on Android.
|
||||||
|
exports_files([
|
||||||
|
"einsum_op_impl_half.cc",
|
||||||
|
"einsum_op_impl_bfloat16.cc",
|
||||||
|
"einsum_op_impl_int32.cc",
|
||||||
|
"einsum_op_impl_int64.cc",
|
||||||
|
"einsum_op_impl_float.cc",
|
||||||
|
"einsum_op_impl_double.cc",
|
||||||
|
"einsum_op_impl_complex64.cc",
|
||||||
|
"einsum_op_impl_complex128.cc",
|
||||||
|
"einsum_op_impl.h",
|
||||||
|
"einsum_op.h",
|
||||||
|
"linalg_ops_common.h",
|
||||||
|
"linalg_ops_common.cc",
|
||||||
|
"matrix_diag_op.h",
|
||||||
|
"matrix_diag_op.cc",
|
||||||
|
"matrix_inverse_op.cc",
|
||||||
|
"matrix_set_diag_op.h",
|
||||||
|
"matrix_set_diag_op.cc",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Public support libraries ----------------------------------------------------
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "linalg",
|
||||||
|
deps = [
|
||||||
|
":banded_triangular_solve_op",
|
||||||
|
":cholesky_grad",
|
||||||
|
":cholesky_op",
|
||||||
|
":determinant_op",
|
||||||
|
":eig_op",
|
||||||
|
":einsum_op",
|
||||||
|
":lu_op",
|
||||||
|
":matrix_band_part_op",
|
||||||
|
":matrix_diag_op",
|
||||||
|
":matrix_exponential_op",
|
||||||
|
":matrix_inverse_op",
|
||||||
|
":matrix_logarithm_op",
|
||||||
|
":matrix_set_diag_op",
|
||||||
|
":matrix_solve_ls_op",
|
||||||
|
":matrix_solve_op",
|
||||||
|
":matrix_square_root_op",
|
||||||
|
":matrix_triangular_solve_op",
|
||||||
|
":qr_op",
|
||||||
|
":self_adjoint_eig_op",
|
||||||
|
":self_adjoint_eig_v2_op",
|
||||||
|
":svd_op",
|
||||||
|
":tridiagonal_matmul_op",
|
||||||
|
":tridiagonal_solve_op",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
LINALG_DEPS = [
|
||||||
|
":linalg_ops_common",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/kernels:cast_op",
|
||||||
|
"//tensorflow/core/kernels:fill_functor",
|
||||||
|
] + if_cuda([
|
||||||
|
":eye_functor",
|
||||||
|
"//tensorflow/core/util:cuda_solvers",
|
||||||
|
"//tensorflow/core/kernels:transpose_functor",
|
||||||
|
]) + if_rocm([
|
||||||
|
"//tensorflow/core/util:rocm_solvers",
|
||||||
|
])
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_band_part_op",
|
||||||
|
prefix = "matrix_band_part_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_diag_op",
|
||||||
|
prefix = "matrix_diag_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_set_diag_op",
|
||||||
|
prefix = "matrix_set_diag_op",
|
||||||
|
deps = LINALG_DEPS + [":matrix_diag_op"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "cholesky_op",
|
||||||
|
prefix = "cholesky_op",
|
||||||
|
deps = if_cuda([
|
||||||
|
":matrix_band_part_op",
|
||||||
|
]) + LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "cholesky_grad",
|
||||||
|
prefix = "cholesky_grad",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "determinant_op",
|
||||||
|
prefix = "determinant_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_exponential_op",
|
||||||
|
prefix = "matrix_exponential_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_logarithm_op",
|
||||||
|
prefix = "matrix_logarithm_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "self_adjoint_eig_op",
|
||||||
|
prefix = "self_adjoint_eig_op",
|
||||||
|
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "self_adjoint_eig_v2_op",
|
||||||
|
prefix = "self_adjoint_eig_v2_op",
|
||||||
|
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
|
||||||
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "eig_op",
|
||||||
|
prefix = "eig_op",
|
||||||
|
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
|
||||||
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_inverse_op",
|
||||||
|
prefix = "matrix_inverse_op",
|
||||||
|
visibility = [":friends"],
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_solve_ls_op",
|
||||||
|
prefix = "matrix_solve_ls_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_solve_op",
|
||||||
|
prefix = "matrix_solve_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_square_root_op",
|
||||||
|
prefix = "matrix_square_root_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "banded_triangular_solve_op",
|
||||||
|
prefix = "banded_triangular_solve_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "matrix_triangular_solve_op",
|
||||||
|
hdrs = ["matrix_triangular_solve_op_impl.h"],
|
||||||
|
prefix = "matrix_triangular_solve_op",
|
||||||
|
deps = [
|
||||||
|
":linalg_ops_common",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/kernels:fill_functor",
|
||||||
|
"//tensorflow/core:stream_executor",
|
||||||
|
] + if_cuda([
|
||||||
|
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||||
|
"//tensorflow/core/util:cuda_solvers",
|
||||||
|
]) + if_rocm([
|
||||||
|
"@local_config_rocm//rocm:rocprim",
|
||||||
|
"//tensorflow/core/util:rocm_solvers",
|
||||||
|
]) + if_cuda_or_rocm([
|
||||||
|
"//tensorflow/core/kernels:transpose_functor",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "tridiagonal_matmul_op",
|
||||||
|
srcs = ["tridiagonal_matmul_op.cc"],
|
||||||
|
gpu_srcs = ["tridiagonal_matmul_op_gpu.cu.cc"],
|
||||||
|
deps = LINALG_DEPS + if_cuda([
|
||||||
|
"//tensorflow/core/util:cuda_sparse",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "tridiagonal_solve_op",
|
||||||
|
srcs = ["tridiagonal_solve_op.cc"],
|
||||||
|
gpu_srcs = ["tridiagonal_solve_op_gpu.cu.cc"],
|
||||||
|
deps = LINALG_DEPS + if_cuda([
|
||||||
|
"//tensorflow/core/util:cuda_sparse",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "qr_op",
|
||||||
|
prefix = "qr_op",
|
||||||
|
deps = LINALG_DEPS + if_cuda([
|
||||||
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
|
":matrix_band_part_op",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "svd_op",
|
||||||
|
prefix = "svd_op",
|
||||||
|
deps = LINALG_DEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "lu_op",
|
||||||
|
prefix = "lu_op",
|
||||||
|
deps = if_cuda([
|
||||||
|
"//tensorflow/core/util:cuda_solvers",
|
||||||
|
"//tensorflow/core/kernels:transpose_functor",
|
||||||
|
]) + [
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "einsum_op",
|
||||||
|
prefix = "einsum_op",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/kernels:batch_matmul_op",
|
||||||
|
"//tensorflow/core/kernels:fill_functor",
|
||||||
|
"//tensorflow/core/kernels:reduction_ops",
|
||||||
|
"//tensorflow/core/kernels:transpose_functor",
|
||||||
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "linalg_ops_common",
|
||||||
|
srcs = ["linalg_ops_common.cc"],
|
||||||
|
hdrs = ["linalg_ops_common.h"],
|
||||||
|
visibility = ["//visibility:private"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "banded_triangular_solve_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["banded_triangular_solve_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":banded_triangular_solve_op",
|
||||||
|
":matrix_set_diag_op",
|
||||||
|
":matrix_triangular_solve_op",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels:ops_testutil",
|
||||||
|
"//tensorflow/core/kernels:ops_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "eye_functor",
|
||||||
|
hdrs = ["eye_functor.h"],
|
||||||
|
gpu_srcs = [
|
||||||
|
"eye_functor_gpu.cu.cc",
|
||||||
|
"eye_functor.h",
|
||||||
|
],
|
||||||
|
visibility = ["//tensorflow/core/kernels:friends"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
],
|
||||||
|
alwayslink = 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "matrix_triangular_solve_op_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["matrix_triangular_solve_op_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":matrix_triangular_solve_op",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/kernels:broadcast_to_op",
|
||||||
|
"//tensorflow/core/kernels:ops_testutil",
|
||||||
|
"//tensorflow/core/kernels:ops_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# A file group which contains all operators which are known to work on mobile.
|
||||||
|
filegroup(
|
||||||
|
name = "android_all_op_kernels",
|
||||||
|
srcs = glob(
|
||||||
|
[
|
||||||
|
"*.cc",
|
||||||
|
"*.h",
|
||||||
|
],
|
||||||
|
exclude = [
|
||||||
|
"*test.cc",
|
||||||
|
"*test.h",
|
||||||
|
"*_test_*",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
visibility = ["//tensorflow:__subpackages__"],
|
||||||
|
)
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/graph/testlib.h"
|
#include "tensorflow/core/graph/testlib.h"
|
||||||
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_set_diag_op.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -25,16 +25,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
|
||||||
#include "tensorflow/core/kernels/matrix_band_part_op.h"
|
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/kernels/determinant_op.h"
|
#include "tensorflow/core/kernels/linalg/determinant_op.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/LU"
|
#include "third_party/eigen3/Eigen/LU"
|
||||||
@ -28,14 +28,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
|
||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
|
||||||
@ -44,4 +44,4 @@ struct LogDeterminantFromPivotedLUFunctor {
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_
|
@ -21,8 +21,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/kernels/linalg/determinant_op.h"
|
||||||
#include "tensorflow/core/kernels/determinant_op.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/eig_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/eig_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/eig_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/eig_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_
|
||||||
|
|
||||||
// See docs in ../ops/linalg_ops.cc.
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/denormal.h"
|
#include "tensorflow/core/platform/denormal.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -95,4 +95,4 @@ class EigOp : public LinearAlgebraOp<InputScalar, OutputScalar> {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/einsum_op.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_EINSUM_OP_IMPL_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_EINSUM_OP_IMPL_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
|
||||||
|
|
||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
@ -31,8 +31,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
|
||||||
#include "tensorflow/core/kernels/einsum_op.h"
|
|
||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
#include "tensorflow/core/kernels/linalg/einsum_op.h"
|
||||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
@ -780,4 +780,4 @@ DECLARE_GPU_SPECS(complex128);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_EINSUM_OP_IMPL_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/einsum_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EYE_FUNCTOR_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_EYE_FUNCTOR_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/type_traits.h"
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
#include "tensorflow/core/kernels/eye_functor.h"
|
#include "tensorflow/core/kernels/linalg/eye_functor.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
221
tensorflow/core/kernels/linalg/linalg_ops_common.h
Normal file
221
tensorflow/core/kernels/linalg/linalg_ops_common.h
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
|
||||||
|
#define TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
|
||||||
|
|
||||||
|
// Classes to support linear algebra functionality, similar to the numpy.linalg
|
||||||
|
// module. Supports batch computation on several matrices at once, sharding the
|
||||||
|
// computations across different threads if necessary.
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Base class for linear algebra operators.
|
||||||
|
template <class InputScalar, class OutputScalar = InputScalar>
|
||||||
|
class LinearAlgebraOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
|
||||||
|
// Returns the number of leading inputs that are to be treated as matrix
|
||||||
|
// inputs. By default this is all the inputs. Derived classes can override
|
||||||
|
// this to tell the base class to ignore one or more trailing inputs.
|
||||||
|
virtual int NumMatrixInputs(const OpKernelContext* context) const {
|
||||||
|
return context->num_inputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if the number of inputs and their shapes are as expected.
|
||||||
|
// Many ops take a single square input matrix, so we provide that as a default
|
||||||
|
// implementation for convenience.
|
||||||
|
virtual void ValidateInputMatrixShapes(
|
||||||
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes) const {
|
||||||
|
ValidateSingleSquareMatrix(context, input_matrix_shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convenience validators for common cases:
|
||||||
|
//
|
||||||
|
// Validate op taking a single matrix A.
|
||||||
|
static void ValidateSingleMatrix(OpKernelContext* context,
|
||||||
|
const TensorShapes& input_matrix_shapes);
|
||||||
|
// Validate op taking a single square matrix A.
|
||||||
|
static void ValidateSingleSquareMatrix(
|
||||||
|
OpKernelContext* context, const TensorShapes& input_matrix_shapes);
|
||||||
|
// Validate op taking two matrices A and B that have the same number of rows.
|
||||||
|
static void ValidateSolver(OpKernelContext* context,
|
||||||
|
const TensorShapes& input_matrix_shapes);
|
||||||
|
// Validate op taking two matrices A and B that have the same number of rows
|
||||||
|
// and A is square.
|
||||||
|
static void ValidateSquareSolver(OpKernelContext* context,
|
||||||
|
const TensorShapes& input_matrix_shapes);
|
||||||
|
|
||||||
|
// Returns the output shapes of each individual matrix operation. Output
|
||||||
|
// matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
|
||||||
|
//
|
||||||
|
// The derived class may return a number of shapes (N) less than
|
||||||
|
// context->num_outputs() (M) to indicate that a only leading subset of
|
||||||
|
// the outputs will be populated. In this case, a dummy scalar tensor with
|
||||||
|
// value zero will be return for the last M-N outputs.
|
||||||
|
//
|
||||||
|
// For many ops, the output dimensions are the same as the input dimensions,
|
||||||
|
// so we provide that as a default implementation for convenience.
|
||||||
|
virtual TensorShapes GetOutputMatrixShapes(
|
||||||
|
const TensorShapes& input_matrix_shapes) const {
|
||||||
|
return input_matrix_shapes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the cost per matrix operation. This is used to determine the
|
||||||
|
// number of threads to use for parallelizing calls to ComputeMatrix in
|
||||||
|
// batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
|
||||||
|
// in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
|
||||||
|
// * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
|
||||||
|
// default implementation for convenience.
|
||||||
|
virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
|
||||||
|
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
|
||||||
|
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
|
||||||
|
double cost = std::max(m, n) * std::min(m, n) * std::min(m, n);
|
||||||
|
return cost >= static_cast<double>(kint64max) ? kint64max
|
||||||
|
: static_cast<int64>(cost);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if it is safe to forward (alias) input to output buffer
|
||||||
|
// and expect the kernel to perform the computation inplace.
|
||||||
|
virtual bool EnableInputForwarding() const { return true; }
|
||||||
|
|
||||||
|
using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
|
||||||
|
Eigen::RowMajor>;
|
||||||
|
using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
|
||||||
|
using InputMatrixMap = Eigen::Map<InputMatrix>;
|
||||||
|
using InputConstVectorMap =
|
||||||
|
Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
|
||||||
|
using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
|
||||||
|
using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
|
||||||
|
using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
|
||||||
|
|
||||||
|
using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
|
||||||
|
Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
|
using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
|
||||||
|
using OutputMatrixMap = Eigen::Map<OutputMatrix>;
|
||||||
|
using OutputConstVectorMap =
|
||||||
|
Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
|
||||||
|
using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
|
||||||
|
using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
|
||||||
|
using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
|
||||||
|
|
||||||
|
// backward compatibility
|
||||||
|
using Scalar = OutputScalar;
|
||||||
|
using Matrix =
|
||||||
|
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
|
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||||
|
using MatrixMap = Eigen::Map<Matrix>;
|
||||||
|
using ConstVectorMap =
|
||||||
|
Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>;
|
||||||
|
using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>;
|
||||||
|
using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>;
|
||||||
|
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
|
||||||
|
|
||||||
|
// Performs a single matrix computation given input matrices, and
|
||||||
|
// stores the result in outputs. For batch operations, this will be called
|
||||||
|
// repeatedly for a single call to Compute() when multiple matrices exist in
|
||||||
|
// input Tensors with rank > 2. In this case the calls to ComputeMatrix are
|
||||||
|
// parallelized. The number of threads used is determined by a cost model from
|
||||||
|
// the value returned by GetCostPerUnit().
|
||||||
|
virtual void ComputeMatrix(OpKernelContext* context,
|
||||||
|
const InputConstMatrixMaps& inputs,
|
||||||
|
OutputMatrixMaps* outputs) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
|
||||||
|
using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
|
||||||
|
// This function maps 2-d slices (matrices) of the input and output tensors
|
||||||
|
// using Eigen::Map and calls ComputeMatrix implemented in terms of the
|
||||||
|
// Eigen::MatrixBase API by the derived class.
|
||||||
|
//
|
||||||
|
// The 'matrix_index' parameter specifies the index of the matrix to be used
|
||||||
|
// from each input tensor, and the index of the matrix to be written to each
|
||||||
|
// output tensor. The input matrices are in row major order, and located at
|
||||||
|
// the memory addresses
|
||||||
|
// inputs[i].flat<Scalar>().data() +
|
||||||
|
// matrix_index * input_matrix_shapes[i].num_elements()
|
||||||
|
// for i in 0...inputs.size()-1.
|
||||||
|
// The output matrices are in row major order, and located at the memory
|
||||||
|
// address
|
||||||
|
// outputs[i]->flat<Scalar>().data() +
|
||||||
|
// matrix_index * output_matrix_shapes[i].num_elements().
|
||||||
|
// for i in 0...outputs.size()-1.
|
||||||
|
//
|
||||||
|
void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
|
||||||
|
const TensorInputs& inputs,
|
||||||
|
const TensorShapes& input_matrix_shapes,
|
||||||
|
const TensorOutputs& outputs,
|
||||||
|
const TensorShapes& output_matrix_shapes);
|
||||||
|
|
||||||
|
void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
|
||||||
|
TensorShapes* input_matrix_shapes,
|
||||||
|
TensorShape* batch_shape);
|
||||||
|
|
||||||
|
void PrepareOutputs(OpKernelContext* context,
|
||||||
|
const TensorShapes& input_matrix_shapes,
|
||||||
|
const TensorShape& batch_shape, TensorOutputs* outputs,
|
||||||
|
TensorShapes* output_matrix_shapes);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Declare LinearAlgebraOp, which is explicitly instantiated in
|
||||||
|
// linalg_ops_common.cc for float, double, complex64, and complex128.
|
||||||
|
extern template class LinearAlgebraOp<float>;
|
||||||
|
extern template class LinearAlgebraOp<double>;
|
||||||
|
extern template class LinearAlgebraOp<complex64>;
|
||||||
|
extern template class LinearAlgebraOp<complex128>;
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#define INHERIT_LINALG_TYPEDEFS(Scalar) \
|
||||||
|
typedef LinearAlgebraOp<Scalar> Base; \
|
||||||
|
using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \
|
||||||
|
using Matrix = typename Base::Matrix; \
|
||||||
|
using MatrixMap = typename Base::MatrixMap; \
|
||||||
|
using MatrixMaps = typename Base::MatrixMaps; \
|
||||||
|
using ConstMatrixMap = typename Base::ConstMatrixMap; \
|
||||||
|
using ConstMatrixMaps = typename Base::ConstMatrixMaps; \
|
||||||
|
using ConstVectorMap = typename Base::ConstVectorMap; \
|
||||||
|
using TensorShapes = typename Base::TensorShapes;
|
||||||
|
|
||||||
|
#define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
|
||||||
|
|
||||||
|
#define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
|
||||||
|
|
||||||
|
// Deprecated, use one of the device-specific macros above.
|
||||||
|
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
|
||||||
|
REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
|
@ -25,9 +25,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -21,11 +21,12 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_band_part_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
@ -34,4 +34,4 @@ struct MatrixBandPartFunctor {
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/matrix_band_part_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_diag_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_
|
||||||
|
|
||||||
// Generator definition for MatrixDiagOp, must be compilable by nvcc.
|
// Generator definition for MatrixDiagOp, must be compilable by nvcc.
|
||||||
|
|
||||||
@ -69,4 +69,4 @@ struct MatrixDiag {
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/matrix_diag_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -32,9 +32,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/kernels/linalg/eye_functor.h"
|
||||||
#include "tensorflow/core/kernels/eye_functor.h"
|
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_set_diag_op.h"
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -30,7 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/matrix_diag_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
@ -39,4 +39,4 @@ struct MatrixSetDiag {
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_set_diag_op.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_
|
||||||
|
|
||||||
// See docs in ../ops/linalg_ops.cc.
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -163,4 +163,4 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_
|
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -33,8 +33,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
@ -14,7 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -15,8 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
// See docs in ../ops/linalg_ops.cc.
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
//
|
//
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -38,9 +38,9 @@ limitations under the License.
|
|||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/rocm_solvers.h"
|
#include "tensorflow/core/util/rocm_solvers.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -434,4 +434,4 @@ struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/gpus/cuda/include/cuda.h"
|
#include "third_party/gpus/cuda/include/cuda.h"
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/qr_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/qr_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/qr_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/qr_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_
|
||||||
|
|
||||||
// See docs in ../ops/linalg_ops.cc.
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
//
|
//
|
||||||
@ -33,7 +33,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -41,11 +41,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
|
||||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||||
#include "tensorflow/core/kernels/eye_functor.h"
|
#include "tensorflow/core/kernels/linalg/eye_functor.h"
|
||||||
#include "tensorflow/core/kernels/matrix_band_part_op.h"
|
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -299,4 +299,4 @@ class QrOpGpu : public AsyncOpKernel {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/denormal.h"
|
#include "tensorflow/core/platform/denormal.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -26,12 +26,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/cast_op.h"
|
#include "tensorflow/core/kernels/cast_op.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
|
||||||
#include "tensorflow/core/kernels/cwise_ops.h"
|
#include "tensorflow/core/kernels/cwise_ops.h"
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
|
||||||
|
|
||||||
// See docs in ../ops/linalg_ops.cc.
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/denormal.h"
|
#include "tensorflow/core/platform/denormal.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -89,4 +89,4 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/svd_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/svd_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/svd_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/svd_op_impl.h"
|
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
@ -36,14 +36,14 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/kernels/linalg/eye_functor.h"
|
||||||
#include "tensorflow/core/kernels/eye_functor.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_
|
||||||
|
|
||||||
// See docs in ../ops/linalg_ops.cc.
|
// See docs in ../ops/linalg_ops.cc.
|
||||||
//
|
//
|
||||||
@ -27,7 +27,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -118,4 +118,4 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
@ -22,11 +22,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#include "tensorflow/core/util/gpu_device_functions.h"
|
#include "tensorflow/core/util/gpu_device_functions.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
#include "tensorflow/core/util/gpu_launch_config.h"
|
#include "tensorflow/core/util/gpu_launch_config.h"
|
@ -19,7 +19,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
@ -23,11 +23,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
|
||||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
|
||||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#include "tensorflow/core/util/gpu_device_functions.h"
|
#include "tensorflow/core/util/gpu_device_functions.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
#include "tensorflow/core/util/gpu_launch_config.h"
|
#include "tensorflow/core/util/gpu_launch_config.h"
|
@ -12,211 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
|
||||||
|
|
||||||
// Classes to support linear algebra functionality, similar to the numpy.linalg
|
// Temporary forwarding header.
|
||||||
// module. Supports batch computation on several matrices at once, sharding the
|
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
|
||||||
// computations across different threads if necessary.
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
|
||||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
// Base class for linear algebra operators.
|
|
||||||
template <class InputScalar, class OutputScalar = InputScalar>
|
|
||||||
class LinearAlgebraOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
|
|
||||||
// Returns the number of leading inputs that are to be treated as matrix
|
|
||||||
// inputs. By default this is all the inputs. Derived classes can override
|
|
||||||
// this to tell the base class to ignore one or more trailing inputs.
|
|
||||||
virtual int NumMatrixInputs(const OpKernelContext* context) const {
|
|
||||||
return context->num_inputs();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if the number of inputs and their shapes are as expected.
|
|
||||||
// Many ops take a single square input matrix, so we provide that as a default
|
|
||||||
// implementation for convenience.
|
|
||||||
virtual void ValidateInputMatrixShapes(
|
|
||||||
OpKernelContext* context, const TensorShapes& input_matrix_shapes) const {
|
|
||||||
ValidateSingleSquareMatrix(context, input_matrix_shapes);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convenience validators for common cases:
|
|
||||||
//
|
|
||||||
// Validate op taking a single matrix A.
|
|
||||||
static void ValidateSingleMatrix(OpKernelContext* context,
|
|
||||||
const TensorShapes& input_matrix_shapes);
|
|
||||||
// Validate op taking a single square matrix A.
|
|
||||||
static void ValidateSingleSquareMatrix(
|
|
||||||
OpKernelContext* context, const TensorShapes& input_matrix_shapes);
|
|
||||||
// Validate op taking two matrices A and B that have the same number of rows.
|
|
||||||
static void ValidateSolver(OpKernelContext* context,
|
|
||||||
const TensorShapes& input_matrix_shapes);
|
|
||||||
// Validate op taking two matrices A and B that have the same number of rows
|
|
||||||
// and A is square.
|
|
||||||
static void ValidateSquareSolver(OpKernelContext* context,
|
|
||||||
const TensorShapes& input_matrix_shapes);
|
|
||||||
|
|
||||||
// Returns the output shapes of each individual matrix operation. Output
|
|
||||||
// matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
|
|
||||||
//
|
|
||||||
// The derived class may return a number of shapes (N) less than
|
|
||||||
// context->num_outputs() (M) to indicate that a only leading subset of
|
|
||||||
// the outputs will be populated. In this case, a dummy scalar tensor with
|
|
||||||
// value zero will be return for the last M-N outputs.
|
|
||||||
//
|
|
||||||
// For many ops, the output dimensions are the same as the input dimensions,
|
|
||||||
// so we provide that as a default implementation for convenience.
|
|
||||||
virtual TensorShapes GetOutputMatrixShapes(
|
|
||||||
const TensorShapes& input_matrix_shapes) const {
|
|
||||||
return input_matrix_shapes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the cost per matrix operation. This is used to determine the
|
|
||||||
// number of threads to use for parallelizing calls to ComputeMatrix in
|
|
||||||
// batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
|
|
||||||
// in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
|
|
||||||
// * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
|
|
||||||
// default implementation for convenience.
|
|
||||||
virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
|
|
||||||
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
|
|
||||||
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
|
|
||||||
double cost = std::max(m, n) * std::min(m, n) * std::min(m, n);
|
|
||||||
return cost >= static_cast<double>(kint64max) ? kint64max
|
|
||||||
: static_cast<int64>(cost);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns true if it is safe to forward (alias) input to output buffer
|
|
||||||
// and expect the kernel to perform the computation inplace.
|
|
||||||
virtual bool EnableInputForwarding() const { return true; }
|
|
||||||
|
|
||||||
using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
|
|
||||||
Eigen::RowMajor>;
|
|
||||||
using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
|
|
||||||
using InputMatrixMap = Eigen::Map<InputMatrix>;
|
|
||||||
using InputConstVectorMap =
|
|
||||||
Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
|
|
||||||
using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
|
|
||||||
using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
|
|
||||||
using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
|
|
||||||
|
|
||||||
using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
|
|
||||||
Eigen::Dynamic, Eigen::RowMajor>;
|
|
||||||
using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
|
|
||||||
using OutputMatrixMap = Eigen::Map<OutputMatrix>;
|
|
||||||
using OutputConstVectorMap =
|
|
||||||
Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
|
|
||||||
using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
|
|
||||||
using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
|
|
||||||
using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
|
|
||||||
|
|
||||||
// backward compatibility
|
|
||||||
using Scalar = OutputScalar;
|
|
||||||
using Matrix =
|
|
||||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
|
||||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
|
||||||
using MatrixMap = Eigen::Map<Matrix>;
|
|
||||||
using ConstVectorMap =
|
|
||||||
Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>;
|
|
||||||
using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>;
|
|
||||||
using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>;
|
|
||||||
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
|
|
||||||
|
|
||||||
// Performs a single matrix computation given input matrices, and
|
|
||||||
// stores the result in outputs. For batch operations, this will be called
|
|
||||||
// repeatedly for a single call to Compute() when multiple matrices exist in
|
|
||||||
// input Tensors with rank > 2. In this case the calls to ComputeMatrix are
|
|
||||||
// parallelized. The number of threads used is determined by a cost model from
|
|
||||||
// the value returned by GetCostPerUnit().
|
|
||||||
virtual void ComputeMatrix(OpKernelContext* context,
|
|
||||||
const InputConstMatrixMaps& inputs,
|
|
||||||
OutputMatrixMaps* outputs) = 0;
|
|
||||||
|
|
||||||
private:
|
|
||||||
using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
|
|
||||||
using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
|
|
||||||
// This function maps 2-d slices (matrices) of the input and output tensors
|
|
||||||
// using Eigen::Map and calls ComputeMatrix implemented in terms of the
|
|
||||||
// Eigen::MatrixBase API by the derived class.
|
|
||||||
//
|
|
||||||
// The 'matrix_index' parameter specifies the index of the matrix to be used
|
|
||||||
// from each input tensor, and the index of the matrix to be written to each
|
|
||||||
// output tensor. The input matrices are in row major order, and located at
|
|
||||||
// the memory addresses
|
|
||||||
// inputs[i].flat<Scalar>().data() +
|
|
||||||
// matrix_index * input_matrix_shapes[i].num_elements()
|
|
||||||
// for i in 0...inputs.size()-1.
|
|
||||||
// The output matrices are in row major order, and located at the memory
|
|
||||||
// address
|
|
||||||
// outputs[i]->flat<Scalar>().data() +
|
|
||||||
// matrix_index * output_matrix_shapes[i].num_elements().
|
|
||||||
// for i in 0...outputs.size()-1.
|
|
||||||
//
|
|
||||||
void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
|
|
||||||
const TensorInputs& inputs,
|
|
||||||
const TensorShapes& input_matrix_shapes,
|
|
||||||
const TensorOutputs& outputs,
|
|
||||||
const TensorShapes& output_matrix_shapes);
|
|
||||||
|
|
||||||
void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
|
|
||||||
TensorShapes* input_matrix_shapes,
|
|
||||||
TensorShape* batch_shape);
|
|
||||||
|
|
||||||
void PrepareOutputs(OpKernelContext* context,
|
|
||||||
const TensorShapes& input_matrix_shapes,
|
|
||||||
const TensorShape& batch_shape, TensorOutputs* outputs,
|
|
||||||
TensorShapes* output_matrix_shapes);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Declare LinearAlgebraOp, which is explicitly instantiated in
|
|
||||||
// linalg_ops_common.cc for float, double, complex64, and complex128.
|
|
||||||
extern template class LinearAlgebraOp<float>;
|
|
||||||
extern template class LinearAlgebraOp<double>;
|
|
||||||
extern template class LinearAlgebraOp<complex64>;
|
|
||||||
extern template class LinearAlgebraOp<complex128>;
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#define INHERIT_LINALG_TYPEDEFS(Scalar) \
|
|
||||||
typedef LinearAlgebraOp<Scalar> Base; \
|
|
||||||
using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \
|
|
||||||
using Matrix = typename Base::Matrix; \
|
|
||||||
using MatrixMap = typename Base::MatrixMap; \
|
|
||||||
using MatrixMaps = typename Base::MatrixMaps; \
|
|
||||||
using ConstMatrixMap = typename Base::ConstMatrixMap; \
|
|
||||||
using ConstMatrixMaps = typename Base::ConstMatrixMaps; \
|
|
||||||
using ConstVectorMap = typename Base::ConstVectorMap; \
|
|
||||||
using TensorShapes = typename Base::TensorShapes;
|
|
||||||
|
|
||||||
#define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
|
|
||||||
REGISTER_KERNEL_BUILDER( \
|
|
||||||
Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
|
|
||||||
|
|
||||||
#define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
|
|
||||||
REGISTER_KERNEL_BUILDER( \
|
|
||||||
Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
|
|
||||||
|
|
||||||
// Deprecated, use one of the device-specific macros above.
|
|
||||||
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
|
|
||||||
REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
|
||||||
|
@ -45,13 +45,13 @@ limitations under the License.
|
|||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||||
|
|
||||||
using stream_executor::cuda::ScopedActivateExecutorContext;
|
using stream_executor::cuda::ScopedActivateExecutorContext;
|
||||||
#elif TENSORFLOW_USE_ROCM
|
#elif TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
|
||||||
#include "tensorflow/core/platform/rocm.h"
|
#include "tensorflow/core/platform/rocm.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
using stream_executor::rocm::ScopedActivateExecutorContext;
|
using stream_executor::rocm::ScopedActivateExecutorContext;
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
@ -80,8 +80,8 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core/kernels:transpose_functor",
|
"//tensorflow/core/kernels:transpose_functor",
|
||||||
"//tensorflow/core/kernels:gpu_prim_hdrs",
|
"//tensorflow/core/kernels:gpu_prim_hdrs",
|
||||||
] + if_cuda_or_rocm([
|
] + if_cuda_or_rocm([
|
||||||
"//tensorflow/core/kernels:cuda_solvers",
|
"//tensorflow/core/util:cuda_solvers",
|
||||||
"//tensorflow/core/kernels:cuda_sparse",
|
"//tensorflow/core/util:cuda_sparse",
|
||||||
]),
|
]),
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -32,8 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -32,8 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -34,8 +34,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -32,8 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -35,8 +35,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
@ -20,13 +20,13 @@ limitations under the License.
|
|||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
|
||||||
#include "tensorflow/core/kernels/gpu_device_array.h"
|
#include "tensorflow/core/kernels/gpu_device_array.h"
|
||||||
#include "tensorflow/core/kernels/gpu_device_array_gpu.h"
|
#include "tensorflow/core/kernels/gpu_device_array_gpu.h"
|
||||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -37,8 +37,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/threadpool.h"
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -29,8 +29,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -36,8 +36,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -30,8 +30,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -33,8 +33,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||||
using stream_executor::cuda::ScopedActivateExecutorContext;
|
using stream_executor::cuda::ScopedActivateExecutorContext;
|
||||||
|
@ -14,6 +14,7 @@ load(
|
|||||||
"tf_copts",
|
"tf_copts",
|
||||||
"tf_cuda_library",
|
"tf_cuda_library",
|
||||||
"tf_cuda_only_cc_test",
|
"tf_cuda_only_cc_test",
|
||||||
|
"tf_kernel_library",
|
||||||
)
|
)
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
|
||||||
load(
|
load(
|
||||||
@ -24,6 +25,11 @@ load(
|
|||||||
"//tensorflow/core/platform:build_config_root.bzl",
|
"//tensorflow/core/platform:build_config_root.bzl",
|
||||||
"if_static",
|
"if_static",
|
||||||
)
|
)
|
||||||
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
load(
|
||||||
|
"@local_config_rocm//rocm:build_defs.bzl",
|
||||||
|
"if_rocm",
|
||||||
|
)
|
||||||
|
|
||||||
default_package_visibility = [
|
default_package_visibility = [
|
||||||
"//tensorflow/core:__subpackages__",
|
"//tensorflow/core:__subpackages__",
|
||||||
@ -567,6 +573,63 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "cuda_solvers",
|
||||||
|
srcs = ["cuda_solvers.cc"],
|
||||||
|
hdrs = ["cuda_solvers.h"],
|
||||||
|
# @local_config_cuda//cuda:cusolver_static, //third_party/eigen3:blas,
|
||||||
|
# and //third_party/libf2c all contain various parts of BLAS, LAPACK,
|
||||||
|
# and f2c helper functions in global namespace. Tell the compiler to
|
||||||
|
# allow multiple definitions when linking this.
|
||||||
|
linkopts = select({
|
||||||
|
"//tensorflow:macos": [],
|
||||||
|
"//tensorflow:windows": [],
|
||||||
|
"//conditions:default": ["-Wl,-z,muldefs"],
|
||||||
|
}),
|
||||||
|
visibility = ["//tensorflow/core/kernels:friends"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||||
|
"//tensorflow/stream_executor/cuda:cublas_lib",
|
||||||
|
"//tensorflow/stream_executor/cuda:cusolver_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "rocm_solvers",
|
||||||
|
srcs = ["rocm_solvers.cc"],
|
||||||
|
hdrs = ["rocm_solvers.h"],
|
||||||
|
visibility = ["//tensorflow/core/kernels:friends"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"//tensorflow/stream_executor/platform:dso_loader",
|
||||||
|
"//tensorflow/stream_executor/rocm:rocblas_plugin",
|
||||||
|
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
|
||||||
|
] + if_rocm([
|
||||||
|
"@local_config_rocm//rocm:rocprim",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_kernel_library(
|
||||||
|
name = "cuda_sparse",
|
||||||
|
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
|
||||||
|
hdrs = ["cuda_sparse.h"],
|
||||||
|
deps = [
|
||||||
|
":cuda_solvers",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
] + if_cuda([
|
||||||
|
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||||
|
"@cub_archive//:cub",
|
||||||
|
]) + if_rocm([
|
||||||
|
"@local_config_rocm//rocm:hipsparse",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
# Tests.
|
# Tests.
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
==============================================================================
|
==============================================================================
|
||||||
*/
|
*/
|
||||||
#ifdef GOOGLE_CUDA
|
#ifdef GOOGLE_CUDA
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <complex>
|
#include <complex>
|
@ -14,8 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================
|
==============================================================================
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
|
||||||
|
|
||||||
// This header declares the class CudaSolver, which contains wrappers of linear
|
// This header declares the class CudaSolver, which contains wrappers of linear
|
||||||
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
|
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
|
||||||
@ -435,7 +435,7 @@ class HostLapackInfo : public ScratchSpace<int> {
|
|||||||
public:
|
public:
|
||||||
HostLapackInfo(OpKernelContext* context, int64 size,
|
HostLapackInfo(OpKernelContext* context, int64 size,
|
||||||
const std::string& debug_info)
|
const std::string& debug_info)
|
||||||
: ScratchSpace<int>(context, size, debug_info, /* on_host */ true){};
|
: ScratchSpace<int>(context, size, debug_info, /* on_host */ true) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeviceLapackInfo : public ScratchSpace<int> {
|
class DeviceLapackInfo : public ScratchSpace<int> {
|
||||||
@ -489,4 +489,4 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
|
|||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#ifdef GOOGLE_CUDA
|
#ifdef GOOGLE_CUDA
|
||||||
|
|
||||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
#include "tensorflow/core/util/cuda_sparse.h"
|
||||||
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -28,7 +28,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/cuda_solvers.h"
|
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
@ -38,6 +37,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/cuda_solvers.h"
|
||||||
|
|
||||||
// TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
|
// TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.
|
||||||
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
#define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
|
||||||
|
|
||||||
// This header declares the class GpuSparse, which contains wrappers of
|
// This header declares the class GpuSparse, which contains wrappers of
|
||||||
// cuSparse libraries for use in TensorFlow kernels.
|
// cuSparse libraries for use in TensorFlow kernels.
|
||||||
@ -75,8 +75,7 @@ using gpuStream_t = hipStream_t;
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
inline std::string ConvertGPUSparseErrorToString(
|
inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
|
||||||
const gpusparseStatus_t status) {
|
|
||||||
switch (status) {
|
switch (status) {
|
||||||
#define STRINGIZE(q) #q
|
#define STRINGIZE(q) #q
|
||||||
#define RETURN_IF_STATUS(err) \
|
#define RETURN_IF_STATUS(err) \
|
||||||
@ -206,49 +205,49 @@ class GpuSparse {
|
|||||||
// Solves tridiagonal system of equations.
|
// Solves tridiagonal system of equations.
|
||||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
|
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Gtsv2(int m, int n, const Scalar *dl, const Scalar *d,
|
Status Gtsv2(int m, int n, const Scalar* dl, const Scalar* d,
|
||||||
const Scalar *du, Scalar *B, int ldb, void *pBuffer) const;
|
const Scalar* du, Scalar* B, int ldb, void* pBuffer) const;
|
||||||
|
|
||||||
// Computes the size of a temporary buffer used by Gtsv2.
|
// Computes the size of a temporary buffer used by Gtsv2.
|
||||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
|
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Gtsv2BufferSizeExt(int m, int n, const Scalar *dl, const Scalar *d,
|
Status Gtsv2BufferSizeExt(int m, int n, const Scalar* dl, const Scalar* d,
|
||||||
const Scalar *du, const Scalar *B, int ldb,
|
const Scalar* du, const Scalar* B, int ldb,
|
||||||
size_t *bufferSizeInBytes) const;
|
size_t* bufferSizeInBytes) const;
|
||||||
|
|
||||||
// Solves tridiagonal system of equations without partial pivoting.
|
// Solves tridiagonal system of equations without partial pivoting.
|
||||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
|
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Gtsv2NoPivot(int m, int n, const Scalar *dl, const Scalar *d,
|
Status Gtsv2NoPivot(int m, int n, const Scalar* dl, const Scalar* d,
|
||||||
const Scalar *du, Scalar *B, int ldb,
|
const Scalar* du, Scalar* B, int ldb,
|
||||||
void *pBuffer) const;
|
void* pBuffer) const;
|
||||||
|
|
||||||
// Computes the size of a temporary buffer used by Gtsv2NoPivot.
|
// Computes the size of a temporary buffer used by Gtsv2NoPivot.
|
||||||
// See:
|
// See:
|
||||||
// https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
|
// https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar *dl,
|
Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar* dl,
|
||||||
const Scalar *d, const Scalar *du,
|
const Scalar* d, const Scalar* du,
|
||||||
const Scalar *B, int ldb,
|
const Scalar* B, int ldb,
|
||||||
size_t *bufferSizeInBytes) const;
|
size_t* bufferSizeInBytes) const;
|
||||||
|
|
||||||
// Solves a batch of tridiagonal systems of equations. Doesn't support
|
// Solves a batch of tridiagonal systems of equations. Doesn't support
|
||||||
// multiple right-hand sides per each system. Doesn't do pivoting.
|
// multiple right-hand sides per each system. Doesn't do pivoting.
|
||||||
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
|
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Gtsv2StridedBatch(int m, const Scalar *dl, const Scalar *d,
|
Status Gtsv2StridedBatch(int m, const Scalar* dl, const Scalar* d,
|
||||||
const Scalar *du, Scalar *x, int batchCount,
|
const Scalar* du, Scalar* x, int batchCount,
|
||||||
int batchStride, void *pBuffer) const;
|
int batchStride, void* pBuffer) const;
|
||||||
|
|
||||||
// Computes the size of a temporary buffer used by Gtsv2StridedBatch.
|
// Computes the size of a temporary buffer used by Gtsv2StridedBatch.
|
||||||
// See:
|
// See:
|
||||||
// https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
|
// https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
|
||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar *dl,
|
Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar* dl,
|
||||||
const Scalar *d, const Scalar *du,
|
const Scalar* d, const Scalar* du,
|
||||||
const Scalar *x, int batchCount,
|
const Scalar* x, int batchCount,
|
||||||
int batchStride,
|
int batchStride,
|
||||||
size_t *bufferSizeInBytes) const;
|
size_t* bufferSizeInBytes) const;
|
||||||
|
|
||||||
// Compresses the indices of rows or columns. It can be interpreted as a
|
// Compresses the indices of rows or columns. It can be interpreted as a
|
||||||
// conversion from COO to CSR sparse storage format. See:
|
// conversion from COO to CSR sparse storage format. See:
|
||||||
@ -449,7 +448,7 @@ class GpuSparse {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
bool initialized_;
|
bool initialized_;
|
||||||
OpKernelContext *context_; // not owned.
|
OpKernelContext* context_; // not owned.
|
||||||
gpuStream_t gpu_stream_;
|
gpuStream_t gpu_stream_;
|
||||||
gpusparseHandle_t* gpusparse_handle_; // not owned.
|
gpusparseHandle_t* gpusparse_handle_; // not owned.
|
||||||
|
|
||||||
@ -585,4 +584,4 @@ class GpuSparseCsrSortingConversionInfo {
|
|||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
|
#endif // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user