Extend DataFormatDimMap to handle tensors.
PiperOrigin-RevId: 179726269
This commit is contained in:
parent
76db97fe39
commit
47249f349d
@ -3,13 +3,14 @@ op {
|
|||||||
in_arg {
|
in_arg {
|
||||||
name: "x"
|
name: "x"
|
||||||
description: <<END
|
description: <<END
|
||||||
Scalar. Dimension index in source data format. Must be in the range [-4, 4).
|
A Tensor with each element as a dimension index in source data format.
|
||||||
|
Must be in the range [-4, 4).
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
out_arg {
|
out_arg {
|
||||||
name: "y"
|
name: "y"
|
||||||
description: <<END
|
description: <<END
|
||||||
Scalar. Dimension index in destination data format.
|
A Tensor with each element as a dimension index in destination data format.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
attr {
|
attr {
|
||||||
|
@ -50,16 +50,11 @@ class DataFormatDimMapOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
const Tensor& input = context->input(0);
|
const Tensor& input = context->input(0);
|
||||||
OP_REQUIRES(
|
|
||||||
context, input.dims() == 0,
|
|
||||||
errors::InvalidArgument("input must be a scalar, but got shape ",
|
|
||||||
input.shape().DebugString()));
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, input.shape(), &output));
|
context->allocate_output(0, input.shape(), &output));
|
||||||
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
|
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
|
||||||
input.scalar<T>(),
|
input.flat<T>(), output->flat<T>());
|
||||||
output->scalar<T>());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -137,11 +132,11 @@ TF_CALL_int64(REGISTER_KERNEL);
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
#define DECLARE_GPU_SPEC(T) \
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
template <> \
|
template <> \
|
||||||
void DataFormatDimMap<GPUDevice, T>::operator()( \
|
void DataFormatDimMap<GPUDevice, T>::operator()( \
|
||||||
const GPUDevice& d, typename TTypes<T>::ConstScalar x, \
|
const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
|
||||||
typename TTypes<T>::Scalar y); \
|
typename TTypes<T>::Flat y); \
|
||||||
extern template struct DataFormatDimMap<GPUDevice, T>;
|
extern template struct DataFormatDimMap<GPUDevice, T>;
|
||||||
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
|
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
|
||||||
TF_CALL_int32(DECLARE_GPU_SPECS);
|
TF_CALL_int32(DECLARE_GPU_SPECS);
|
||||||
|
@ -26,8 +26,8 @@ namespace functor {
|
|||||||
// Functor used by DataFormatDimMapOP to do the computations.
|
// Functor used by DataFormatDimMapOP to do the computations.
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
struct DataFormatDimMap {
|
struct DataFormatDimMap {
|
||||||
void operator()(const Device& d, typename TTypes<T>::ConstScalar x,
|
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
|
||||||
typename TTypes<T>::Scalar y) {
|
typename TTypes<T>::Flat y) {
|
||||||
auto zero = x.constant(0);
|
auto zero = x.constant(0);
|
||||||
auto one = x.constant(1);
|
auto one = x.constant(1);
|
||||||
auto three = x.constant(3);
|
auto three = x.constant(3);
|
||||||
|
@ -762,8 +762,9 @@ REGISTER_OP("DataFormatDimMap")
|
|||||||
Returns the dimension index in the destination data format given the one in
|
Returns the dimension index in the destination data format given the one in
|
||||||
the source data format.
|
the source data format.
|
||||||
|
|
||||||
x: Scalar. Dimension index in source data format. Must be in the range [-4, 4).
|
x: A Tensor with each element as a dimension index in source data format.
|
||||||
y: Scalar. Dimension index in destination data format.
|
Must be in the range [-4, 4).
|
||||||
|
y: A Tensor with each element as a dimension index in destination data format.
|
||||||
src_format: source data format.
|
src_format: source data format.
|
||||||
dst_format: destination data format.
|
dst_format: destination data format.
|
||||||
)doc");
|
)doc");
|
||||||
|
@ -960,7 +960,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
|||||||
y = nn_ops.data_format_dim_map(x)
|
y = nn_ops.data_format_dim_map(x)
|
||||||
with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
|
with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
|
||||||
y_val = sess.run(y)
|
y_val = sess.run(y)
|
||||||
self.assertEqual(y_val, y_val_expected)
|
self.assertAllEqual(y_val, y_val_expected)
|
||||||
|
|
||||||
def test(self):
|
def test(self):
|
||||||
self._test(0, 0)
|
self._test(0, 0)
|
||||||
@ -971,6 +971,10 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
|||||||
self._test(-2, 3)
|
self._test(-2, 3)
|
||||||
self._test(-3, 2)
|
self._test(-3, 2)
|
||||||
self._test(-4, 0)
|
self._test(-4, 0)
|
||||||
|
self._test([1, 3], [2, 1])
|
||||||
|
self._test([1, 3, -2], [2, 1, 3])
|
||||||
|
self._test([1, -3, -2], [2, 2, 3])
|
||||||
|
self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]])
|
||||||
|
|
||||||
|
|
||||||
class DataFormatVectorPermuteTest(test_lib.TestCase):
|
class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user