Extend DataFormatDimMap to handle tensors.

PiperOrigin-RevId: 179726269
This commit is contained in:
Yao Zhang 2017-12-20 13:30:22 -08:00 committed by TensorFlower Gardener
parent 76db97fe39
commit 47249f349d
5 changed files with 19 additions and 18 deletions

View File

@ -3,13 +3,14 @@ op {
in_arg { in_arg {
name: "x" name: "x"
description: <<END description: <<END
Scalar. Dimension index in source data format. Must be in the range [-4, 4). A Tensor with each element as a dimension index in source data format.
Must be in the range [-4, 4).
END END
} }
out_arg { out_arg {
name: "y" name: "y"
description: <<END description: <<END
Scalar. Dimension index in destination data format. A Tensor with each element as a dimension index in destination data format.
END END
} }
attr { attr {

View File

@ -50,16 +50,11 @@ class DataFormatDimMapOp : public OpKernel {
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0); const Tensor& input = context->input(0);
OP_REQUIRES(
context, input.dims() == 0,
errors::InvalidArgument("input must be a scalar, but got shape ",
input.shape().DebugString()));
Tensor* output = nullptr; Tensor* output = nullptr;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output)); context->allocate_output(0, input.shape(), &output));
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(), functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
input.scalar<T>(), input.flat<T>(), output->flat<T>());
output->scalar<T>());
} }
}; };
@ -137,11 +132,11 @@ TF_CALL_int64(REGISTER_KERNEL);
#if GOOGLE_CUDA #if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU. // Forward declarations of the functor specializations for GPU.
namespace functor { namespace functor {
#define DECLARE_GPU_SPEC(T) \ #define DECLARE_GPU_SPEC(T) \
template <> \ template <> \
void DataFormatDimMap<GPUDevice, T>::operator()( \ void DataFormatDimMap<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstScalar x, \ const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
typename TTypes<T>::Scalar y); \ typename TTypes<T>::Flat y); \
extern template struct DataFormatDimMap<GPUDevice, T>; extern template struct DataFormatDimMap<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T); #define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS); TF_CALL_int32(DECLARE_GPU_SPECS);

View File

@ -26,8 +26,8 @@ namespace functor {
// Functor used by DataFormatDimMapOP to do the computations. // Functor used by DataFormatDimMapOP to do the computations.
template <typename Device, typename T> template <typename Device, typename T>
struct DataFormatDimMap { struct DataFormatDimMap {
void operator()(const Device& d, typename TTypes<T>::ConstScalar x, void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
typename TTypes<T>::Scalar y) { typename TTypes<T>::Flat y) {
auto zero = x.constant(0); auto zero = x.constant(0);
auto one = x.constant(1); auto one = x.constant(1);
auto three = x.constant(3); auto three = x.constant(3);

View File

@ -762,8 +762,9 @@ REGISTER_OP("DataFormatDimMap")
Returns the dimension index in the destination data format given the one in Returns the dimension index in the destination data format given the one in
the source data format. the source data format.
x: Scalar. Dimension index in source data format. Must be in the range [-4, 4). x: A Tensor with each element as a dimension index in source data format.
y: Scalar. Dimension index in destination data format. Must be in the range [-4, 4).
y: A Tensor with each element as a dimension index in destination data format.
src_format: source data format. src_format: source data format.
dst_format: destination data format. dst_format: destination data format.
)doc"); )doc");

View File

@ -960,7 +960,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
y = nn_ops.data_format_dim_map(x) y = nn_ops.data_format_dim_map(x)
with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess: with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
y_val = sess.run(y) y_val = sess.run(y)
self.assertEqual(y_val, y_val_expected) self.assertAllEqual(y_val, y_val_expected)
def test(self): def test(self):
self._test(0, 0) self._test(0, 0)
@ -971,6 +971,10 @@ class DataFormatDimMapTest(test_lib.TestCase):
self._test(-2, 3) self._test(-2, 3)
self._test(-3, 2) self._test(-3, 2)
self._test(-4, 0) self._test(-4, 0)
self._test([1, 3], [2, 1])
self._test([1, 3, -2], [2, 1, 3])
self._test([1, -3, -2], [2, 2, 3])
self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]])
class DataFormatVectorPermuteTest(test_lib.TestCase): class DataFormatVectorPermuteTest(test_lib.TestCase):