145 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			145 lines
		
	
	
		
			6.2 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ==============================================================================*/
 | |
| 
 | |
| #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
 | |
| 
 | |
| #define EIGEN_USE_THREADS
 | |
| 
 | |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 | |
| #include "tensorflow/compiler/xla/executable_run_options.h"
 | |
| #include "tensorflow/compiler/xla/service/cpu/runtime_lightweight_check.h"
 | |
| #include "tensorflow/core/platform/dynamic_annotations.h"
 | |
| #include "tensorflow/core/platform/types.h"
 | |
| 
 | |
| #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
 | |
| #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
 | |
| #endif
 | |
| 
 | |
| namespace {
 | |
| 
 | |
| bool Is16BytesAligned(void* ptr) {
 | |
|   return reinterpret_cast<uintptr_t>(ptr) % 16 == 0;
 | |
| }
 | |
| 
 | |
| template <typename T, Eigen::AlignmentType Alignment>
 | |
| void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs,
 | |
|             tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k,
 | |
|             tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) {
 | |
|   const xla::ExecutableRunOptions* run_options =
 | |
|       static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
 | |
| 
 | |
|   tensorflow::int64 lhs_rows = m;
 | |
|   tensorflow::int64 lhs_cols = k;
 | |
|   if (transpose_lhs) {
 | |
|     std::swap(lhs_rows, lhs_cols);
 | |
|   }
 | |
| 
 | |
|   tensorflow::int64 rhs_rows = k;
 | |
|   tensorflow::int64 rhs_cols = n;
 | |
|   if (transpose_rhs) {
 | |
|     std::swap(rhs_rows, rhs_cols);
 | |
|   }
 | |
| 
 | |
|   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> A(lhs, lhs_rows,
 | |
|                                                                  lhs_cols);
 | |
|   const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Alignment> B(rhs, rhs_rows,
 | |
|                                                                  rhs_cols);
 | |
|   Eigen::TensorMap<Eigen::Tensor<T, 2>, Alignment> C(out, m, n);
 | |
| 
 | |
|   typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
 | |
|   int lhs_contract_dim = transpose_lhs ? 0 : 1;
 | |
|   int rhs_contract_dim = transpose_rhs ? 1 : 0;
 | |
|   const Eigen::array<DimPair, 1> dims(
 | |
|       {DimPair(lhs_contract_dim, rhs_contract_dim)});
 | |
| 
 | |
|   // Matrix multiply is a special case of the "contract" operation where
 | |
|   // the contraction is performed along dimension 1 of the lhs and dimension
 | |
|   // 0 of the rhs.
 | |
|   XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr);
 | |
|   C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims);
 | |
| }
 | |
| 
 | |
| template <typename T>
 | |
| void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs,
 | |
|                     tensorflow::int64 m, tensorflow::int64 n,
 | |
|                     tensorflow::int64 k, tensorflow::int32 transpose_lhs,
 | |
|                     tensorflow::int32 transpose_rhs) {
 | |
|   bool all_buffers_16b_aligned =
 | |
|       Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs);
 | |
| 
 | |
|   if (!all_buffers_16b_aligned) {
 | |
|     MatMul<T, Eigen::Unaligned>(run_options_ptr, out, lhs, rhs, m, n, k,
 | |
|                                 transpose_lhs, transpose_rhs);
 | |
|     return;
 | |
|   }
 | |
| 
 | |
|   MatMul<T, Eigen::Aligned16>(run_options_ptr, out, lhs, rhs, m, n, k,
 | |
|                               transpose_lhs, transpose_rhs);
 | |
| }
 | |
| 
 | |
| }  // namespace
 | |
| 
 | |
| TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16(
 | |
|     const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs,
 | |
|     Eigen::half* rhs, tensorflow::int64 m, tensorflow::int64 n,
 | |
|     tensorflow::int64 k, tensorflow::int32 transpose_lhs,
 | |
|     tensorflow::int32 transpose_rhs) {
 | |
|   MatMulDispatch<Eigen::half>(run_options_ptr, out, lhs, rhs, m, n, k,
 | |
|                               transpose_lhs, transpose_rhs);
 | |
| }
 | |
| 
 | |
| TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32(
 | |
|     const void* run_options_ptr, float* out, float* lhs, float* rhs,
 | |
|     tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k,
 | |
|     tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) {
 | |
|   MatMulDispatch<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
 | |
|                         transpose_rhs);
 | |
| }
 | |
| 
 | |
| TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64(
 | |
|     const void* run_options_ptr, double* out, double* lhs, double* rhs,
 | |
|     tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k,
 | |
|     tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) {
 | |
|   MatMulDispatch<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
 | |
|                          transpose_rhs);
 | |
| }
 | |
| 
 | |
| TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64(
 | |
|     const void* run_options_ptr, std::complex<float>* out,
 | |
|     std::complex<float>* lhs, std::complex<float>* rhs, tensorflow::int64 m,
 | |
|     tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs,
 | |
|     tensorflow::int32 transpose_rhs) {
 | |
|   MatMulDispatch<std::complex<float>>(run_options_ptr, out, lhs, rhs, m, n, k,
 | |
|                                       transpose_lhs, transpose_rhs);
 | |
| }
 | |
| 
 | |
| TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128(
 | |
|     const void* run_options_ptr, std::complex<double>* out,
 | |
|     std::complex<double>* lhs, std::complex<double>* rhs, tensorflow::int64 m,
 | |
|     tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs,
 | |
|     tensorflow::int32 transpose_rhs) {
 | |
|   MatMulDispatch<std::complex<double>>(run_options_ptr, out, lhs, rhs, m, n, k,
 | |
|                                        transpose_lhs, transpose_rhs);
 | |
| }
 | |
| 
 | |
| TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32(
 | |
|     const void* run_options_ptr, tensorflow::int32* out, tensorflow::int32* lhs,
 | |
|     tensorflow::int32* rhs, tensorflow::int64 m, tensorflow::int64 n,
 | |
|     tensorflow::int64 k, tensorflow::int32 transpose_lhs,
 | |
|     tensorflow::int32 transpose_rhs) {
 | |
|   MatMulDispatch<tensorflow::int32>(run_options_ptr, out, lhs, rhs, m, n, k,
 | |
|                                     transpose_lhs, transpose_rhs);
 | |
| }
 |