Extend DataFormatDimMap to handle tensors.

PiperOrigin-RevId: 179726269
This commit is contained in:
Yao Zhang 2017-12-20 13:30:22 -08:00 committed by TensorFlower Gardener
parent 76db97fe39
commit 47249f349d
5 changed files with 19 additions and 18 deletions

View File

@ -3,13 +3,14 @@ op {
in_arg {
name: "x"
description: <<END
Scalar. Dimension index in source data format. Must be in the range [-4, 4).
A Tensor with each element as a dimension index in source data format.
Must be in the range [-4, 4).
END
}
out_arg {
name: "y"
description: <<END
Scalar. Dimension index in destination data format.
A Tensor with each element as a dimension index in destination data format.
END
}
attr {

View File

@ -50,16 +50,11 @@ class DataFormatDimMapOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
OP_REQUIRES(
context, input.dims() == 0,
errors::InvalidArgument("input must be a scalar, but got shape ",
input.shape().DebugString()));
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &output));
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
input.scalar<T>(),
output->scalar<T>());
input.flat<T>(), output->flat<T>());
}
};
@ -137,11 +132,11 @@ TF_CALL_int64(REGISTER_KERNEL);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void DataFormatDimMap<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstScalar x, \
typename TTypes<T>::Scalar y); \
#define DECLARE_GPU_SPEC(T) \
template <> \
void DataFormatDimMap<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
typename TTypes<T>::Flat y); \
extern template struct DataFormatDimMap<GPUDevice, T>;
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
TF_CALL_int32(DECLARE_GPU_SPECS);

View File

@ -26,8 +26,8 @@ namespace functor {
// Functor used by DataFormatDimMapOP to do the computations.
template <typename Device, typename T>
struct DataFormatDimMap {
void operator()(const Device& d, typename TTypes<T>::ConstScalar x,
typename TTypes<T>::Scalar y) {
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
typename TTypes<T>::Flat y) {
auto zero = x.constant(0);
auto one = x.constant(1);
auto three = x.constant(3);

View File

@ -762,8 +762,9 @@ REGISTER_OP("DataFormatDimMap")
Returns the dimension index in the destination data format given the one in
the source data format.
x: Scalar. Dimension index in source data format. Must be in the range [-4, 4).
y: Scalar. Dimension index in destination data format.
x: A Tensor with each element as a dimension index in source data format.
Must be in the range [-4, 4).
y: A Tensor with each element as a dimension index in destination data format.
src_format: source data format.
dst_format: destination data format.
)doc");

View File

@ -960,7 +960,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
y = nn_ops.data_format_dim_map(x)
with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
y_val = sess.run(y)
self.assertEqual(y_val, y_val_expected)
self.assertAllEqual(y_val, y_val_expected)
def test(self):
self._test(0, 0)
@ -971,6 +971,10 @@ class DataFormatDimMapTest(test_lib.TestCase):
self._test(-2, 3)
self._test(-3, 2)
self._test(-4, 0)
self._test([1, 3], [2, 1])
self._test([1, 3, -2], [2, 1, 3])
self._test([1, -3, -2], [2, 2, 3])
self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]])
class DataFormatVectorPermuteTest(test_lib.TestCase):