Shard tile_ops.cc for cpu, to reduce compilation times.
Change: 130411083
This commit is contained in:
parent
9c56588283
commit
79b1489dcd
@ -7,6 +7,11 @@ tensorflow/core/kernels/transpose_functor_cpu.cc
|
||||
tensorflow/core/kernels/training_ops.cc
|
||||
tensorflow/core/kernels/topk_op.cc
|
||||
tensorflow/core/kernels/tile_ops.cc
|
||||
tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
|
||||
tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
|
||||
tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
|
||||
tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
|
||||
tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_6.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_5.cc
|
||||
tensorflow/core/kernels/strided_slice_op_inst_4.cc
|
||||
|
@ -1940,7 +1940,8 @@ filegroup(
|
||||
"save_restore_tensor.h",
|
||||
"softplus_op.h",
|
||||
"softsign_op.h",
|
||||
"tile_ops.h",
|
||||
"tile_ops_cpu_impl.h",
|
||||
"tile_ops_impl.h",
|
||||
"training_ops.h",
|
||||
"transpose_functor.h",
|
||||
"transpose_op.h",
|
||||
@ -2019,6 +2020,11 @@ filegroup(
|
||||
"stack_ops.cc",
|
||||
"summary_op.cc",
|
||||
"tile_ops.cc",
|
||||
"tile_ops_cpu_impl_1.cc",
|
||||
"tile_ops_cpu_impl_2.cc",
|
||||
"tile_ops_cpu_impl_3.cc",
|
||||
"tile_ops_cpu_impl_4.cc",
|
||||
"tile_ops_cpu_impl_5.cc",
|
||||
"topk_op.cc",
|
||||
"training_ops.cc",
|
||||
"transpose_functor_cpu.cc",
|
||||
|
@ -21,23 +21,70 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/kernels/tile_ops.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/type_index.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Forward declarations of functors that will be defined in
|
||||
// tile_ops_cpu_impl*.cc and tile_ops_gpu.cu.cc.
|
||||
namespace functor {
|
||||
template <typename Device, typename T, int NDIM>
|
||||
struct Tile {
|
||||
void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
|
||||
typename TTypes<T, NDIM>::ConstTensor in,
|
||||
const Eigen::array<int32, NDIM>& broadcast_array) const;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct Tile<Device, T, 0> {
|
||||
void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
|
||||
typename TTypes<T, 0>::ConstTensor in,
|
||||
const Eigen::array<int32, 0>&) const;
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int NDIM>
|
||||
struct TileGrad {
|
||||
void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
|
||||
typename TTypes<T, NDIM>::ConstTensor in,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes,
|
||||
bool first) const;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct TileGrad<Device, T, 0> {
|
||||
void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
|
||||
typename TTypes<T, 0>::ConstTensor in,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 0>&,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 0>&, bool first) const;
|
||||
};
|
||||
|
||||
template <typename Device, typename T, int NDIM, int REDUCEDNDIM>
|
||||
struct ReduceAndReshape {
|
||||
void operator()(
|
||||
const Device& d, typename TTypes<T, NDIM>::Tensor out,
|
||||
typename TTypes<T, NDIM>::ConstTensor in,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, REDUCEDNDIM>& reduce_dim,
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const;
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
template <typename Device>
|
||||
class TileOp : public OpKernel {
|
||||
@ -153,7 +200,7 @@ inline void TileOp<Device>::HandleCase(
|
||||
<< DataTypeString(DT) << ", " << NDIM;
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(device, dtype, ndim) \
|
||||
#define HANDLE_CASE(device, T, dtype, ndim) \
|
||||
template <> \
|
||||
template <> \
|
||||
void TileOp<device>::HandleCase<dtype, ndim>( \
|
||||
@ -163,15 +210,18 @@ inline void TileOp<Device>::HandleCase(
|
||||
}
|
||||
|
||||
// 0-D handled above
|
||||
#define HANDLE_CASE_DIM(device, dtype) \
|
||||
HANDLE_CASE(device, dtype, 1); \
|
||||
HANDLE_CASE(device, dtype, 2); \
|
||||
HANDLE_CASE(device, dtype, 3); \
|
||||
HANDLE_CASE(device, dtype, 4); \
|
||||
HANDLE_CASE(device, dtype, 5);
|
||||
#define HANDLE_CASE_DIM(device, T, dtype) \
|
||||
HANDLE_CASE(device, T, dtype, 1); \
|
||||
HANDLE_CASE(device, T, dtype, 2); \
|
||||
HANDLE_CASE(device, T, dtype, 3); \
|
||||
HANDLE_CASE(device, T, dtype, 4); \
|
||||
HANDLE_CASE(device, T, dtype, 5);
|
||||
|
||||
#define HANDLE_TYPE_NAME_CPU(T) \
|
||||
HANDLE_CASE_DIM(CPUDevice, DataTypeToEnum<T>::value);
|
||||
HANDLE_CASE_DIM(CPUDevice, T, DataTypeToEnum<T>::value);
|
||||
|
||||
#define HANDLE_TYPE_NAME_GPU(T) \
|
||||
HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value);
|
||||
|
||||
TF_CALL_bool(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_float(HANDLE_TYPE_NAME_CPU);
|
||||
@ -186,15 +236,16 @@ TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_string(HANDLE_TYPE_NAME_CPU);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_FLOAT);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_DOUBLE);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_INT16);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_INT32);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_INT64);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_HALF);
|
||||
TF_CALL_float(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_double(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_int16(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_int32(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_int64(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_half(HANDLE_TYPE_NAME_GPU);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef HANDLE_TYPE_NAME_CPU
|
||||
#undef HANDLE_TYPE_NAME_GPU
|
||||
#undef HANDLE_CASE_DIM
|
||||
#undef HANDLE_CASE
|
||||
|
||||
@ -385,7 +436,7 @@ inline void TileGradientOp<Device>::HandleCase(
|
||||
<< ", " << NDIM;
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(device, dtype, ndim) \
|
||||
#define HANDLE_CASE(device, T, dtype, ndim) \
|
||||
template <> \
|
||||
template <> \
|
||||
void TileGradientOp<device>::HandleCase<dtype, ndim>( \
|
||||
@ -395,15 +446,18 @@ inline void TileGradientOp<Device>::HandleCase(
|
||||
}
|
||||
|
||||
// 0-D handled specially above
|
||||
#define HANDLE_CASE_DIM(device, dtype) \
|
||||
HANDLE_CASE(device, dtype, 1); \
|
||||
HANDLE_CASE(device, dtype, 2); \
|
||||
HANDLE_CASE(device, dtype, 3); \
|
||||
HANDLE_CASE(device, dtype, 4); \
|
||||
HANDLE_CASE(device, dtype, 5);
|
||||
#define HANDLE_CASE_DIM(device, T, dtype) \
|
||||
HANDLE_CASE(device, T, dtype, 1); \
|
||||
HANDLE_CASE(device, T, dtype, 2); \
|
||||
HANDLE_CASE(device, T, dtype, 3); \
|
||||
HANDLE_CASE(device, T, dtype, 4); \
|
||||
HANDLE_CASE(device, T, dtype, 5);
|
||||
|
||||
#define HANDLE_TYPE_NAME_CPU(T) \
|
||||
HANDLE_CASE_DIM(CPUDevice, DataTypeToEnum<T>::value);
|
||||
HANDLE_CASE_DIM(CPUDevice, T, DataTypeToEnum<T>::value);
|
||||
|
||||
#define HANDLE_TYPE_NAME_GPU(T) \
|
||||
HANDLE_CASE_DIM(GPUDevice, T, DataTypeToEnum<T>::value);
|
||||
|
||||
TF_CALL_float(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_double(HANDLE_TYPE_NAME_CPU);
|
||||
@ -415,16 +469,16 @@ TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_FLOAT);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_DOUBLE);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_INT16);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_INT32);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_INT64);
|
||||
HANDLE_CASE_DIM(GPUDevice, DT_HALF);
|
||||
|
||||
TF_CALL_float(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_double(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_int16(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_int32(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_int64(HANDLE_TYPE_NAME_GPU);
|
||||
TF_CALL_half(HANDLE_TYPE_NAME_GPU);
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef HANDLE_TYPE_NAME_CPU
|
||||
#undef HANDLE_TYPE_NAME_GPU
|
||||
#undef HANDLE_CASE_DIM
|
||||
#undef HANDLE_CASE
|
||||
|
||||
@ -436,46 +490,6 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad")
|
||||
TileGradientOp<CPUDevice>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define DEFINE_GPU_TYPE(T) \
|
||||
DEFINE_GPU_DIM(T, 1) \
|
||||
DEFINE_GPU_DIM(T, 2) \
|
||||
DEFINE_GPU_DIM(T, 3) \
|
||||
DEFINE_GPU_DIM(T, 4) \
|
||||
DEFINE_GPU_DIM(T, 5)
|
||||
|
||||
#define DEFINE_GPU_DIM(T, NDIM) \
|
||||
template <> \
|
||||
void Tile<GPUDevice, T, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::array<int32, NDIM>& broadcast_array) const; \
|
||||
extern template struct Tile<GPUDevice, T, NDIM>; \
|
||||
template <> \
|
||||
void TileGrad<GPUDevice, T, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes, bool first) const; \
|
||||
extern template struct TileGrad<GPUDevice, T, NDIM>; \
|
||||
template <> \
|
||||
void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out, \
|
||||
typename TTypes<T, NDIM>::ConstTensor in, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, 1>& reduce_dim, \
|
||||
const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const; \
|
||||
extern template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>;
|
||||
|
||||
namespace functor {
|
||||
DEFINE_GPU_TYPE(float);
|
||||
DEFINE_GPU_TYPE(double);
|
||||
DEFINE_GPU_TYPE(int64);
|
||||
DEFINE_GPU_TYPE(int32);
|
||||
DEFINE_GPU_TYPE(int16);
|
||||
DEFINE_GPU_TYPE(Eigen::half);
|
||||
} // end namespace functor
|
||||
|
||||
#undef DEFINE_GPU_DIM
|
||||
#undef DEFINE_GPU_TYPE
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Tile")
|
||||
.Device(DEVICE_GPU)
|
||||
|
68
tensorflow/core/kernels/tile_ops_cpu_impl.h
Normal file
68
tensorflow/core/kernels/tile_ops_cpu_impl.h
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/tile_ops_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// Register functors used for TileOp.
|
||||
#define DEFINE_DIM(T, NDIM) template struct Tile<CPUDevice, T, NDIM>;
|
||||
#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
|
||||
|
||||
TF_CALL_bool(DEFINE_TYPE);
|
||||
TF_CALL_float(DEFINE_TYPE);
|
||||
TF_CALL_double(DEFINE_TYPE);
|
||||
TF_CALL_uint8(DEFINE_TYPE);
|
||||
TF_CALL_int32(DEFINE_TYPE);
|
||||
TF_CALL_int16(DEFINE_TYPE);
|
||||
TF_CALL_int64(DEFINE_TYPE);
|
||||
TF_CALL_half(DEFINE_TYPE);
|
||||
TF_CALL_complex64(DEFINE_TYPE);
|
||||
TF_CALL_complex128(DEFINE_TYPE);
|
||||
TF_CALL_string(DEFINE_TYPE);
|
||||
|
||||
#undef DEFINE_DIM
|
||||
#undef DEFINE_TYPE
|
||||
|
||||
// Register functors used for TileGradientOp.
|
||||
#define DEFINE_DIM(T, NDIM) \
|
||||
template struct TileGrad<CPUDevice, T, NDIM>; \
|
||||
template struct ReduceAndReshape<CPUDevice, T, NDIM, 1>;
|
||||
#define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM)
|
||||
|
||||
TF_CALL_float(DEFINE_TYPE);
|
||||
TF_CALL_double(DEFINE_TYPE);
|
||||
TF_CALL_int16(DEFINE_TYPE);
|
||||
TF_CALL_int32(DEFINE_TYPE);
|
||||
TF_CALL_int64(DEFINE_TYPE);
|
||||
TF_CALL_half(DEFINE_TYPE);
|
||||
TF_CALL_complex64(DEFINE_TYPE);
|
||||
TF_CALL_complex128(DEFINE_TYPE);
|
||||
|
||||
#undef DEFINE_DIM
|
||||
#undef DEFINE_TYPE
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TILE_OPS_CPU_IMPL_H_
|
18
tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
Normal file
18
tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 1
|
||||
#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
Normal file
18
tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 2
|
||||
#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
Normal file
18
tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 3
|
||||
#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
Normal file
18
tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 4
|
||||
#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
18
tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
Normal file
18
tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
Normal file
@ -0,0 +1,18 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define CPU_PROVIDED_IXDIM 5
|
||||
#include "tensorflow/core/kernels/tile_ops_cpu_impl.h"
|
||||
#undef CPU_PROVIDED_IXDIM
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/tile_ops.h"
|
||||
#include <stdio.h>
|
||||
#include "tensorflow/core/kernels/tile_ops_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_TILE_OPS_H_
|
||||
#define TENSORFLOW_KERNELS_TILE_OPS_H_
|
||||
#ifndef TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
|
||||
#define TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
@ -91,4 +91,4 @@ struct ReduceAndReshape {
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_TILE_OPS_H_
|
||||
#endif // TENSORFLOW_KERNELS_TILE_OPS_IMPL_H_
|
Loading…
Reference in New Issue
Block a user