[Grappler] Reorder cast and transpose.

A common pattern after the layout optimizer is casting an uint8 NHWC
image to float before transposing it to NCHW. It is beneficial to reorder
the cast and the transpose to make the transpose process smaller amount
of data. This optimization converts

  Transpose(Cast(image, dst_type), perm)

to

  Cast(Transpose(image, perm), dst_type)

when sizeof(image.type) < sizeof(dst_type).

PiperOrigin-RevId: 171294111
This commit is contained in:
Jingyue Wu 2017-10-06 08:15:48 -07:00 committed by TensorFlower Gardener
parent e7ab55b01f
commit 9d8346a120
2 changed files with 147 additions and 0 deletions

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
namespace grappler {
@ -274,6 +275,26 @@ static bool SimplyReordersData(const NodeDef& node) {
return node.op() == "Transpose";
}
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
static DataType GetDataTypeFromAttr(const NodeDef& node,
const string& attr_name) {
if (!node.attr().count(attr_name)) {
return DT_INVALID;
}
const auto& attr = node.attr().at(attr_name);
if (attr.value_case() != AttrValue::kType) {
return DT_INVALID;
}
return attr.type();
}
static bool IsNumberType(DataType dtype) {
DataTypeVector number_types = NumberTypes();
return std::find(number_types.begin(), number_types.end(), dtype) !=
number_types.end();
}
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, GraphDef* graph_def, NodeMap* node_map,
std::vector<const NodeDef*>* new_nodes) const {
@ -320,6 +341,66 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
}
if (node->op() == "Transpose") {
// Reorder Cast and Transpose if beneficial.
//
// A common pattern after the layout optimizer is casting an uint8 NHWC
// image to float before transposing it to NCHW. It is beneficial to reorder
// the cast and the transpose to make the transpose process smaller amount
// of data. This optimization converts
// Transpose(Cast(image, dst_type), perm)
// to
// Cast(Transpose(image, perm), dst_type)
// when sizeof(image.type) < sizeof(dst_type).
//
// TODO(jingyue): This optimization can be generalized to a cast followed by
// a chain of ops that merely reorder elements (e.g. Reshape and
// DepthToSpace).
const NodeDef* transpose = node;
string dontcare;
string device;
// This optimization can be dangerous on devices other than CPU and GPU. The
// transpose might not be implemented for image.type, or might be slower
// with image.type than with dst_type.
if (DeviceNameUtils::SplitDeviceName(transpose->device(), &dontcare,
&device) &&
(StringPiece(device).contains(DEVICE_CPU) ||
StringPiece(device).contains(DEVICE_GPU))) {
const NodeDef* cast = node_map->GetNode(transpose->input(0));
if (cast->op() == "Cast") {
const NodeDef* input = node_map->GetNode(cast->input(0));
const DataType src_type = GetDataTypeFromAttr(*cast, "SrcT");
const DataType dst_type = GetDataTypeFromAttr(*cast, "DstT");
if (IsNumberType(src_type) && IsNumberType(dst_type) &&
DataTypeSize(src_type) < DataTypeSize(dst_type)) {
NodeDef* new_transpose = graph_def->add_node();
*new_transpose = *transpose;
new_transpose->set_name(transpose->name() + "_" +
DataTypeString(src_type));
(*new_transpose->mutable_attr())["T"].set_type(src_type);
node_map->AddNode(new_transpose->name(), new_transpose);
new_transpose->set_input(0, cast->input(0));
node_map->AddOutput(input->name(), new_transpose->name());
node_map->AddOutput(NodeName(new_transpose->input(1)),
new_transpose->name());
NodeDef* new_cast = graph_def->add_node();
*new_cast = *cast;
new_cast->set_name(cast->name() + "_new");
node_map->AddNode(new_cast->name(), new_cast);
new_cast->set_input(0, new_transpose->name());
node_map->AddOutput(new_transpose->name(), new_cast->name());
new_nodes->push_back(new_transpose);
new_nodes->push_back(new_cast);
return new_cast->name();
}
}
}
}
// Fold a multiply of a scalar into the following convolution. This folding
// can jump across nodes that merely reorders data (such as reshape and
// transpose). For example, we can optimize

View File

@ -109,6 +109,72 @@ TEST_F(ArithmeticOptimizerTest, CombineReshapes) {
[](const NodeDef& node) { return node.op() == "Reshape"; }));
}
TEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
Output nhwc_uint8 =
ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
Output nchw_fp32 =
ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph = output;
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
const NodeDef* transpose_node = nullptr;
for (const NodeDef& node : output.node()) {
if (node.op() == "Transpose") {
EXPECT_EQ(transpose_node, nullptr);
EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
transpose_node = &node;
}
}
EXPECT_NE(transpose_node, nullptr);
for (const NodeDef& node : output.node()) {
if (node.op() == "Cast") {
EXPECT_EQ(NodeName(node.input(0)), transpose_node->name());
}
}
}
TEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
Output nhwc_fp32 =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
Output nchw_uint8 =
ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
item.graph = output;
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
int num_transposes = 0;
for (const NodeDef& node : output.node()) {
if (node.op() == "Transpose") {
EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
EXPECT_EQ(node.input(0), "Cast");
++num_transposes;
}
}
EXPECT_EQ(1, num_transposes);
}
TEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs_shape =