Handle int64 axis in ReduceTransposer

PiperOrigin-RevId: 322830387
Change-Id: I4c5c7a536926fd032d5efc08cddf67e3844bca38
This commit is contained in:
Gaurav Jain 2020-07-23 11:37:54 -07:00 committed by TensorFlower Gardener
parent 78688104bc
commit dc9685322d
2 changed files with 140 additions and 123 deletions

View File

@ -1236,7 +1236,12 @@ bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
return false;
}
for (int i = 0; i < axis_size; ++i) {
int local_axis = tensor.flat<int>()(i);
int local_axis = 0;
if (tensor.dtype() == DT_INT32) {
local_axis = tensor.flat<int32>()(i);
} else {
local_axis = tensor.flat<int64>()(i);
}
if (local_axis < 0) {
local_axis += rank;
}

View File

@ -370,6 +370,136 @@ class TransposerTest : public ::testing::Test {
void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
template <typename T>
void ReduceTransposerKeepDims() {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GrapplerItem item;
Scope scope = Scope::NewRootScope();
auto input =
ops::RandomUniform(scope.WithOpName("input"),
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
auto filter =
ops::RandomUniform(scope.WithOpName("filter"),
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
Output conv2d = ops::Conv2D(
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
auto attrs = ops::Sum::Attrs().KeepDims(true);
auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
conv2d, axis, attrs);
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context));
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d");
ASSERT_NE(c2d, nullptr);
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
ReduceTransposer reducer_transposer;
auto* sum = context.graph_view->GetNode("sum");
ASSERT_NE(sum, nullptr);
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
auto* input_transpose_node = context.graph_view->GetNode(
"sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(input_transpose_node, nullptr);
auto* updated_sum_node = context.graph_view->GetNode("sum");
ASSERT_NE(updated_sum_node, nullptr);
ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
VerifyRegularFaninMatch(updated_sum_node, 0,
input_transpose_node->GetName(), 0);
auto* axis_node = context.graph_view->GetNode(
"sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(axis_node, nullptr);
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
auto* output_transpose_node = context.graph_view->GetNode(
"sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
ASSERT_NE(output_transpose_node, nullptr);
auto* z_output_node = context.graph_view->GetNode("z");
ASSERT_NE(z_output_node, nullptr);
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
0);
}
template <typename T>
void ReduceTransposerValidAxisNode() {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GrapplerItem item;
Scope scope = Scope::NewRootScope();
auto input =
ops::RandomUniform(scope.WithOpName("input"),
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
auto filter =
ops::RandomUniform(scope.WithOpName("filter"),
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
Output conv2d = ops::Conv2D(
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
conv2d, axis);
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context));
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d");
ASSERT_NE(c2d, nullptr);
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
ReduceTransposer reducer_transposer;
auto* max = context.graph_view->GetNode("max");
ASSERT_NE(max, nullptr);
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
auto* input_transpose_node = context.graph_view->GetNode(
"max-0-TransposeNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(input_transpose_node, nullptr);
auto* updated_max_node = context.graph_view->GetNode("max");
ASSERT_NE(updated_max_node, nullptr);
ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
VerifyRegularFaninMatch(updated_max_node, 0,
input_transpose_node->GetName(), 0);
auto* axis_node = context.graph_view->GetNode(
"max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(axis_node, nullptr);
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
auto* z_output_node = context.graph_view->GetNode("z");
ASSERT_NE(z_output_node, nullptr);
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
}
std::unique_ptr<Cluster> virtual_cluster_;
};
@ -3637,131 +3767,13 @@ TEST_F(TransposerTest, StridedSliceTransposerConstFaninBadRank) {
}
TEST_F(TransposerTest, ReduceTransposerKeepDims) {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GrapplerItem item;
Scope scope = Scope::NewRootScope();
auto input =
ops::RandomUniform(scope.WithOpName("input"),
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
auto filter =
ops::RandomUniform(scope.WithOpName("filter"),
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
Output conv2d = ops::Conv2D(
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
auto axis = ops::Const(scope.WithOpName("axis"), {0, 1, 2}, {3});
auto attrs = ops::Sum::Attrs().KeepDims(true);
auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
conv2d, axis, attrs);
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context));
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d");
ASSERT_NE(c2d, nullptr);
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
ReduceTransposer reducer_transposer;
auto* sum = context.graph_view->GetNode("sum");
ASSERT_NE(sum, nullptr);
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
auto* input_transpose_node =
context.graph_view->GetNode("sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(input_transpose_node, nullptr);
auto* updated_sum_node = context.graph_view->GetNode("sum");
ASSERT_NE(updated_sum_node, nullptr);
ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
VerifyRegularFaninMatch(updated_sum_node, 0, input_transpose_node->GetName(),
0);
auto* axis_node = context.graph_view->GetNode(
"sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(axis_node, nullptr);
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
auto* output_transpose_node = context.graph_view->GetNode(
"sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
ASSERT_NE(output_transpose_node, nullptr);
auto* z_output_node = context.graph_view->GetNode("z");
ASSERT_NE(z_output_node, nullptr);
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
0);
ReduceTransposerKeepDims<int32>();
ReduceTransposerKeepDims<int64>();
}
TEST_F(TransposerTest, ReduceTransposerValidAxisNode) {
#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
GrapplerItem item;
Scope scope = Scope::NewRootScope();
auto input =
ops::RandomUniform(scope.WithOpName("input"),
{kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
auto filter =
ops::RandomUniform(scope.WithOpName("filter"),
{kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
Output conv2d = ops::Conv2D(
scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
{1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
auto axis = ops::Const(scope.WithOpName("axis"), {0, 1, 2}, {3});
auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
conv2d, axis);
auto z = ops::Identity(scope.WithOpName("z"), sum_op);
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context));
context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d");
ASSERT_NE(c2d, nullptr);
TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
ReduceTransposer reducer_transposer;
auto* max = context.graph_view->GetNode("max");
ASSERT_NE(max, nullptr);
TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
auto* input_transpose_node =
context.graph_view->GetNode("max-0-TransposeNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(input_transpose_node, nullptr);
auto* updated_max_node = context.graph_view->GetNode("max");
ASSERT_NE(updated_max_node, nullptr);
ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
VerifyRegularFaninMatch(updated_max_node, 0, input_transpose_node->GetName(),
0);
auto* axis_node = context.graph_view->GetNode(
"max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
ASSERT_NE(axis_node, nullptr);
ASSERT_EQ(axis_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
auto* z_output_node = context.graph_view->GetNode("z");
ASSERT_NE(z_output_node, nullptr);
ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
ReduceTransposerValidAxisNode<int32>();
ReduceTransposerValidAxisNode<int64>();
}
TEST(PermutationTest, PermutesVector) {