Break up core/kernels/BUILD (part 1 of N):

Move linear algebra kernels to subdirectory tensorflow/core/kernels/linalg with its own BUILD file.

PiperOrigin-RevId: 324923762
Change-Id: Id17aac690729b62ae97525df5bb57d6a073d6b0c
This commit is contained in:
A. Unique TensorFlower 2020-08-04 17:19:50 -07:00 committed by TensorFlower Gardener
parent 79594069bb
commit 84d053187c
103 changed files with 885 additions and 772 deletions
tensorflow/core
BUILD
kernels
BUILD
linalg
BUILDbanded_triangular_solve_op.ccbanded_triangular_solve_op_test.cccholesky_grad.cccholesky_op.ccdeterminant_op.ccdeterminant_op.hdeterminant_op_gpu.cu.cceig_op_complex128.cceig_op_complex64.cceig_op_double.cceig_op_float.cceig_op_impl.heinsum_op.heinsum_op_gpu.cu.cceinsum_op_impl.heinsum_op_impl_bfloat16.cceinsum_op_impl_complex128.cceinsum_op_impl_complex64.cceinsum_op_impl_double.cceinsum_op_impl_float.cceinsum_op_impl_half.cceinsum_op_impl_int32.cceinsum_op_impl_int64.cceye_functor.heye_functor_gpu.cu.cclinalg_ops_common.cclinalg_ops_common.hlu_op.cclu_op_gpu.cu.ccmatrix_band_part_op.ccmatrix_band_part_op.hmatrix_band_part_op_gpu.cu.ccmatrix_diag_op.ccmatrix_diag_op.hmatrix_diag_op_gpu.cu.ccmatrix_exponential_op.ccmatrix_inverse_op.ccmatrix_logarithm_op.ccmatrix_set_diag_op.ccmatrix_set_diag_op.hmatrix_set_diag_op_gpu.cu.ccmatrix_solve_ls_op_complex128.ccmatrix_solve_ls_op_complex64.ccmatrix_solve_ls_op_double.ccmatrix_solve_ls_op_float.ccmatrix_solve_ls_op_impl.hmatrix_solve_op.ccmatrix_square_root_op.ccmatrix_triangular_solve_op_complex.ccmatrix_triangular_solve_op_impl.hmatrix_triangular_solve_op_real.ccmatrix_triangular_solve_op_test.ccqr_op_complex128.ccqr_op_complex64.ccqr_op_double.ccqr_op_float.ccqr_op_impl.hself_adjoint_eig_op.ccself_adjoint_eig_v2_op_complex128.ccself_adjoint_eig_v2_op_complex64.ccself_adjoint_eig_v2_op_double.ccself_adjoint_eig_v2_op_float.ccself_adjoint_eig_v2_op_gpu.ccself_adjoint_eig_v2_op_impl.hsvd_op_complex128.ccsvd_op_complex64.ccsvd_op_double.ccsvd_op_float.ccsvd_op_gpu.cu.ccsvd_op_impl.htridiagonal_matmul_op.cctridiagonal_matmul_op_gpu.cu.cctridiagonal_solve_op.cctridiagonal_solve_op_gpu.cu.cc
linalg_ops_common.hsegment_reduction_ops_impl.h
sparse
where_op.cc
util

View File

@ -1010,7 +1010,7 @@ cc_library(
"//tensorflow/core/kernels:histogram_op",
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",
"//tensorflow/core/kernels:linalg",
"//tensorflow/core/kernels/linalg:linalg",
"//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging",
"//tensorflow/core/kernels:manip",

View File

@ -1039,9 +1039,6 @@ cc_library(
":immutable_constant_op",
":inplace_ops",
":listdiff_op",
":matrix_band_part_op",
":matrix_diag_op",
":matrix_set_diag_op",
":mirror_pad_op",
":one_hot_op",
":pack_op",
@ -1174,26 +1171,6 @@ tf_kernel_library(
deps = ARRAY_DEPS,
)
tf_kernel_library(
name = "matrix_band_part_op",
prefix = "matrix_band_part_op",
deps = if_cuda([
":cuda_solvers",
]) + ARRAY_DEPS,
)
tf_kernel_library(
name = "matrix_diag_op",
prefix = "matrix_diag_op",
deps = ARRAY_DEPS,
)
tf_kernel_library(
name = "matrix_set_diag_op",
prefix = "matrix_set_diag_op",
deps = ARRAY_DEPS + [":matrix_diag_op"],
)
tf_kernel_library(
name = "mirror_pad_op",
prefix = "mirror_pad_op",
@ -1405,7 +1382,7 @@ tf_kernel_library(
"where_op_gpu_impl_8.cu.cc",
],
deps = if_cuda_or_rocm([
":cuda_solvers",
"//tensorflow/core/util:cuda_solvers",
]) + [":gpu_prim_hdrs"] +
ARRAY_DEPS,
)
@ -2785,21 +2762,6 @@ tf_cuda_cc_tests(
],
)
tf_kernel_library(
name = "eye_functor",
hdrs = ["eye_functor.h"],
gpu_srcs = [
"eye_functor_gpu.cu.cc",
"eye_functor.h",
],
visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//third_party/eigen3",
],
alwayslink = 0,
)
cc_library(
name = "fifo_queue",
srcs = ["fifo_queue.cc"],
@ -3558,289 +3520,6 @@ tf_cc_tests(
],
)
cc_library(
name = "linalg",
deps = [
":banded_triangular_solve_op",
":cholesky_grad",
":cholesky_op",
":determinant_op",
":eig_op",
":einsum_op",
":lu_op",
":matrix_exponential_op",
":matrix_inverse_op",
":matrix_logarithm_op",
":matrix_solve_ls_op",
":matrix_solve_op",
":matrix_square_root_op",
":matrix_triangular_solve_op",
":qr_op",
":self_adjoint_eig_op",
":self_adjoint_eig_v2_op",
":svd_op",
":tridiagonal_matmul_op",
":tridiagonal_solve_op",
],
)
tf_kernel_library(
name = "cuda_solvers",
srcs = ["cuda_solvers.cc"],
hdrs = ["cuda_solvers.h"],
# @local_config_cuda//cuda:cusolver_static, //third_party/eigen3:blas,
# and //third_party/libf2c all contain various parts of BLAS, LAPACK,
# and f2c helper functions in global namespace. Tell the compiler to
# allow multiple definitions when linking this.
linkopts = select({
"//tensorflow:macos": [],
"//tensorflow:windows": [],
"//conditions:default": ["-Wl,-z,muldefs"],
}),
visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/stream_executor/cuda:cublas_lib",
"//tensorflow/stream_executor/cuda:cusolver_lib",
],
)
tf_kernel_library(
name = "rocm_solvers",
srcs = ["rocm_solvers.cc"],
hdrs = ["rocm_solvers.h"],
visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader",
"//tensorflow/stream_executor/rocm:rocblas_plugin",
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
] + if_rocm([
"@local_config_rocm//rocm:rocprim",
]),
)
tf_kernel_library(
name = "cuda_sparse",
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
hdrs = ["cuda_sparse.h"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:cuda_solvers",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cusparse_lib",
"@cub_archive//:cub",
]) + if_rocm([
"@local_config_rocm//rocm:hipsparse",
]),
)
LINALG_DEPS = [
":linalg_ops_common",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
] + if_cuda([
":cuda_solvers",
":transpose_functor",
]) + if_rocm([
":rocm_solvers",
])
tf_kernel_library(
name = "cholesky_op",
prefix = "cholesky_op",
deps = if_cuda([
":matrix_band_part_op",
]) + LINALG_DEPS,
)
tf_kernel_library(
name = "cholesky_grad",
prefix = "cholesky_grad",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "determinant_op",
prefix = "determinant_op",
deps = if_cuda([
":fill_functor",
]) + LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_exponential_op",
prefix = "matrix_exponential_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_logarithm_op",
prefix = "matrix_logarithm_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "self_adjoint_eig_op",
prefix = "self_adjoint_eig_op",
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"],
)
tf_kernel_library(
name = "self_adjoint_eig_v2_op",
prefix = "self_adjoint_eig_v2_op",
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
":cast_op",
":cwise_op",
]),
)
tf_kernel_library(
name = "eig_op",
prefix = "eig_op",
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
":cast_op",
":cwise_op",
]),
)
tf_kernel_library(
name = "matrix_inverse_op",
prefix = "matrix_inverse_op",
deps = LINALG_DEPS + if_cuda([":eye_functor"]),
)
tf_kernel_library(
name = "matrix_solve_ls_op",
prefix = "matrix_solve_ls_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_solve_op",
prefix = "matrix_solve_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_square_root_op",
prefix = "matrix_square_root_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "banded_triangular_solve_op",
prefix = "banded_triangular_solve_op",
deps = LINALG_DEPS + [":fill_functor"],
)
tf_kernel_library(
name = "matrix_triangular_solve_op",
hdrs = ["matrix_triangular_solve_op_impl.h"],
prefix = "matrix_triangular_solve_op",
deps = [
":linalg_ops_common",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
":fill_functor",
"//tensorflow/core:stream_executor",
] + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
":cuda_solvers",
]) + if_rocm([
"@local_config_rocm//rocm:rocprim",
":rocm_solvers",
]) + if_cuda_or_rocm([
":transpose_functor",
]),
)
tf_kernel_library(
name = "tridiagonal_matmul_op",
srcs = ["tridiagonal_matmul_op.cc"],
gpu_srcs = ["tridiagonal_matmul_op_gpu.cu.cc"],
deps = LINALG_DEPS + if_cuda([
":cuda_sparse",
]),
)
tf_kernel_library(
name = "tridiagonal_solve_op",
srcs = ["tridiagonal_solve_op.cc"],
gpu_srcs = ["tridiagonal_solve_op_gpu.cu.cc"],
deps = LINALG_DEPS + if_cuda([
":cuda_sparse",
]),
)
tf_kernel_library(
name = "qr_op",
prefix = "qr_op",
deps = LINALG_DEPS + if_cuda([
":cwise_op",
":eye_functor",
":matrix_band_part_op",
]),
)
tf_kernel_library(
name = "svd_op",
prefix = "svd_op",
deps = LINALG_DEPS + if_cuda([
":eye_functor",
]),
)
tf_kernel_library(
name = "lu_op",
prefix = "lu_op",
deps = if_cuda([
":cuda_solvers",
":transpose_functor",
]) + [
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_kernel_library(
name = "einsum_op",
prefix = "einsum_op",
deps = [
":batch_matmul_op",
":fill_functor",
":reduction_ops",
":transpose_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",
"//third_party/eigen3",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "linalg_ops_common",
srcs = ["linalg_ops_common.cc"],
hdrs = ["linalg_ops_common.h"],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
)
cc_library(
name = "logging",
deps = [
@ -4208,7 +3887,7 @@ tf_kernel_library(
name = "segment_reduction_ops",
prefix = "segment_reduction_ops",
deps = MATH_DEPS + if_cuda_or_rocm([
":cuda_solvers",
"//tensorflow/core/util:cuda_solvers",
]),
)
@ -4405,45 +4084,6 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "banded_triangular_solve_op_test",
size = "small",
srcs = ["banded_triangular_solve_op_test.cc"],
deps = [
":banded_triangular_solve_op",
":matrix_set_diag_op",
":matrix_triangular_solve_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test(
name = "matrix_triangular_solve_op_test",
size = "small",
srcs = ["matrix_triangular_solve_op_test.cc"],
deps = [
":broadcast_to_op",
":matrix_triangular_solve_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cuda_cc_test(
name = "scan_ops_test",
size = "small",
@ -6672,10 +6312,7 @@ filegroup(
"lookup_table_init_op.h",
"lookup_table_op.h",
"lookup_util.h",
"linalg_ops_common.h",
"list_kernels.h",
"matrix_diag_op.h",
"matrix_set_diag_op.h",
"maxpooling_op.h",
"mfcc.h",
"mfcc_dct.h",
@ -6723,6 +6360,9 @@ filegroup(
"xent_op.h",
] + [
"//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles_hdrs",
"//tensorflow/core/kernels/linalg:linalg_ops_common.h",
"//tensorflow/core/kernels/linalg:matrix_diag_op.h",
"//tensorflow/core/kernels/linalg:matrix_set_diag_op.h",
],
)
@ -6823,16 +6463,6 @@ filegroup(
"encode_wav_op.cc",
"eigen_contraction_kernel.cc",
"eigen_contraction_kernel.h",
"einsum_op_impl_half.cc",
"einsum_op_impl_bfloat16.cc",
"einsum_op_impl_int32.cc",
"einsum_op_impl_int64.cc",
"einsum_op_impl_float.cc",
"einsum_op_impl_double.cc",
"einsum_op_impl_complex64.cc",
"einsum_op_impl_complex128.cc",
"einsum_op_impl.h",
"einsum_op.h",
"fake_quant_ops.cc",
"fifo_queue.cc",
"fifo_queue_op.cc",
@ -6844,6 +6474,17 @@ filegroup(
"population_count_op.h",
"winograd_transform.h",
":android_extended_ops_headers",
] + [
"//tensorflow/core/kernels/linalg:einsum_op_impl_half.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl_bfloat16.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl_int32.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl_int64.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl_float.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl_double.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl_complex64.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl_complex128.cc",
"//tensorflow/core/kernels/linalg:einsum_op_impl.h",
"//tensorflow/core/kernels/linalg:einsum_op.h",
] + select({
":xsmm_convolutions": [
"xsmm_conv2d.h",
@ -6874,7 +6515,6 @@ filegroup(
"in_topk_op.cc",
"in_topk_op.h",
"initializable_lookup_table.cc",
"linalg_ops_common.cc",
"list_kernels.cc",
"logging_ops.cc",
"logging_ops.h",
@ -6882,9 +6522,6 @@ filegroup(
"lookup_table_op.cc",
"lookup_util.cc",
"lrn_op.cc",
"matrix_diag_op.cc",
"matrix_inverse_op.cc",
"matrix_set_diag_op.cc",
"maxpooling_op.cc",
"mfcc.cc",
"mfcc_dct.cc",
@ -7006,6 +6643,10 @@ filegroup(
":android_extended_ops_headers",
] + [
"//tensorflow/core/kernels/boosted_trees:quantile_ops.cc",
"//tensorflow/core/kernels/linalg:linalg_ops_common.cc",
"//tensorflow/core/kernels/linalg:matrix_diag_op.cc",
"//tensorflow/core/kernels/linalg:matrix_inverse_op.cc",
"//tensorflow/core/kernels/linalg:matrix_set_diag_op.cc",
],
)
@ -7059,6 +6700,7 @@ filegroup(
srcs = [
"//tensorflow/c/kernels:android_all_op_kernels",
"//tensorflow/core/kernels/data:android_all_op_kernels",
"//tensorflow/core/kernels/linalg:android_all_op_kernels",
] + glob(
[
"*.cc",
@ -8827,3 +8469,15 @@ tf_kernel_library(
"@sobol_data",
],
)
# ---- temporary forwarding declaration for libraries in linalg
# TODO(b/160344057): Remove after updating dependencies.
tf_kernel_library(
name = "matrix_inverse_op",
deps = ["//tensorflow/core/kernels/linalg:matrix_inverse_op"],
)
tf_kernel_library(
name = "einsum_op",
deps = ["//tensorflow/core/kernels/linalg:einsum_op"],
)

View File

@ -0,0 +1,376 @@
load(
"//tensorflow:tensorflow.bzl",
"if_cuda_or_rocm",
"tf_kernel_library",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm",
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
# Description:
# Op kernel implementations for TensorFlow.
#
# Note: Any test that uses GPU support and which we would like to
# benchmark should be linked statically so that it can be executed
# from a py_binary or cuda_py_test test logger. For such a test,
# append "_gpu" to the test name to invoke the GPU benchmarks. Example:
#
# # for CPU tests
# $ bazel test --config opt //third_party/tensorflow/core/kernels:my_op_test
# # for GPU benchmarks
# $ bazel run --config opt --config=cuda //third_party/tensorflow/core/kernels:my_op_test_gpu -- --benchmarks=..
#
package(
default_visibility = [
"//tensorflow:__subpackages__",
"//tensorflow:internal",
],
licenses = ["notice"], # Apache 2.0
)
# TODO(rmlarsen): Remove ASAP.
package_group(
name = "friends",
packages = ["//tensorflow/..."],
)
# Export a few files for use on Android.
exports_files([
"einsum_op_impl_half.cc",
"einsum_op_impl_bfloat16.cc",
"einsum_op_impl_int32.cc",
"einsum_op_impl_int64.cc",
"einsum_op_impl_float.cc",
"einsum_op_impl_double.cc",
"einsum_op_impl_complex64.cc",
"einsum_op_impl_complex128.cc",
"einsum_op_impl.h",
"einsum_op.h",
"linalg_ops_common.h",
"linalg_ops_common.cc",
"matrix_diag_op.h",
"matrix_diag_op.cc",
"matrix_inverse_op.cc",
"matrix_set_diag_op.h",
"matrix_set_diag_op.cc",
])
# Public support libraries ----------------------------------------------------
cc_library(
name = "linalg",
deps = [
":banded_triangular_solve_op",
":cholesky_grad",
":cholesky_op",
":determinant_op",
":eig_op",
":einsum_op",
":lu_op",
":matrix_band_part_op",
":matrix_diag_op",
":matrix_exponential_op",
":matrix_inverse_op",
":matrix_logarithm_op",
":matrix_set_diag_op",
":matrix_solve_ls_op",
":matrix_solve_op",
":matrix_square_root_op",
":matrix_triangular_solve_op",
":qr_op",
":self_adjoint_eig_op",
":self_adjoint_eig_v2_op",
":svd_op",
":tridiagonal_matmul_op",
":tridiagonal_solve_op",
],
)
LINALG_DEPS = [
":linalg_ops_common",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:fill_functor",
] + if_cuda([
":eye_functor",
"//tensorflow/core/util:cuda_solvers",
"//tensorflow/core/kernels:transpose_functor",
]) + if_rocm([
"//tensorflow/core/util:rocm_solvers",
])
tf_kernel_library(
name = "matrix_band_part_op",
prefix = "matrix_band_part_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_diag_op",
prefix = "matrix_diag_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_set_diag_op",
prefix = "matrix_set_diag_op",
deps = LINALG_DEPS + [":matrix_diag_op"],
)
tf_kernel_library(
name = "cholesky_op",
prefix = "cholesky_op",
deps = if_cuda([
":matrix_band_part_op",
]) + LINALG_DEPS,
)
tf_kernel_library(
name = "cholesky_grad",
prefix = "cholesky_grad",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "determinant_op",
prefix = "determinant_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_exponential_op",
prefix = "matrix_exponential_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_logarithm_op",
prefix = "matrix_logarithm_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "self_adjoint_eig_op",
prefix = "self_adjoint_eig_op",
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"],
)
tf_kernel_library(
name = "self_adjoint_eig_v2_op",
prefix = "self_adjoint_eig_v2_op",
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
"//tensorflow/core/kernels:cwise_op",
]),
)
tf_kernel_library(
name = "eig_op",
prefix = "eig_op",
deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
"//tensorflow/core/kernels:cwise_op",
]),
)
tf_kernel_library(
name = "matrix_inverse_op",
prefix = "matrix_inverse_op",
visibility = [":friends"],
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_solve_ls_op",
prefix = "matrix_solve_ls_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_solve_op",
prefix = "matrix_solve_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_square_root_op",
prefix = "matrix_square_root_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "banded_triangular_solve_op",
prefix = "banded_triangular_solve_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "matrix_triangular_solve_op",
hdrs = ["matrix_triangular_solve_op_impl.h"],
prefix = "matrix_triangular_solve_op",
deps = [
":linalg_ops_common",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:fill_functor",
"//tensorflow/core:stream_executor",
] + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/core/util:cuda_solvers",
]) + if_rocm([
"@local_config_rocm//rocm:rocprim",
"//tensorflow/core/util:rocm_solvers",
]) + if_cuda_or_rocm([
"//tensorflow/core/kernels:transpose_functor",
]),
)
tf_kernel_library(
name = "tridiagonal_matmul_op",
srcs = ["tridiagonal_matmul_op.cc"],
gpu_srcs = ["tridiagonal_matmul_op_gpu.cu.cc"],
deps = LINALG_DEPS + if_cuda([
"//tensorflow/core/util:cuda_sparse",
]),
)
tf_kernel_library(
name = "tridiagonal_solve_op",
srcs = ["tridiagonal_solve_op.cc"],
gpu_srcs = ["tridiagonal_solve_op_gpu.cu.cc"],
deps = LINALG_DEPS + if_cuda([
"//tensorflow/core/util:cuda_sparse",
]),
)
tf_kernel_library(
name = "qr_op",
prefix = "qr_op",
deps = LINALG_DEPS + if_cuda([
"//tensorflow/core/kernels:cwise_op",
":matrix_band_part_op",
]),
)
tf_kernel_library(
name = "svd_op",
prefix = "svd_op",
deps = LINALG_DEPS,
)
tf_kernel_library(
name = "lu_op",
prefix = "lu_op",
deps = if_cuda([
"//tensorflow/core/util:cuda_solvers",
"//tensorflow/core/kernels:transpose_functor",
]) + [
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_kernel_library(
name = "einsum_op",
prefix = "einsum_op",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/kernels:batch_matmul_op",
"//tensorflow/core/kernels:fill_functor",
"//tensorflow/core/kernels:reduction_ops",
"//tensorflow/core/kernels:transpose_functor",
"//tensorflow/core/profiler/lib:traceme",
"//third_party/eigen3",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "linalg_ops_common",
srcs = ["linalg_ops_common.cc"],
hdrs = ["linalg_ops_common.h"],
visibility = ["//visibility:private"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
)
tf_cuda_cc_test(
name = "banded_triangular_solve_op_test",
size = "small",
srcs = ["banded_triangular_solve_op_test.cc"],
deps = [
":banded_triangular_solve_op",
":matrix_set_diag_op",
":matrix_triangular_solve_op",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
],
)
tf_kernel_library(
name = "eye_functor",
hdrs = ["eye_functor.h"],
gpu_srcs = [
"eye_functor_gpu.cu.cc",
"eye_functor.h",
],
visibility = ["//tensorflow/core/kernels:friends"],
deps = [
"//tensorflow/core:framework",
"//third_party/eigen3",
],
alwayslink = 0,
)
tf_cuda_cc_test(
name = "matrix_triangular_solve_op_test",
size = "small",
srcs = ["matrix_triangular_solve_op_test.cc"],
deps = [
":matrix_triangular_solve_op",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:broadcast_to_op",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
],
)
# A file group which contains all operators which are known to work on mobile.
filegroup(
name = "android_all_op_kernels",
srcs = glob(
[
"*.cc",
"*.h",
],
exclude = [
"*test.cc",
"*test.h",
"*_test_*",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
#include "tensorflow/core/kernels/linalg/matrix_set_diag_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
namespace tensorflow {

View File

@ -25,16 +25,16 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/cuda_solvers.h"
#endif
namespace tensorflow {

View File

@ -20,7 +20,7 @@ limitations under the License.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/determinant_op.h"
#include "tensorflow/core/kernels/linalg/determinant_op.h"
#endif
#include "third_party/eigen3/Eigen/LU"
@ -28,14 +28,14 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/util/cuda_solvers.h"
#endif
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
#define TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_
#include "tensorflow/core/framework/tensor_types.h"
@ -44,4 +44,4 @@ struct LogDeterminantFromPivotedLUFunctor {
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_DETERMINANT_OP_H_

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/determinant_op.h"
#include "tensorflow/core/kernels/linalg/determinant_op.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/eig_op_impl.h"
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/eig_op_impl.h"
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/eig_op_impl.h"
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/eig_op_impl.h"
#include "tensorflow/core/kernels/linalg/eig_op_impl.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_
// See docs in ../ops/linalg_ops.cc.
@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
@ -95,4 +95,4 @@ class EigOp : public LinearAlgebraOp<InputScalar, OutputScalar> {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_EIG_OP_IMPL_H_

View File

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
#define TENSORFLOW_CORE_KERNELS_EINSUM_OP_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"

View File

@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/einsum_op.h"
#include "tensorflow/core/kernels/linalg/einsum_op.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_EINSUM_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_EINSUM_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@ -31,8 +31,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
#include "tensorflow/core/kernels/einsum_op.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg/einsum_op.h"
#include "tensorflow/core/kernels/reduction_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
@ -780,4 +780,4 @@ DECLARE_GPU_SPECS(complex128);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_EINSUM_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_EINSUM_OP_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/einsum_op_impl.h"
#include "tensorflow/core/kernels/linalg/einsum_op_impl.h"
namespace tensorflow {

View File

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
#define TENSORFLOW_CORE_KERNELS_EYE_FUNCTOR_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_EYE_FUNCTOR_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_EYE_FUNCTOR_H_
#include "tensorflow/core/framework/tensor_types.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/kernels/eye_functor.h"
#include "tensorflow/core/kernels/linalg/eye_functor.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include <utility>

View File

@ -0,0 +1,221 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_
// Classes to support linear algebra functionality, similar to the numpy.linalg
// module. Supports batch computation on several matrices at once, sharding the
// computations across different threads if necessary.
#include <algorithm>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
// Base class for linear algebra operators.
template <class InputScalar, class OutputScalar = InputScalar>
class LinearAlgebraOp : public OpKernel {
public:
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override;
protected:
using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
// Returns the number of leading inputs that are to be treated as matrix
// inputs. By default this is all the inputs. Derived classes can override
// this to tell the base class to ignore one or more trailing inputs.
virtual int NumMatrixInputs(const OpKernelContext* context) const {
return context->num_inputs();
}
// Returns true if the number of inputs and their shapes are as expected.
// Many ops take a single square input matrix, so we provide that as a default
// implementation for convenience.
virtual void ValidateInputMatrixShapes(
OpKernelContext* context, const TensorShapes& input_matrix_shapes) const {
ValidateSingleSquareMatrix(context, input_matrix_shapes);
}
// Convenience validators for common cases:
//
// Validate op taking a single matrix A.
static void ValidateSingleMatrix(OpKernelContext* context,
const TensorShapes& input_matrix_shapes);
// Validate op taking a single square matrix A.
static void ValidateSingleSquareMatrix(
OpKernelContext* context, const TensorShapes& input_matrix_shapes);
// Validate op taking two matrices A and B that have the same number of rows.
static void ValidateSolver(OpKernelContext* context,
const TensorShapes& input_matrix_shapes);
// Validate op taking two matrices A and B that have the same number of rows
// and A is square.
static void ValidateSquareSolver(OpKernelContext* context,
const TensorShapes& input_matrix_shapes);
// Returns the output shapes of each individual matrix operation. Output
// matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
//
// The derived class may return a number of shapes (N) less than
// context->num_outputs() (M) to indicate that a only leading subset of
// the outputs will be populated. In this case, a dummy scalar tensor with
// value zero will be return for the last M-N outputs.
//
// For many ops, the output dimensions are the same as the input dimensions,
// so we provide that as a default implementation for convenience.
virtual TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const {
return input_matrix_shapes;
}
// Returns the cost per matrix operation. This is used to determine the
// number of threads to use for parallelizing calls to ComputeMatrix in
// batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
// in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
// * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
// default implementation for convenience.
virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
double cost = std::max(m, n) * std::min(m, n) * std::min(m, n);
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
// Returns true if it is safe to forward (alias) input to output buffer
// and expect the kernel to perform the computation inplace.
virtual bool EnableInputForwarding() const { return true; }
using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>;
using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
using InputMatrixMap = Eigen::Map<InputMatrix>;
using InputConstVectorMap =
Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
Eigen::Dynamic, Eigen::RowMajor>;
using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
using OutputMatrixMap = Eigen::Map<OutputMatrix>;
using OutputConstVectorMap =
Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
// backward compatibility
using Scalar = OutputScalar;
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>;
using ConstVectorMap =
Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>;
using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>;
using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
// Performs a single matrix computation given input matrices, and
// stores the result in outputs. For batch operations, this will be called
// repeatedly for a single call to Compute() when multiple matrices exist in
// input Tensors with rank > 2. In this case the calls to ComputeMatrix are
// parallelized. The number of threads used is determined by a cost model from
// the value returned by GetCostPerUnit().
virtual void ComputeMatrix(OpKernelContext* context,
const InputConstMatrixMaps& inputs,
OutputMatrixMaps* outputs) = 0;
private:
using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
// This function maps 2-d slices (matrices) of the input and output tensors
// using Eigen::Map and calls ComputeMatrix implemented in terms of the
// Eigen::MatrixBase API by the derived class.
//
// The 'matrix_index' parameter specifies the index of the matrix to be used
// from each input tensor, and the index of the matrix to be written to each
// output tensor. The input matrices are in row major order, and located at
// the memory addresses
// inputs[i].flat<Scalar>().data() +
// matrix_index * input_matrix_shapes[i].num_elements()
// for i in 0...inputs.size()-1.
// The output matrices are in row major order, and located at the memory
// address
// outputs[i]->flat<Scalar>().data() +
// matrix_index * output_matrix_shapes[i].num_elements().
// for i in 0...outputs.size()-1.
//
void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
const TensorInputs& inputs,
const TensorShapes& input_matrix_shapes,
const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes);
void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
TensorShapes* input_matrix_shapes,
TensorShape* batch_shape);
void PrepareOutputs(OpKernelContext* context,
const TensorShapes& input_matrix_shapes,
const TensorShape& batch_shape, TensorOutputs* outputs,
TensorShapes* output_matrix_shapes);
};
// Declare LinearAlgebraOp, which is explicitly instantiated in
// linalg_ops_common.cc for float, double, complex64, and complex128.
extern template class LinearAlgebraOp<float>;
extern template class LinearAlgebraOp<double>;
extern template class LinearAlgebraOp<complex64>;
extern template class LinearAlgebraOp<complex128>;
} // namespace tensorflow
#define INHERIT_LINALG_TYPEDEFS(Scalar) \
typedef LinearAlgebraOp<Scalar> Base; \
using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \
using Matrix = typename Base::Matrix; \
using MatrixMap = typename Base::MatrixMap; \
using MatrixMaps = typename Base::MatrixMaps; \
using ConstMatrixMap = typename Base::ConstMatrixMap; \
using ConstMatrixMaps = typename Base::ConstMatrixMaps; \
using ConstVectorMap = typename Base::ConstVectorMap; \
using TensorShapes = typename Base::TensorShapes;
#define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
#define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
// Deprecated, use one of the device-specific macros above.
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
#endif // TENSORFLOW_CORE_KERNELS_LINALG_LINALG_OPS_COMMON_H_

View File

@ -25,9 +25,9 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -21,11 +21,12 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
#define TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@ -34,4 +34,4 @@ struct MatrixBandPartFunctor {
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_BAND_PART_OP_H_

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -20,7 +20,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/matrix_diag_op.h"
#include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
#include <algorithm>
#include <memory>

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
#define TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_
// Generator definition for MatrixDiagOp, must be compilable by nvcc.
@ -69,4 +69,4 @@ struct MatrixDiag {
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_DIAG_OP_H_

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/matrix_diag_op.h"
#include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -32,9 +32,9 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/eye_functor.h"
#include "tensorflow/core/kernels/linalg/eye_functor.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/util/cuda_solvers.h"
#endif
namespace tensorflow {

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"

View File

@ -21,7 +21,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
#include "tensorflow/core/kernels/linalg/matrix_set_diag_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@ -30,7 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/matrix_diag_op.h"
#include "tensorflow/core/kernels/linalg/matrix_diag_op.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
#define TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@ -39,4 +39,4 @@ struct MatrixSetDiag {
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SET_DIAG_OP_H_

View File

@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/matrix_set_diag_op.h"
#include "tensorflow/core/kernels/linalg/matrix_set_diag_op.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_solve_ls_op_impl.h"
#include "tensorflow/core/kernels/linalg/matrix_solve_ls_op_impl.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_
// See docs in ../ops/linalg_ops.cc.
@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@ -163,4 +163,4 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_SOLVE_LS_OP_IMPL_H_

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -33,8 +33,8 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/util/cuda_solvers.h"
#endif
namespace tensorflow {

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"

View File

@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
#include "tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h"
namespace tensorflow {

View File

@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/linalg_ops.cc.
//
#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -38,9 +38,9 @@ limitations under the License.
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/util/cuda_solvers.h"
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/rocm_solvers.h"
#include "tensorflow/core/util/rocm_solvers.h"
#endif
namespace tensorflow {
@ -434,4 +434,4 @@ struct LaunchBatchMatrixTriangularSolve<GPUDevice, Scalar> {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_MATRIX_TRIANGULAR_SOLVE_OP_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/matrix_triangular_solve_op_impl.h"
#include "tensorflow/core/kernels/linalg/matrix_triangular_solve_op_impl.h"
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/qr_op_impl.h"
#include "tensorflow/core/kernels/linalg/qr_op_impl.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_
// See docs in ../ops/linalg_ops.cc.
//
@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -41,11 +41,11 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/eye_functor.h"
#include "tensorflow/core/kernels/matrix_band_part_op.h"
#include "tensorflow/core/kernels/linalg/eye_functor.h"
#include "tensorflow/core/kernels/linalg/matrix_band_part_op.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/util/cuda_solvers.h"
#endif
namespace tensorflow {
@ -299,4 +299,4 @@ class QrOpGpu : public AsyncOpKernel {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_QR_OP_IMPL_H_

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h"
#include "tensorflow/core/kernels/linalg/self_adjoint_eig_v2_op_impl.h"
namespace tensorflow {

View File

@ -26,12 +26,12 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cwise_ops.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_solvers.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
// See docs in ../ops/linalg_ops.cc.
@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
@ -89,4 +89,4 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_SELF_ADJOINT_EIG_V2_OP_IMPL_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/svd_op_impl.h"
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/svd_op_impl.h"
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/svd_op_impl.h"
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
namespace tensorflow {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/svd_op_impl.h"
#include "tensorflow/core/kernels/linalg/svd_op_impl.h"
namespace tensorflow {

View File

@ -36,14 +36,14 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/eye_functor.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/eye_functor.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_
// See docs in ../ops/linalg_ops.cc.
//
@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -118,4 +118,4 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_SVD_OP_IMPL_H_

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"

View File

@ -22,11 +22,11 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#include "tensorflow/core/util/gpu_device_functions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"

View File

@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"

View File

@ -23,11 +23,11 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#include "tensorflow/core/util/gpu_device_functions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"

View File

@ -12,211 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
// Classes to support linear algebra functionality, similar to the numpy.linalg
// module. Supports batch computation on several matrices at once, sharding the
// computations across different threads if necessary.
#include <algorithm>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
// Base class for linear algebra operators.
template <class InputScalar, class OutputScalar = InputScalar>
class LinearAlgebraOp : public OpKernel {
public:
explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override;
protected:
using TensorShapes = gtl::InlinedVector<TensorShape, 4>;
// Returns the number of leading inputs that are to be treated as matrix
// inputs. By default this is all the inputs. Derived classes can override
// this to tell the base class to ignore one or more trailing inputs.
virtual int NumMatrixInputs(const OpKernelContext* context) const {
return context->num_inputs();
}
// Returns true if the number of inputs and their shapes are as expected.
// Many ops take a single square input matrix, so we provide that as a default
// implementation for convenience.
virtual void ValidateInputMatrixShapes(
OpKernelContext* context, const TensorShapes& input_matrix_shapes) const {
ValidateSingleSquareMatrix(context, input_matrix_shapes);
}
// Convenience validators for common cases:
//
// Validate op taking a single matrix A.
static void ValidateSingleMatrix(OpKernelContext* context,
const TensorShapes& input_matrix_shapes);
// Validate op taking a single square matrix A.
static void ValidateSingleSquareMatrix(
OpKernelContext* context, const TensorShapes& input_matrix_shapes);
// Validate op taking two matrices A and B that have the same number of rows.
static void ValidateSolver(OpKernelContext* context,
const TensorShapes& input_matrix_shapes);
// Validate op taking two matrices A and B that have the same number of rows
// and A is square.
static void ValidateSquareSolver(OpKernelContext* context,
const TensorShapes& input_matrix_shapes);
// Returns the output shapes of each individual matrix operation. Output
// matrices shapes must be rank 0, 1, or 2. Scalar outputs are rank 0.
//
// The derived class may return a number of shapes (N) less than
// context->num_outputs() (M) to indicate that a only leading subset of
// the outputs will be populated. In this case, a dummy scalar tensor with
// value zero will be return for the last M-N outputs.
//
// For many ops, the output dimensions are the same as the input dimensions,
// so we provide that as a default implementation for convenience.
virtual TensorShapes GetOutputMatrixShapes(
const TensorShapes& input_matrix_shapes) const {
return input_matrix_shapes;
}
// Returns the cost per matrix operation. This is used to determine the
// number of threads to use for parallelizing calls to ComputeMatrix in
// batch mode. Cost per unit is assumed to be roughly 1ns, based on comments
// in core/util/work_sharder.cc. Many linear algebra ops take roughly max(m,n)
// * min(m,n)^2, where the first input matrix is m-by-n. We provide that as a
// default implementation for convenience.
virtual int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const {
double m = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double n = static_cast<double>(input_matrix_shapes[0].dim_size(1));
double cost = std::max(m, n) * std::min(m, n) * std::min(m, n);
return cost >= static_cast<double>(kint64max) ? kint64max
: static_cast<int64>(cost);
}
// Returns true if it is safe to forward (alias) input to output buffer
// and expect the kernel to perform the computation inplace.
virtual bool EnableInputForwarding() const { return true; }
using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>;
using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
using InputMatrixMap = Eigen::Map<InputMatrix>;
using InputConstVectorMap =
Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
Eigen::Dynamic, Eigen::RowMajor>;
using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
using OutputMatrixMap = Eigen::Map<OutputMatrix>;
using OutputConstVectorMap =
Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
// backward compatibility
using Scalar = OutputScalar;
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>;
using ConstVectorMap =
Eigen::Map<const Eigen::Matrix<Scalar, 1, Eigen::Dynamic>>;
using ConstMatrixMaps = gtl::InlinedVector<ConstMatrixMap, 4>;
using MatrixMaps = gtl::InlinedVector<MatrixMap, 4>;
using RealScalar = typename Eigen::NumTraits<Scalar>::Real;
// Performs a single matrix computation given input matrices, and
// stores the result in outputs. For batch operations, this will be called
// repeatedly for a single call to Compute() when multiple matrices exist in
// input Tensors with rank > 2. In this case the calls to ComputeMatrix are
// parallelized. The number of threads used is determined by a cost model from
// the value returned by GetCostPerUnit().
virtual void ComputeMatrix(OpKernelContext* context,
const InputConstMatrixMaps& inputs,
OutputMatrixMaps* outputs) = 0;
private:
using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
using TensorOutputs = gtl::InlinedVector<Tensor*, 4>;
// This function maps 2-d slices (matrices) of the input and output tensors
// using Eigen::Map and calls ComputeMatrix implemented in terms of the
// Eigen::MatrixBase API by the derived class.
//
// The 'matrix_index' parameter specifies the index of the matrix to be used
// from each input tensor, and the index of the matrix to be written to each
// output tensor. The input matrices are in row major order, and located at
// the memory addresses
// inputs[i].flat<Scalar>().data() +
// matrix_index * input_matrix_shapes[i].num_elements()
// for i in 0...inputs.size()-1.
// The output matrices are in row major order, and located at the memory
// address
// outputs[i]->flat<Scalar>().data() +
// matrix_index * output_matrix_shapes[i].num_elements().
// for i in 0...outputs.size()-1.
//
void ComputeTensorSlice(OpKernelContext* context, int64 matrix_index,
const TensorInputs& inputs,
const TensorShapes& input_matrix_shapes,
const TensorOutputs& outputs,
const TensorShapes& output_matrix_shapes);
void AnalyzeInputs(OpKernelContext* context, TensorInputs* inputs,
TensorShapes* input_matrix_shapes,
TensorShape* batch_shape);
void PrepareOutputs(OpKernelContext* context,
const TensorShapes& input_matrix_shapes,
const TensorShape& batch_shape, TensorOutputs* outputs,
TensorShapes* output_matrix_shapes);
};
// Declare LinearAlgebraOp, which is explicitly instantiated in
// linalg_ops_common.cc for float, double, complex64, and complex128.
extern template class LinearAlgebraOp<float>;
extern template class LinearAlgebraOp<double>;
extern template class LinearAlgebraOp<complex64>;
extern template class LinearAlgebraOp<complex128>;
} // namespace tensorflow
#define INHERIT_LINALG_TYPEDEFS(Scalar) \
typedef LinearAlgebraOp<Scalar> Base; \
using RealScalar = typename Eigen::NumTraits<Scalar>::Real; \
using Matrix = typename Base::Matrix; \
using MatrixMap = typename Base::MatrixMap; \
using MatrixMaps = typename Base::MatrixMaps; \
using ConstMatrixMap = typename Base::ConstMatrixMap; \
using ConstMatrixMaps = typename Base::ConstMatrixMaps; \
using ConstVectorMap = typename Base::ConstVectorMap; \
using TensorShapes = typename Base::TensorShapes;
#define REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
#define REGISTER_LINALG_OP_GPU(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
Name(OpName).Device(DEVICE_GPU).TypeConstraint<Scalar>("T"), OpClass)
// Deprecated, use one of the device-specific macros above.
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
// Temporary forwarding header.
#include "tensorflow/core/kernels/linalg/linalg_ops_common.h"
#endif // TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_

View File

@ -45,13 +45,13 @@ limitations under the License.
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
using stream_executor::cuda::ScopedActivateExecutorContext;
#elif TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/platform/rocm.h"
#include "tensorflow/core/util/cuda_solvers.h"
using stream_executor::rocm::ScopedActivateExecutorContext;
#endif // GOOGLE_CUDA

View File

@ -80,8 +80,8 @@ tf_kernel_library(
"//tensorflow/core/kernels:transpose_functor",
"//tensorflow/core/kernels:gpu_prim_hdrs",
] + if_cuda_or_rocm([
"//tensorflow/core/kernels:cuda_solvers",
"//tensorflow/core/kernels:cuda_sparse",
"//tensorflow/core/util:cuda_solvers",
"//tensorflow/core/util:cuda_sparse",
]),
alwayslink = 1,
)

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/kernels/fill_functor.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -34,8 +34,8 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -35,8 +35,8 @@ limitations under the License.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
#if GOOGLE_CUDA

View File

@ -20,13 +20,13 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/kernels/gpu_device_array.h"
#include "tensorflow/core/kernels/gpu_device_array_gpu.h"
#include "tensorflow/core/kernels/gpu_prim.h"
#include "tensorflow/core/kernels/sparse/kernels.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_sparse.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
namespace tensorflow {

View File

@ -37,8 +37,8 @@ limitations under the License.
#include "tensorflow/core/platform/threadpool.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -29,8 +29,8 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -20,7 +20,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_sparse.h"
#define EIGEN_USE_GPU
#endif

View File

@ -36,8 +36,8 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "tensorflow/core/kernels/sparse/sparse_matrix.h"
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
namespace tensorflow {

View File

@ -33,8 +33,8 @@ limitations under the License.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include "tensorflow/core/util/cuda_sparse.h"
#endif
#if GOOGLE_CUDA

View File

@ -20,7 +20,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_sparse.h"
#define EIGEN_USE_GPU
#endif

View File

@ -39,7 +39,7 @@ limitations under the License.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/util/cuda_solvers.h"
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
using stream_executor::cuda::ScopedActivateExecutorContext;

View File

@ -14,6 +14,7 @@ load(
"tf_copts",
"tf_cuda_library",
"tf_cuda_only_cc_test",
"tf_kernel_library",
)
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
load(
@ -24,6 +25,11 @@ load(
"//tensorflow/core/platform:build_config_root.bzl",
"if_static",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm",
)
default_package_visibility = [
"//tensorflow/core:__subpackages__",
@ -567,6 +573,63 @@ cc_library(
],
)
tf_kernel_library(
name = "cuda_solvers",
srcs = ["cuda_solvers.cc"],
hdrs = ["cuda_solvers.h"],
# @local_config_cuda//cuda:cusolver_static, //third_party/eigen3:blas,
# and //third_party/libf2c all contain various parts of BLAS, LAPACK,
# and f2c helper functions in global namespace. Tell the compiler to
# allow multiple definitions when linking this.
linkopts = select({
"//tensorflow:macos": [],
"//tensorflow:windows": [],
"//conditions:default": ["-Wl,-z,muldefs"],
}),
visibility = ["//tensorflow/core/kernels:friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform/default/build_config:cublas_plugin",
"//tensorflow/stream_executor/cuda:cublas_lib",
"//tensorflow/stream_executor/cuda:cusolver_lib",
],
)
tf_kernel_library(
name = "rocm_solvers",
srcs = ["rocm_solvers.cc"],
hdrs = ["rocm_solvers.h"],
visibility = ["//tensorflow/core/kernels:friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform:dso_loader",
"//tensorflow/stream_executor/rocm:rocblas_plugin",
"//tensorflow/stream_executor/rocm:rocm_gpu_executor",
] + if_rocm([
"@local_config_rocm//rocm:rocprim",
]),
)
tf_kernel_library(
name = "cuda_sparse",
srcs = if_cuda(["cuda_sparse.cc"]) + if_rocm(["rocm_sparse.cc"]),
hdrs = ["cuda_sparse.h"],
deps = [
":cuda_solvers",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cusparse_lib",
"@cub_archive//:cub",
]) + if_rocm([
"@local_config_rocm//rocm:hipsparse",
]),
)
# Tests.
tf_cc_test(

View File

@ -14,7 +14,7 @@
==============================================================================
*/
#ifdef GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/util/cuda_solvers.h"
#include <chrono>
#include <complex>

View File

@ -14,8 +14,8 @@ limitations under the License.
==============================================================================
*/
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
#define TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_
// This header declares the class CudaSolver, which contains wrappers of linear
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
@ -435,7 +435,7 @@ class HostLapackInfo : public ScratchSpace<int> {
public:
HostLapackInfo(OpKernelContext* context, int64 size,
const std::string& debug_info)
: ScratchSpace<int>(context, size, debug_info, /* on_host */ true){};
: ScratchSpace<int>(context, size, debug_info, /* on_host */ true) {}
};
class DeviceLapackInfo : public ScratchSpace<int> {
@ -489,4 +489,4 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SOLVERS_H_

View File

@ -15,7 +15,7 @@ limitations under the License.
#ifdef GOOGLE_CUDA
#include "tensorflow/core/kernels/cuda_sparse.h"
#include "tensorflow/core/util/cuda_sparse.h"
#include <complex>
#include <memory>
@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@ -38,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_solvers.h"
// TODO(rmlarsen,penporn): Investigate using newer kernels in CUDA 10.1+.

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
#define TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
#ifndef TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
#define TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_
// This header declares the class GpuSparse, which contains wrappers of
// cuSparse libraries for use in TensorFlow kernels.
@ -75,8 +75,7 @@ using gpuStream_t = hipStream_t;
namespace tensorflow {
inline std::string ConvertGPUSparseErrorToString(
const gpusparseStatus_t status) {
inline string ConvertGPUSparseErrorToString(const gpusparseStatus_t status) {
switch (status) {
#define STRINGIZE(q) #q
#define RETURN_IF_STATUS(err) \
@ -206,49 +205,49 @@ class GpuSparse {
// Solves tridiagonal system of equations.
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2
template <typename Scalar>
Status Gtsv2(int m, int n, const Scalar *dl, const Scalar *d,
const Scalar *du, Scalar *B, int ldb, void *pBuffer) const;
Status Gtsv2(int m, int n, const Scalar* dl, const Scalar* d,
const Scalar* du, Scalar* B, int ldb, void* pBuffer) const;
// Computes the size of a temporary buffer used by Gtsv2.
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_bufferSize
template <typename Scalar>
Status Gtsv2BufferSizeExt(int m, int n, const Scalar *dl, const Scalar *d,
const Scalar *du, const Scalar *B, int ldb,
size_t *bufferSizeInBytes) const;
Status Gtsv2BufferSizeExt(int m, int n, const Scalar* dl, const Scalar* d,
const Scalar* du, const Scalar* B, int ldb,
size_t* bufferSizeInBytes) const;
// Solves tridiagonal system of equations without partial pivoting.
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot
template <typename Scalar>
Status Gtsv2NoPivot(int m, int n, const Scalar *dl, const Scalar *d,
const Scalar *du, Scalar *B, int ldb,
void *pBuffer) const;
Status Gtsv2NoPivot(int m, int n, const Scalar* dl, const Scalar* d,
const Scalar* du, Scalar* B, int ldb,
void* pBuffer) const;
// Computes the size of a temporary buffer used by Gtsv2NoPivot.
// See:
// https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2_nopivot_bufferSize
template <typename Scalar>
Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar *dl,
const Scalar *d, const Scalar *du,
const Scalar *B, int ldb,
size_t *bufferSizeInBytes) const;
Status Gtsv2NoPivotBufferSizeExt(int m, int n, const Scalar* dl,
const Scalar* d, const Scalar* du,
const Scalar* B, int ldb,
size_t* bufferSizeInBytes) const;
// Solves a batch of tridiagonal systems of equations. Doesn't support
// multiple right-hand sides per each system. Doesn't do pivoting.
// See: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch
template <typename Scalar>
Status Gtsv2StridedBatch(int m, const Scalar *dl, const Scalar *d,
const Scalar *du, Scalar *x, int batchCount,
int batchStride, void *pBuffer) const;
Status Gtsv2StridedBatch(int m, const Scalar* dl, const Scalar* d,
const Scalar* du, Scalar* x, int batchCount,
int batchStride, void* pBuffer) const;
// Computes the size of a temporary buffer used by Gtsv2StridedBatch.
// See:
// https://docs.nvidia.com/cuda/cusparse/index.html#gtsv2stridedbatch_bufferSize
template <typename Scalar>
Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar *dl,
const Scalar *d, const Scalar *du,
const Scalar *x, int batchCount,
Status Gtsv2StridedBatchBufferSizeExt(int m, const Scalar* dl,
const Scalar* d, const Scalar* du,
const Scalar* x, int batchCount,
int batchStride,
size_t *bufferSizeInBytes) const;
size_t* bufferSizeInBytes) const;
// Compresses the indices of rows or columns. It can be interpreted as a
// conversion from COO to CSR sparse storage format. See:
@ -449,7 +448,7 @@ class GpuSparse {
private:
bool initialized_;
OpKernelContext *context_; // not owned.
OpKernelContext* context_; // not owned.
gpuStream_t gpu_stream_;
gpusparseHandle_t* gpusparse_handle_; // not owned.
@ -585,4 +584,4 @@ class GpuSparseCsrSortingConversionInfo {
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_SPARSE_H_
#endif // TENSORFLOW_CORE_KERNELS_LINALG_CUDA_SPARSE_H_

Some files were not shown because too many files have changed in this diff Show More