Allow tile op to work on variant dtype.
PiperOrigin-RevId: 312382133 Change-Id: I3a0f95865ca0f782fa73f7ba55b3d987de006332
This commit is contained in:
parent
db573482f4
commit
f8a918ccf6
@ -1339,6 +1339,7 @@ tf_kernel_library(
|
||||
"tile_functor_cpu_int8.cc",
|
||||
"tile_functor_cpu_tstring.cc",
|
||||
"tile_functor_cpu_uint8.cc",
|
||||
"tile_functor_cpu_variant.cc",
|
||||
"tile_functor_sycl.cc",
|
||||
],
|
||||
hdrs = ["tile_functor.h"],
|
||||
@ -6907,6 +6908,7 @@ filegroup(
|
||||
"tile_functor_cpu_int8.cc",
|
||||
"tile_functor_cpu_tstring.cc",
|
||||
"tile_functor_cpu_uint8.cc",
|
||||
"tile_functor_cpu_variant.cc",
|
||||
"tile_ops.cc",
|
||||
"tile_ops_cpu_impl_1.cc",
|
||||
"tile_ops_cpu_impl_2.cc",
|
||||
|
30
tensorflow/core/kernels/tile_functor_cpu_variant.cc
Normal file
30
tensorflow/core/kernels/tile_functor_cpu_variant.cc
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/variant.h"
|
||||
#include "tensorflow/core/kernels/tile_functor_cpu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template struct Tile<CPUDevice, Variant, int32>;
|
||||
template struct Tile<CPUDevice, Variant, int64>;
|
||||
|
||||
} // end namespace functor
|
||||
} // end namespace tensorflow
|
@ -143,6 +143,7 @@ TF_CALL_half(DECLARE_TYPE);
|
||||
TF_CALL_complex64(DECLARE_TYPE);
|
||||
TF_CALL_complex128(DECLARE_TYPE);
|
||||
TF_CALL_tstring(DECLARE_TYPE);
|
||||
TF_CALL_variant(DECLARE_TYPE);
|
||||
#undef DECLARE_TYPE
|
||||
|
||||
#define DECLARE_DIM(T, NDIM) \
|
||||
@ -244,6 +245,7 @@ class TileOp : public OpKernel {
|
||||
TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice.
|
||||
TF_CALL_complex64(HANDLE_TYPE_NAME);
|
||||
TF_CALL_complex128(HANDLE_TYPE_NAME);
|
||||
TF_CALL_variant(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice
|
||||
|
||||
#undef HANDLE_TYPE_NAME
|
||||
#undef HANDLE_TYPE
|
||||
@ -323,6 +325,7 @@ TF_CALL_half(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_tstring(HANDLE_TYPE_NAME_CPU);
|
||||
TF_CALL_variant(HANDLE_TYPE_NAME_CPU);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TF_CALL_bool(HANDLE_TYPE_NAME_GPU);
|
||||
|
@ -42,6 +42,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -1994,5 +1995,32 @@ class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(v_tf_fn, v_np)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TileVariantTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_tile_tensor_list(self):
|
||||
t = constant_op.constant(np.random.uniform(size=[2, 3, 4]))
|
||||
handle = list_ops.tensor_list_from_tensor(t, element_shape=None)
|
||||
with ops.device("CPU:0"):
|
||||
tiled_handles = array_ops.tile(array_ops.reshape(handle, [1]), [2])
|
||||
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
|
||||
[3, 4])
|
||||
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
|
||||
[3, 4])
|
||||
self.assertAllEqual(t, tiled_tensor_0)
|
||||
self.assertAllEqual(t, tiled_tensor_1)
|
||||
# Now mutate some of the lists and make sure the changes are not reflected
|
||||
# in the tiled handles.
|
||||
with ops.control_dependencies([
|
||||
list_ops.tensor_list_scatter([t[0] + 1], [0], input_handle=handle),
|
||||
list_ops.tensor_list_set_item(tiled_handles[0], 0, t[0] + 2)]):
|
||||
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
|
||||
[3, 4])
|
||||
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
|
||||
[3, 4])
|
||||
self.assertAllEqual(t, tiled_tensor_0)
|
||||
self.assertAllEqual(t, tiled_tensor_1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lib.main()
|
||||
|
Loading…
Reference in New Issue
Block a user