Merge pull request #39424 from yongtang:39405-tile-uint32-uint64

PiperOrigin-RevId: 313291326
Change-Id: I99b509b3c6337f84c2abef769e4decad699f9f70
This commit is contained in:
TensorFlower Gardener 2020-05-26 16:43:07 -07:00
commit 49c82c0929
5 changed files with 70 additions and 0 deletions

View File

@ -1338,6 +1338,8 @@ tf_kernel_library(
"tile_functor_cpu_int64.cc",
"tile_functor_cpu_int8.cc",
"tile_functor_cpu_tstring.cc",
"tile_functor_cpu_uint32.cc",
"tile_functor_cpu_uint64.cc",
"tile_functor_cpu_uint8.cc",
"tile_functor_cpu_variant.cc",
"tile_functor_sycl.cc",
@ -6911,6 +6913,8 @@ filegroup(
"tile_functor_cpu_int64.cc",
"tile_functor_cpu_int8.cc",
"tile_functor_cpu_tstring.cc",
"tile_functor_cpu_uint32.cc",
"tile_functor_cpu_uint64.cc",
"tile_functor_cpu_uint8.cc",
"tile_functor_cpu_variant.cc",
"tile_ops.cc",

View File

@ -0,0 +1,29 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/tile_functor_cpu.h"
namespace tensorflow {
namespace functor {
typedef Eigen::ThreadPoolDevice CPUDevice;
template struct Tile<CPUDevice, uint32, int32>;
template struct Tile<CPUDevice, uint32, int64>;
} // end namespace functor
} // end namespace tensorflow

View File

@ -0,0 +1,29 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/tile_functor_cpu.h"
namespace tensorflow {
namespace functor {
typedef Eigen::ThreadPoolDevice CPUDevice;
template struct Tile<CPUDevice, uint64, int32>;
template struct Tile<CPUDevice, uint64, int64>;
} // end namespace functor
} // end namespace tensorflow

View File

@ -139,6 +139,8 @@ TF_CALL_uint8(DECLARE_TYPE);
TF_CALL_int32(DECLARE_TYPE);
TF_CALL_int16(DECLARE_TYPE);
TF_CALL_int64(DECLARE_TYPE);
TF_CALL_uint32(DECLARE_TYPE);
TF_CALL_uint64(DECLARE_TYPE);
TF_CALL_half(DECLARE_TYPE);
TF_CALL_complex64(DECLARE_TYPE);
TF_CALL_complex128(DECLARE_TYPE);
@ -241,6 +243,8 @@ class TileOp : public OpKernel {
TF_CALL_int32(HANDLE_TYPE_NAME);
TF_CALL_int16(HANDLE_TYPE_NAME);
TF_CALL_int64(HANDLE_TYPE_NAME);
TF_CALL_uint32(HANDLE_TYPE_NAME);
TF_CALL_uint64(HANDLE_TYPE_NAME);
TF_CALL_half(HANDLE_TYPE_NAME);
TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice.
TF_CALL_complex64(HANDLE_TYPE_NAME);
@ -321,6 +325,8 @@ TF_CALL_int8(HANDLE_TYPE_NAME_CPU);
TF_CALL_int32(HANDLE_TYPE_NAME_CPU);
TF_CALL_int16(HANDLE_TYPE_NAME_CPU);
TF_CALL_int64(HANDLE_TYPE_NAME_CPU);
TF_CALL_uint32(HANDLE_TYPE_NAME_CPU);
TF_CALL_uint64(HANDLE_TYPE_NAME_CPU);
TF_CALL_half(HANDLE_TYPE_NAME_CPU);
TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);

View File

@ -500,6 +500,8 @@ class TileTest(test.TestCase, parameterized.TestCase):
"int16": (dtypes.int16, int),
"int32": (dtypes.int32, int),
"int64": (dtypes.int64, int),
"uint32": (dtypes.uint32, int),
"uint64": (dtypes.uint64, int),
bytes: (dtypes.string, bytes)
}
for dtype_np, (dtype_tf, cast) in types_to_test.items():