Allow tile op to work on variant dtype.

PiperOrigin-RevId: 312382133
Change-Id: I3a0f95865ca0f782fa73f7ba55b3d987de006332
This commit is contained in:
A. Unique TensorFlower 2020-05-19 16:52:08 -07:00 committed by TensorFlower Gardener
parent db573482f4
commit f8a918ccf6
4 changed files with 63 additions and 0 deletions

View File

@ -1339,6 +1339,7 @@ tf_kernel_library(
"tile_functor_cpu_int8.cc",
"tile_functor_cpu_tstring.cc",
"tile_functor_cpu_uint8.cc",
"tile_functor_cpu_variant.cc",
"tile_functor_sycl.cc",
],
hdrs = ["tile_functor.h"],
@ -6907,6 +6908,7 @@ filegroup(
"tile_functor_cpu_int8.cc",
"tile_functor_cpu_tstring.cc",
"tile_functor_cpu_uint8.cc",
"tile_functor_cpu_variant.cc",
"tile_ops.cc",
"tile_ops_cpu_impl_1.cc",
"tile_ops_cpu_impl_2.cc",

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/tile_functor_cpu.h"
namespace tensorflow {
namespace functor {
typedef Eigen::ThreadPoolDevice CPUDevice;
template struct Tile<CPUDevice, Variant, int32>;
template struct Tile<CPUDevice, Variant, int64>;
} // end namespace functor
} // end namespace tensorflow

View File

@ -143,6 +143,7 @@ TF_CALL_half(DECLARE_TYPE);
TF_CALL_complex64(DECLARE_TYPE);
TF_CALL_complex128(DECLARE_TYPE);
TF_CALL_tstring(DECLARE_TYPE);
TF_CALL_variant(DECLARE_TYPE);
#undef DECLARE_TYPE
#define DECLARE_DIM(T, NDIM) \
@ -244,6 +245,7 @@ class TileOp : public OpKernel {
TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice.
TF_CALL_complex64(HANDLE_TYPE_NAME);
TF_CALL_complex128(HANDLE_TYPE_NAME);
TF_CALL_variant(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice
#undef HANDLE_TYPE_NAME
#undef HANDLE_TYPE
@ -323,6 +325,7 @@ TF_CALL_half(HANDLE_TYPE_NAME_CPU);
TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
TF_CALL_tstring(HANDLE_TYPE_NAME_CPU);
TF_CALL_variant(HANDLE_TYPE_NAME_CPU);
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TF_CALL_bool(HANDLE_TYPE_NAME_GPU);

View File

@ -42,6 +42,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@ -1994,5 +1995,32 @@ class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(v_tf_fn, v_np)
@test_util.run_all_in_graph_and_eager_modes
class TileVariantTest(test_util.TensorFlowTestCase):
def test_tile_tensor_list(self):
t = constant_op.constant(np.random.uniform(size=[2, 3, 4]))
handle = list_ops.tensor_list_from_tensor(t, element_shape=None)
with ops.device("CPU:0"):
tiled_handles = array_ops.tile(array_ops.reshape(handle, [1]), [2])
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
[3, 4])
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
[3, 4])
self.assertAllEqual(t, tiled_tensor_0)
self.assertAllEqual(t, tiled_tensor_1)
# Now mutate some of the lists and make sure the changes are not reflected
# in the tiled handles.
with ops.control_dependencies([
list_ops.tensor_list_scatter([t[0] + 1], [0], input_handle=handle),
list_ops.tensor_list_set_item(tiled_handles[0], 0, t[0] + 2)]):
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
[3, 4])
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
[3, 4])
self.assertAllEqual(t, tiled_tensor_0)
self.assertAllEqual(t, tiled_tensor_1)
if __name__ == "__main__":
test_lib.main()