diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 157b3f30b24..492cf0b9fd6 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1339,6 +1339,7 @@ tf_kernel_library( "tile_functor_cpu_int8.cc", "tile_functor_cpu_tstring.cc", "tile_functor_cpu_uint8.cc", + "tile_functor_cpu_variant.cc", "tile_functor_sycl.cc", ], hdrs = ["tile_functor.h"], @@ -6907,6 +6908,7 @@ filegroup( "tile_functor_cpu_int8.cc", "tile_functor_cpu_tstring.cc", "tile_functor_cpu_uint8.cc", + "tile_functor_cpu_variant.cc", "tile_ops.cc", "tile_ops_cpu_impl_1.cc", "tile_ops_cpu_impl_2.cc", diff --git a/tensorflow/core/kernels/tile_functor_cpu_variant.cc b/tensorflow/core/kernels/tile_functor_cpu_variant.cc new file mode 100644 index 00000000000..9ecfb4e9fe1 --- /dev/null +++ b/tensorflow/core/kernels/tile_functor_cpu_variant.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/kernels/tile_functor_cpu.h" + +namespace tensorflow { +namespace functor { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template struct Tile<CPUDevice, Variant, int32>; +template struct Tile<CPUDevice, Variant, int64>; + +} // end namespace functor +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index cd047ed9d4a..5000e3b0f12 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -143,6 +143,7 @@ TF_CALL_half(DECLARE_TYPE); TF_CALL_complex64(DECLARE_TYPE); TF_CALL_complex128(DECLARE_TYPE); TF_CALL_tstring(DECLARE_TYPE); +TF_CALL_variant(DECLARE_TYPE); #undef DECLARE_TYPE #define DECLARE_DIM(T, NDIM) \ @@ -244,6 +245,7 @@ class TileOp : public OpKernel { TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice. TF_CALL_complex64(HANDLE_TYPE_NAME); TF_CALL_complex128(HANDLE_TYPE_NAME); + TF_CALL_variant(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice #undef HANDLE_TYPE_NAME #undef HANDLE_TYPE @@ -323,6 +325,7 @@ TF_CALL_half(HANDLE_TYPE_NAME_CPU); TF_CALL_complex64(HANDLE_TYPE_NAME_CPU); TF_CALL_complex128(HANDLE_TYPE_NAME_CPU); TF_CALL_tstring(HANDLE_TYPE_NAME_CPU); +TF_CALL_variant(HANDLE_TYPE_NAME_CPU); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_CALL_bool(HANDLE_TYPE_NAME_GPU); diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index bea08ac70bf..9eb8bfcef41 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -42,6 +42,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import map_fn from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -1994,5 +1995,32 @@ class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(v_tf_fn, v_np) +@test_util.run_all_in_graph_and_eager_modes +class TileVariantTest(test_util.TensorFlowTestCase): + + def test_tile_tensor_list(self): + t = constant_op.constant(np.random.uniform(size=[2, 3, 4])) + handle = list_ops.tensor_list_from_tensor(t, element_shape=None) + with ops.device("CPU:0"): + tiled_handles = array_ops.tile(array_ops.reshape(handle, [1]), [2]) + tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2, + [3, 4]) + tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2, + [3, 4]) + self.assertAllEqual(t, tiled_tensor_0) + self.assertAllEqual(t, tiled_tensor_1) + # Now mutate some of the lists and make sure the changes are not reflected + # in the tiled handles. + with ops.control_dependencies([ + list_ops.tensor_list_scatter([t[0] + 1], [0], input_handle=handle), + list_ops.tensor_list_set_item(tiled_handles[0], 0, t[0] + 2)]): + tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2, + [3, 4]) + tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2, + [3, 4]) + self.assertAllEqual(t, tiled_tensor_0) + self.assertAllEqual(t, tiled_tensor_1) + + if __name__ == "__main__": test_lib.main()