Merge pull request #43226 from kaixih:reduce_ops_layout
PiperOrigin-RevId: 333786383 Change-Id: Ifefb0a3d23cf7779858e7c011fd4195024ab9dc5
This commit is contained in:
commit
ab55c62645
tensorflow
compiler
core
python
@ -63,6 +63,22 @@ class XlaDataFormatDimMapTest(xla_test.XLATestCase):
|
|||||||
self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq",
|
self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq",
|
||||||
[3, 2, 1, 0, 3, 2, 1, 0])
|
[3, 2, 1, 0, 3, 2, 1, 0])
|
||||||
|
|
||||||
|
self._test(0, "NDHWC", "NCDHW", 0)
|
||||||
|
self._test(1, "NDHWC", "NCDHW", 2)
|
||||||
|
self._test(2, "NDHWC", "NCDHW", 3)
|
||||||
|
self._test(3, "NDHWC", "NCDHW", 4)
|
||||||
|
self._test(4, "NDHWC", "NCDHW", 1)
|
||||||
|
self._test([1, 4], "NDHWC", "NCDHW", [2, 1])
|
||||||
|
self._test([1, 4, -2], "NDHWC", "NCDHW", [2, 1, 4])
|
||||||
|
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
|
||||||
|
self._test([[1, -4], [1, -1]], "NDHWC", "NCDHW", [[2, 2], [2, 1]])
|
||||||
|
|
||||||
|
self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
|
||||||
|
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "DHWNC",
|
||||||
|
[3, 0, 1, 2, 4, 3, 0, 1, 2, 4])
|
||||||
|
self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "WHDCN",
|
||||||
|
[4, 2, 1, 0, 3, 4, 2, 1, 0, 3])
|
||||||
|
|
||||||
|
|
||||||
class XlaPermuteOpTest(xla_test.XLATestCase):
|
class XlaPermuteOpTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@ -35,15 +35,19 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||||
string dst_format;
|
string dst_format;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||||
OP_REQUIRES(context, src_format.size() == 4,
|
OP_REQUIRES(context, src_format.size() == 4 or src_format.size() == 5,
|
||||||
errors::InvalidArgument(absl::StrCat(
|
errors::InvalidArgument(
|
||||||
"Source format must of length 4, received src_format = ",
|
absl::StrCat("Source format must of length 4 or 5, "
|
||||||
|
"received src_format = ",
|
||||||
src_format)));
|
src_format)));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, dst_format.size() == 4,
|
context, dst_format.size() == 4 or dst_format.size() == 5,
|
||||||
errors::InvalidArgument(absl::StrCat(
|
errors::InvalidArgument(absl::StrCat(
|
||||||
"Destination format must of length 4, received dst_format = ",
|
"Destination format must of length 4 or 5, received dst_format = ",
|
||||||
dst_format)));
|
dst_format)));
|
||||||
|
for (int i = 0; i < src_format.size(); ++i) {
|
||||||
|
dst_idx_.push_back(-1);
|
||||||
|
}
|
||||||
for (int i = 0; i < src_format.size(); ++i) {
|
for (int i = 0; i < src_format.size(); ++i) {
|
||||||
for (int j = 0; j < dst_format.size(); ++j) {
|
for (int j = 0; j < dst_format.size(); ++j) {
|
||||||
if (dst_format[j] == src_format[i]) {
|
if (dst_format[j] == src_format[i]) {
|
||||||
@ -61,9 +65,10 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
|||||||
auto builder = context->builder();
|
auto builder = context->builder();
|
||||||
xla::XlaOp dst_indices =
|
xla::XlaOp dst_indices =
|
||||||
xla::ConstantR1(builder, absl::Span<const int32>(dst_idx_));
|
xla::ConstantR1(builder, absl::Span<const int32>(dst_idx_));
|
||||||
xla::XlaOp four = xla::ConstantR0<int32>(builder, 4);
|
const int dims = dst_idx_.size();
|
||||||
|
xla::XlaOp rank = xla::ConstantR0<int32>(builder, dims);
|
||||||
xla::XlaOp src_indices =
|
xla::XlaOp src_indices =
|
||||||
(xla::ConvertElementType(context->Input(0), xla::S32) + four) % four;
|
(xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank;
|
||||||
xla::XlaOp output =
|
xla::XlaOp output =
|
||||||
xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0);
|
xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0);
|
||||||
context->SetOutput(
|
context->SetOutput(
|
||||||
@ -71,7 +76,7 @@ class DataFormatDimMapOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::array<int32, 4> dst_idx_ = {{-1, -1, -1, -1}};
|
std::vector<int32> dst_idx_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp);
|
||||||
};
|
};
|
||||||
|
@ -1371,11 +1371,35 @@ bool ReduceTransposer::IsReduceAxisSupported(
|
|||||||
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
||||||
utils::MutableNodeView* node) {
|
utils::MutableNodeView* node) {
|
||||||
DCHECK(IsReduceOp(*node->node()));
|
DCHECK(IsReduceOp(*node->node()));
|
||||||
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
|
const auto& regular_fanin = node->GetRegularFanin(0);
|
||||||
|
const auto* output_shape_attr =
|
||||||
|
regular_fanin.node_view()->GetAttr(kAttrOutputShape);
|
||||||
|
const auto& shape = output_shape_attr->list().shape(0);
|
||||||
|
const int rank = shape.dim_size();
|
||||||
|
std::string src_format = context->src_format;
|
||||||
|
std::string dst_format = context->dst_format;
|
||||||
|
// Update the format from 4D to 5D layout if necessary.
|
||||||
|
bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
|
||||||
|
(dst_format == "NHWC" || dst_format == "NCHW");
|
||||||
|
if (allow_5d) {
|
||||||
|
std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
|
||||||
|
dst_format_3d);
|
||||||
|
}
|
||||||
|
if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) ||
|
||||||
!IsReduceAxisSupported(*context, *node) ||
|
!IsReduceAxisSupported(*context, *node) ||
|
||||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||||
|
// Change back to the original layout due to early exit.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
|
||||||
|
<< "' with op '" << node->GetOp() << "' from data format '"
|
||||||
|
<< context->src_format << "' to '" << context->dst_format << "'";
|
||||||
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
|
UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
|
||||||
@ -1383,6 +1407,11 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
|
||||||
}
|
}
|
||||||
|
// Change back the format from 5D to 4D layout.
|
||||||
|
if (allow_5d) {
|
||||||
|
context->AssignDeviceAndDataFormats(context->target_device, src_format,
|
||||||
|
dst_format);
|
||||||
|
}
|
||||||
return context->graph_view->GetMutationBuilder()->Apply();
|
return context->graph_view->GetMutationBuilder()->Apply();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,14 +37,15 @@ class DataFormatDimMapOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||||
string dst_format;
|
string dst_format;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||||
OP_REQUIRES(context, src_format.size() == 4,
|
OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||||||
errors::InvalidArgument(strings::StrCat(
|
errors::InvalidArgument(strings::StrCat(
|
||||||
"Source format must of length 4, received src_format = ",
|
"Source format must of length 4 or 5, received "
|
||||||
|
"src_format = ",
|
||||||
src_format)));
|
src_format)));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, dst_format.size() == 4,
|
context, dst_format.size() == 4 || dst_format.size() == 5,
|
||||||
errors::InvalidArgument(strings::StrCat(
|
errors::InvalidArgument(strings::StrCat(
|
||||||
"Destination format must of length 4, received dst_format = ",
|
"Destination format must of length 4 or 5, received dst_format = ",
|
||||||
dst_format)));
|
dst_format)));
|
||||||
dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
|
dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
|
||||||
for (int i = 0; i < src_format.size(); ++i) {
|
for (int i = 0; i < src_format.size(); ++i) {
|
||||||
|
@ -28,6 +28,7 @@ template <typename Device, typename T>
|
|||||||
struct DataFormatDimMap {
|
struct DataFormatDimMap {
|
||||||
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
|
void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
|
||||||
typename TTypes<T>::Flat y, const TTypes<int>::Vec dst) {
|
typename TTypes<T>::Flat y, const TTypes<int>::Vec dst) {
|
||||||
|
if (dst.size() == 4) {
|
||||||
auto zero = x.constant(0);
|
auto zero = x.constant(0);
|
||||||
auto one = x.constant(1);
|
auto one = x.constant(1);
|
||||||
auto two = x.constant(2);
|
auto two = x.constant(2);
|
||||||
@ -46,6 +47,31 @@ struct DataFormatDimMap {
|
|||||||
|
|
||||||
y.device(d) = is_zero.select(
|
y.device(d) = is_zero.select(
|
||||||
f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
|
f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
|
||||||
|
} else {
|
||||||
|
auto zero = x.constant(0);
|
||||||
|
auto one = x.constant(1);
|
||||||
|
auto two = x.constant(2);
|
||||||
|
auto three = x.constant(3);
|
||||||
|
|
||||||
|
auto f_zero = x.constant(dst(0));
|
||||||
|
auto f_one = x.constant(dst(1));
|
||||||
|
auto f_two = x.constant(dst(2));
|
||||||
|
auto f_three = x.constant(dst(3));
|
||||||
|
auto f_four = x.constant(dst(4));
|
||||||
|
|
||||||
|
auto five = x.constant(5);
|
||||||
|
auto x_mod = (x + five) % 5;
|
||||||
|
|
||||||
|
auto is_zero = (x_mod == zero);
|
||||||
|
auto is_one = (x_mod == one);
|
||||||
|
auto is_two = (x_mod == two);
|
||||||
|
auto is_three = (x_mod == three);
|
||||||
|
|
||||||
|
y.device(d) = is_zero.select(
|
||||||
|
f_zero,
|
||||||
|
is_one.select(
|
||||||
|
f_one, is_two.select(f_two, is_three.select(f_three, f_four))));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -221,6 +221,9 @@ class LayoutOptimizerTest(test.TestCase):
|
|||||||
def _assert_map_nhwc_to_nchw(self, name, nodes):
|
def _assert_map_nhwc_to_nchw(self, name, nodes):
|
||||||
self.assertIn(name + '-DimMapNHWCToNCHW-LayoutOptimizer', nodes)
|
self.assertIn(name + '-DimMapNHWCToNCHW-LayoutOptimizer', nodes)
|
||||||
|
|
||||||
|
def _assert_map_ndhwc_to_ncdhw(self, name, nodes):
|
||||||
|
self.assertIn(name + '-DataFormatDimMapNDHWCToNCDHW-LayoutOptimizer', nodes)
|
||||||
|
|
||||||
def _assert_vec_nchw_to_nhwc(self, name, nodes):
|
def _assert_vec_nchw_to_nhwc(self, name, nodes):
|
||||||
self.assertIn(name + '-VecPermuteNCHWToNHWC-LayoutOptimizer', nodes)
|
self.assertIn(name + '-VecPermuteNCHWToNHWC-LayoutOptimizer', nodes)
|
||||||
|
|
||||||
@ -1194,6 +1197,39 @@ class LayoutOptimizerTest(test.TestCase):
|
|||||||
self._assert_trans_nchw_to_nhwc('LeakyReluGrad-0-0', nodes)
|
self._assert_trans_nchw_to_nhwc('LeakyReluGrad-0-0', nodes)
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||||
|
|
||||||
|
@test_util.deprecated_graph_mode_only
|
||||||
|
def testReduceOpsFor5DTensors(self):
|
||||||
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
random_seed.set_random_seed(0)
|
||||||
|
x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
|
||||||
|
w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
|
||||||
|
conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
|
||||||
|
y = math_ops.reduce_mean(conv3d, [0, 1, 2, 3], keepdims=True)
|
||||||
|
output = array_ops.identity(y)
|
||||||
|
|
||||||
|
with session.Session(config=_get_config(False)) as sess:
|
||||||
|
output_val_ref = sess.run(output)
|
||||||
|
|
||||||
|
with session.Session(config=_get_config()) as sess:
|
||||||
|
metadata = config_pb2.RunMetadata()
|
||||||
|
output_val = sess.run(output, run_metadata=metadata)
|
||||||
|
|
||||||
|
nodes = []
|
||||||
|
num_transposes = 0
|
||||||
|
for node in metadata.cost_graph.node:
|
||||||
|
if _is_transpose(node.name):
|
||||||
|
num_transposes += 1
|
||||||
|
nodes.append(node.name)
|
||||||
|
|
||||||
|
# The reduce op Mean needs to dim map the input reduce index to NCDHW.
|
||||||
|
# Then, the output needs to be tranposed back to NDHWC.
|
||||||
|
expected_num_transposes = 2
|
||||||
|
self.assertEqual(expected_num_transposes, num_transposes)
|
||||||
|
self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
|
||||||
|
self._assert_map_ndhwc_to_ncdhw('Mean-1', nodes)
|
||||||
|
self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes)
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
@test_util.deprecated_graph_mode_only
|
||||||
def testConv3D(self):
|
def testConv3D(self):
|
||||||
if test.is_gpu_available(cuda_only=True):
|
if test.is_gpu_available(cuda_only=True):
|
||||||
|
@ -1235,6 +1235,33 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
|||||||
y_val = self.evaluate(y)
|
y_val = self.evaluate(y)
|
||||||
self.assertAllEqual(y_val, y_val_expected)
|
self.assertAllEqual(y_val, y_val_expected)
|
||||||
|
|
||||||
|
def testNDHWCtoNCDHW(self):
|
||||||
|
x_val = [1, -4, -3, -2]
|
||||||
|
y_val_expected = [2, 2, 3, 4]
|
||||||
|
x = constant_op.constant(x_val)
|
||||||
|
y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW")
|
||||||
|
with test_util.use_gpu():
|
||||||
|
y_val = self.evaluate(y)
|
||||||
|
self.assertAllEqual(y_val, y_val_expected)
|
||||||
|
|
||||||
|
def testNDHWCtoDHWNC(self):
|
||||||
|
x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||||
|
y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4]
|
||||||
|
x = constant_op.constant(x_val)
|
||||||
|
y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC")
|
||||||
|
with test_util.use_gpu():
|
||||||
|
y_val = self.evaluate(y)
|
||||||
|
self.assertAllEqual(y_val, y_val_expected)
|
||||||
|
|
||||||
|
def testDNHWCtoWHDCN(self):
|
||||||
|
x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
|
||||||
|
y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3]
|
||||||
|
x = constant_op.constant(x_val)
|
||||||
|
y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN")
|
||||||
|
with test_util.use_gpu():
|
||||||
|
y_val = self.evaluate(y)
|
||||||
|
self.assertAllEqual(y_val, y_val_expected)
|
||||||
|
|
||||||
def testArbitraryASCII(self):
|
def testArbitraryASCII(self):
|
||||||
x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
|
x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||||
y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
|
y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
|
||||||
|
Loading…
Reference in New Issue
Block a user