Allow tile op to work on variant dtype.
PiperOrigin-RevId: 312382133 Change-Id: I3a0f95865ca0f782fa73f7ba55b3d987de006332
This commit is contained in:
parent
db573482f4
commit
f8a918ccf6
tensorflow
@ -1339,6 +1339,7 @@ tf_kernel_library(
|
|||||||
"tile_functor_cpu_int8.cc",
|
"tile_functor_cpu_int8.cc",
|
||||||
"tile_functor_cpu_tstring.cc",
|
"tile_functor_cpu_tstring.cc",
|
||||||
"tile_functor_cpu_uint8.cc",
|
"tile_functor_cpu_uint8.cc",
|
||||||
|
"tile_functor_cpu_variant.cc",
|
||||||
"tile_functor_sycl.cc",
|
"tile_functor_sycl.cc",
|
||||||
],
|
],
|
||||||
hdrs = ["tile_functor.h"],
|
hdrs = ["tile_functor.h"],
|
||||||
@ -6907,6 +6908,7 @@ filegroup(
|
|||||||
"tile_functor_cpu_int8.cc",
|
"tile_functor_cpu_int8.cc",
|
||||||
"tile_functor_cpu_tstring.cc",
|
"tile_functor_cpu_tstring.cc",
|
||||||
"tile_functor_cpu_uint8.cc",
|
"tile_functor_cpu_uint8.cc",
|
||||||
|
"tile_functor_cpu_variant.cc",
|
||||||
"tile_ops.cc",
|
"tile_ops.cc",
|
||||||
"tile_ops_cpu_impl_1.cc",
|
"tile_ops_cpu_impl_1.cc",
|
||||||
"tile_ops_cpu_impl_2.cc",
|
"tile_ops_cpu_impl_2.cc",
|
||||||
|
30
tensorflow/core/kernels/tile_functor_cpu_variant.cc
Normal file
30
tensorflow/core/kernels/tile_functor_cpu_variant.cc
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/variant.h"
|
||||||
|
#include "tensorflow/core/kernels/tile_functor_cpu.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
template struct Tile<CPUDevice, Variant, int32>;
|
||||||
|
template struct Tile<CPUDevice, Variant, int64>;
|
||||||
|
|
||||||
|
} // end namespace functor
|
||||||
|
} // end namespace tensorflow
|
@ -143,6 +143,7 @@ TF_CALL_half(DECLARE_TYPE);
|
|||||||
TF_CALL_complex64(DECLARE_TYPE);
|
TF_CALL_complex64(DECLARE_TYPE);
|
||||||
TF_CALL_complex128(DECLARE_TYPE);
|
TF_CALL_complex128(DECLARE_TYPE);
|
||||||
TF_CALL_tstring(DECLARE_TYPE);
|
TF_CALL_tstring(DECLARE_TYPE);
|
||||||
|
TF_CALL_variant(DECLARE_TYPE);
|
||||||
#undef DECLARE_TYPE
|
#undef DECLARE_TYPE
|
||||||
|
|
||||||
#define DECLARE_DIM(T, NDIM) \
|
#define DECLARE_DIM(T, NDIM) \
|
||||||
@ -244,6 +245,7 @@ class TileOp : public OpKernel {
|
|||||||
TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice.
|
TF_CALL_tstring(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice.
|
||||||
TF_CALL_complex64(HANDLE_TYPE_NAME);
|
TF_CALL_complex64(HANDLE_TYPE_NAME);
|
||||||
TF_CALL_complex128(HANDLE_TYPE_NAME);
|
TF_CALL_complex128(HANDLE_TYPE_NAME);
|
||||||
|
TF_CALL_variant(HANDLE_TYPE_NAME); // when DEVICE=CPUDevice
|
||||||
|
|
||||||
#undef HANDLE_TYPE_NAME
|
#undef HANDLE_TYPE_NAME
|
||||||
#undef HANDLE_TYPE
|
#undef HANDLE_TYPE
|
||||||
@ -323,6 +325,7 @@ TF_CALL_half(HANDLE_TYPE_NAME_CPU);
|
|||||||
TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
|
TF_CALL_complex64(HANDLE_TYPE_NAME_CPU);
|
||||||
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
|
TF_CALL_complex128(HANDLE_TYPE_NAME_CPU);
|
||||||
TF_CALL_tstring(HANDLE_TYPE_NAME_CPU);
|
TF_CALL_tstring(HANDLE_TYPE_NAME_CPU);
|
||||||
|
TF_CALL_variant(HANDLE_TYPE_NAME_CPU);
|
||||||
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
TF_CALL_bool(HANDLE_TYPE_NAME_GPU);
|
TF_CALL_bool(HANDLE_TYPE_NAME_GPU);
|
||||||
|
@ -42,6 +42,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import list_ops
|
||||||
from tensorflow.python.ops import map_fn
|
from tensorflow.python.ops import map_fn
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -1994,5 +1995,32 @@ class RepeatTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(v_tf_fn, v_np)
|
self.assertAllEqual(v_tf_fn, v_np)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
class TileVariantTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def test_tile_tensor_list(self):
|
||||||
|
t = constant_op.constant(np.random.uniform(size=[2, 3, 4]))
|
||||||
|
handle = list_ops.tensor_list_from_tensor(t, element_shape=None)
|
||||||
|
with ops.device("CPU:0"):
|
||||||
|
tiled_handles = array_ops.tile(array_ops.reshape(handle, [1]), [2])
|
||||||
|
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
|
||||||
|
[3, 4])
|
||||||
|
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
|
||||||
|
[3, 4])
|
||||||
|
self.assertAllEqual(t, tiled_tensor_0)
|
||||||
|
self.assertAllEqual(t, tiled_tensor_1)
|
||||||
|
# Now mutate some of the lists and make sure the changes are not reflected
|
||||||
|
# in the tiled handles.
|
||||||
|
with ops.control_dependencies([
|
||||||
|
list_ops.tensor_list_scatter([t[0] + 1], [0], input_handle=handle),
|
||||||
|
list_ops.tensor_list_set_item(tiled_handles[0], 0, t[0] + 2)]):
|
||||||
|
tiled_tensor_0 = list_ops.tensor_list_stack(tiled_handles[0], t.dtype, 2,
|
||||||
|
[3, 4])
|
||||||
|
tiled_tensor_1 = list_ops.tensor_list_stack(tiled_handles[1], t.dtype, 2,
|
||||||
|
[3, 4])
|
||||||
|
self.assertAllEqual(t, tiled_tensor_0)
|
||||||
|
self.assertAllEqual(t, tiled_tensor_1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_lib.main()
|
test_lib.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user