Handle int64 axis in ReduceTransposer
PiperOrigin-RevId: 322830387 Change-Id: I4c5c7a536926fd032d5efc08cddf67e3844bca38
This commit is contained in:
parent
78688104bc
commit
dc9685322d
@ -1236,7 +1236,12 @@ bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < axis_size; ++i) {
|
||||
int local_axis = tensor.flat<int>()(i);
|
||||
int local_axis = 0;
|
||||
if (tensor.dtype() == DT_INT32) {
|
||||
local_axis = tensor.flat<int32>()(i);
|
||||
} else {
|
||||
local_axis = tensor.flat<int64>()(i);
|
||||
}
|
||||
if (local_axis < 0) {
|
||||
local_axis += rank;
|
||||
}
|
||||
|
@ -370,6 +370,136 @@ class TransposerTest : public ::testing::Test {
|
||||
|
||||
void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
|
||||
|
||||
template <typename T>
|
||||
void ReduceTransposerKeepDims() {
|
||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GrapplerItem item;
|
||||
Scope scope = Scope::NewRootScope();
|
||||
|
||||
auto input =
|
||||
ops::RandomUniform(scope.WithOpName("input"),
|
||||
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
|
||||
auto filter =
|
||||
ops::RandomUniform(scope.WithOpName("filter"),
|
||||
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
|
||||
Output conv2d = ops::Conv2D(
|
||||
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
|
||||
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
|
||||
|
||||
auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
|
||||
auto attrs = ops::Sum::Attrs().KeepDims(true);
|
||||
auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
|
||||
conv2d, axis, attrs);
|
||||
|
||||
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
ASSERT_NE(c2d, nullptr);
|
||||
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
|
||||
|
||||
ReduceTransposer reducer_transposer;
|
||||
auto* sum = context.graph_view->GetNode("sum");
|
||||
ASSERT_NE(sum, nullptr);
|
||||
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
|
||||
|
||||
auto* input_transpose_node = context.graph_view->GetNode(
|
||||
"sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(input_transpose_node, nullptr);
|
||||
|
||||
auto* updated_sum_node = context.graph_view->GetNode("sum");
|
||||
ASSERT_NE(updated_sum_node, nullptr);
|
||||
ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
|
||||
VerifyRegularFaninMatch(updated_sum_node, 0,
|
||||
input_transpose_node->GetName(), 0);
|
||||
|
||||
auto* axis_node = context.graph_view->GetNode(
|
||||
"sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(axis_node, nullptr);
|
||||
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
|
||||
|
||||
auto* output_transpose_node = context.graph_view->GetNode(
|
||||
"sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
ASSERT_NE(output_transpose_node, nullptr);
|
||||
|
||||
auto* z_output_node = context.graph_view->GetNode("z");
|
||||
ASSERT_NE(z_output_node, nullptr);
|
||||
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
|
||||
0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ReduceTransposerValidAxisNode() {
|
||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GrapplerItem item;
|
||||
Scope scope = Scope::NewRootScope();
|
||||
|
||||
auto input =
|
||||
ops::RandomUniform(scope.WithOpName("input"),
|
||||
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
|
||||
auto filter =
|
||||
ops::RandomUniform(scope.WithOpName("filter"),
|
||||
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
|
||||
Output conv2d = ops::Conv2D(
|
||||
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
|
||||
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
|
||||
|
||||
auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
|
||||
auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
|
||||
conv2d, axis);
|
||||
|
||||
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
ASSERT_NE(c2d, nullptr);
|
||||
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
|
||||
|
||||
ReduceTransposer reducer_transposer;
|
||||
auto* max = context.graph_view->GetNode("max");
|
||||
ASSERT_NE(max, nullptr);
|
||||
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
|
||||
|
||||
auto* input_transpose_node = context.graph_view->GetNode(
|
||||
"max-0-TransposeNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(input_transpose_node, nullptr);
|
||||
|
||||
auto* updated_max_node = context.graph_view->GetNode("max");
|
||||
ASSERT_NE(updated_max_node, nullptr);
|
||||
ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
|
||||
VerifyRegularFaninMatch(updated_max_node, 0,
|
||||
input_transpose_node->GetName(), 0);
|
||||
|
||||
auto* axis_node = context.graph_view->GetNode(
|
||||
"max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(axis_node, nullptr);
|
||||
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
|
||||
|
||||
auto* z_output_node = context.graph_view->GetNode("z");
|
||||
ASSERT_NE(z_output_node, nullptr);
|
||||
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
|
||||
}
|
||||
|
||||
std::unique_ptr<Cluster> virtual_cluster_;
|
||||
};
|
||||
|
||||
@ -3637,131 +3767,13 @@ TEST_F(TransposerTest, StridedSliceTransposerConstFaninBadRank) {
|
||||
}
|
||||
|
||||
TEST_F(TransposerTest, ReduceTransposerKeepDims) {
|
||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GrapplerItem item;
|
||||
Scope scope = Scope::NewRootScope();
|
||||
|
||||
auto input =
|
||||
ops::RandomUniform(scope.WithOpName("input"),
|
||||
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
|
||||
auto filter =
|
||||
ops::RandomUniform(scope.WithOpName("filter"),
|
||||
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
|
||||
Output conv2d = ops::Conv2D(
|
||||
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
|
||||
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
|
||||
|
||||
auto axis = ops::Const(scope.WithOpName("axis"), {0, 1, 2}, {3});
|
||||
auto attrs = ops::Sum::Attrs().KeepDims(true);
|
||||
auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
|
||||
conv2d, axis, attrs);
|
||||
|
||||
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
ASSERT_NE(c2d, nullptr);
|
||||
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
|
||||
|
||||
ReduceTransposer reducer_transposer;
|
||||
auto* sum = context.graph_view->GetNode("sum");
|
||||
ASSERT_NE(sum, nullptr);
|
||||
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
|
||||
|
||||
auto* input_transpose_node =
|
||||
context.graph_view->GetNode("sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(input_transpose_node, nullptr);
|
||||
|
||||
auto* updated_sum_node = context.graph_view->GetNode("sum");
|
||||
ASSERT_NE(updated_sum_node, nullptr);
|
||||
ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
|
||||
VerifyRegularFaninMatch(updated_sum_node, 0, input_transpose_node->GetName(),
|
||||
0);
|
||||
|
||||
auto* axis_node = context.graph_view->GetNode(
|
||||
"sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(axis_node, nullptr);
|
||||
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
|
||||
|
||||
auto* output_transpose_node = context.graph_view->GetNode(
|
||||
"sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
|
||||
ASSERT_NE(output_transpose_node, nullptr);
|
||||
|
||||
auto* z_output_node = context.graph_view->GetNode("z");
|
||||
ASSERT_NE(z_output_node, nullptr);
|
||||
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
|
||||
0);
|
||||
ReduceTransposerKeepDims<int32>();
|
||||
ReduceTransposerKeepDims<int64>();
|
||||
}
|
||||
|
||||
TEST_F(TransposerTest, ReduceTransposerValidAxisNode) {
|
||||
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
|
||||
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||
GrapplerItem item;
|
||||
Scope scope = Scope::NewRootScope();
|
||||
|
||||
auto input =
|
||||
ops::RandomUniform(scope.WithOpName("input"),
|
||||
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
|
||||
auto filter =
|
||||
ops::RandomUniform(scope.WithOpName("filter"),
|
||||
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
|
||||
Output conv2d = ops::Conv2D(
|
||||
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
|
||||
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
|
||||
|
||||
auto axis = ops::Const(scope.WithOpName("axis"), {0, 1, 2}, {3});
|
||||
auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
|
||||
conv2d, axis);
|
||||
|
||||
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
ASSERT_NE(c2d, nullptr);
|
||||
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
|
||||
|
||||
ReduceTransposer reducer_transposer;
|
||||
auto* max = context.graph_view->GetNode("max");
|
||||
ASSERT_NE(max, nullptr);
|
||||
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
|
||||
|
||||
auto* input_transpose_node =
|
||||
context.graph_view->GetNode("max-0-TransposeNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(input_transpose_node, nullptr);
|
||||
|
||||
auto* updated_max_node = context.graph_view->GetNode("max");
|
||||
ASSERT_NE(updated_max_node, nullptr);
|
||||
ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
|
||||
VerifyRegularFaninMatch(updated_max_node, 0, input_transpose_node->GetName(),
|
||||
0);
|
||||
|
||||
auto* axis_node = context.graph_view->GetNode(
|
||||
"max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
|
||||
ASSERT_NE(axis_node, nullptr);
|
||||
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
|
||||
|
||||
auto* z_output_node = context.graph_view->GetNode("z");
|
||||
ASSERT_NE(z_output_node, nullptr);
|
||||
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
|
||||
VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
|
||||
ReduceTransposerValidAxisNode<int32>();
|
||||
ReduceTransposerValidAxisNode<int64>();
|
||||
}
|
||||
|
||||
TEST(PermutationTest, PermutesVector) {
|
||||
|
Loading…
Reference in New Issue
Block a user