Extend DataFormatDimMap to handle tensors.
PiperOrigin-RevId: 179726269
This commit is contained in:
parent
76db97fe39
commit
47249f349d
@ -3,13 +3,14 @@ op {
|
||||
in_arg {
|
||||
name: "x"
|
||||
description: <<END
|
||||
Scalar. Dimension index in source data format. Must be in the range [-4, 4).
|
||||
A Tensor with each element as a dimension index in source data format.
|
||||
Must be in the range [-4, 4).
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "y"
|
||||
description: <<END
|
||||
Scalar. Dimension index in destination data format.
|
||||
A Tensor with each element as a dimension index in destination data format.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
|
@ -50,16 +50,11 @@ class DataFormatDimMapOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
OP_REQUIRES(
|
||||
context, input.dims() == 0,
|
||||
errors::InvalidArgument("input must be a scalar, but got shape ",
|
||||
input.shape().DebugString()));
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &output));
|
||||
functor::DataFormatDimMap<Device, T>()(context->eigen_device<Device>(),
|
||||
input.scalar<T>(),
|
||||
output->scalar<T>());
|
||||
input.flat<T>(), output->flat<T>());
|
||||
}
|
||||
};
|
||||
|
||||
@ -140,8 +135,8 @@ namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void DataFormatDimMap<GPUDevice, T>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T>::ConstScalar x, \
|
||||
typename TTypes<T>::Scalar y); \
|
||||
const GPUDevice& d, typename TTypes<T>::ConstFlat x, \
|
||||
typename TTypes<T>::Flat y); \
|
||||
extern template struct DataFormatDimMap<GPUDevice, T>;
|
||||
#define DECLARE_GPU_SPECS(T) DECLARE_GPU_SPEC(T);
|
||||
TF_CALL_int32(DECLARE_GPU_SPECS);
|
||||
|
@ -26,8 +26,8 @@ namespace functor {
|
||||
// Functor used by DataFormatDimMapOP to do the computations.
|
||||
template <typename Device, typename T>
|
||||
struct DataFormatDimMap {
|
||||
void operator()(const Device& d, typename TTypes<T>::ConstScalar x,
|
||||
typename TTypes<T>::Scalar y) {
|
||||
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
|
||||
typename TTypes<T>::Flat y) {
|
||||
auto zero = x.constant(0);
|
||||
auto one = x.constant(1);
|
||||
auto three = x.constant(3);
|
||||
|
@ -762,8 +762,9 @@ REGISTER_OP("DataFormatDimMap")
|
||||
Returns the dimension index in the destination data format given the one in
|
||||
the source data format.
|
||||
|
||||
x: Scalar. Dimension index in source data format. Must be in the range [-4, 4).
|
||||
y: Scalar. Dimension index in destination data format.
|
||||
x: A Tensor with each element as a dimension index in source data format.
|
||||
Must be in the range [-4, 4).
|
||||
y: A Tensor with each element as a dimension index in destination data format.
|
||||
src_format: source data format.
|
||||
dst_format: destination data format.
|
||||
)doc");
|
||||
|
@ -960,7 +960,7 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||||
y = nn_ops.data_format_dim_map(x)
|
||||
with self.test_session(use_gpu=test_lib.is_gpu_available()) as sess:
|
||||
y_val = sess.run(y)
|
||||
self.assertEqual(y_val, y_val_expected)
|
||||
self.assertAllEqual(y_val, y_val_expected)
|
||||
|
||||
def test(self):
|
||||
self._test(0, 0)
|
||||
@ -971,6 +971,10 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||||
self._test(-2, 3)
|
||||
self._test(-3, 2)
|
||||
self._test(-4, 0)
|
||||
self._test([1, 3], [2, 1])
|
||||
self._test([1, 3, -2], [2, 1, 3])
|
||||
self._test([1, -3, -2], [2, 2, 3])
|
||||
self._test([[1, -3], [1, -1]], [[2, 2], [2, 1]])
|
||||
|
||||
|
||||
class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user