Disabled NCHW layout optimization for integer convolutions.

Integer convolutions do not support NCHW layout (both on CPU and GPU).

PiperOrigin-RevId: 334309575
Change-Id: If35531e9678f4e735f7318f8c8fe166abeba7d90
This commit is contained in:
Sung Jin Hwang 2020-09-28 22:52:25 -07:00 committed by TensorFlower Gardener
parent d4ed713e95
commit 8a643858ce
4 changed files with 75 additions and 14 deletions

View File

@ -81,6 +81,7 @@ constexpr int kDepthOut = 16;
{ 0, 3, 1, 2 } { 0, 3, 1, 2 }
#endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
template <typename T = float>
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
const string& padding, const string& device) { const string& padding, const string& device) {
int batch_size = 8; int batch_size = 8;
@ -91,15 +92,15 @@ Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
int stride = 1; int stride = 1;
TensorShape input_shape( TensorShape input_shape(
DIMS(batch_size, input_height, input_width, input_depth)); DIMS(batch_size, input_height, input_width, input_depth));
Tensor input_data(DT_FLOAT, input_shape); Tensor input_data(DataTypeToEnum<T>::value, input_shape);
test::FillIota<float>(&input_data, 1.0f); test::FillIota<T>(&input_data, static_cast<T>(1));
Output input = Output input =
ops::Const(s->WithOpName("Input"), Input::Initializer(input_data)); ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
TensorShape filter_shape( TensorShape filter_shape(
{filter_size, filter_size, input_depth, filter_count}); {filter_size, filter_size, input_depth, filter_count});
Tensor filter_data(DT_FLOAT, filter_shape); Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
test::FillIota<float>(&filter_data, 1.0f); test::FillIota<T>(&filter_data, static_cast<T>(1));
Output filter = Output filter =
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
@ -356,6 +357,25 @@ TEST_F(GenericLayoutOptimizerTest, CPUDevice) {
#endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM) #endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
} }
TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D<int32>(&s, 4, 2, "VALID", "");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
Status status;
utils::GraphView graph_view(&output, &status);
TF_ASSERT_OK(status);
auto* conv_node = graph_view.GetNode("Conv2D");
ASSERT_NE(conv_node, nullptr);
VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
}
TEST_F(GenericLayoutOptimizerTest, Connectivity) { TEST_F(GenericLayoutOptimizerTest, Connectivity) {
Scope scope = Scope::NewRootScope(); Scope scope = Scope::NewRootScope();
auto conv = SimpleConv2D(&scope, 4, 2, "VALID", auto conv = SimpleConv2D(&scope, 4, 2, "VALID",

View File

@ -30,11 +30,13 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/frame.h" #include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/grappler/utils/graph_view.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/protobuf/device_properties.pb.h"
#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/device_name_utils.h"
@ -83,6 +85,16 @@ inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
return AttrDataFormatMatch(node, src_data_format, &missing); return AttrDataFormatMatch(node, src_data_format, &missing);
} }
bool IsNonFloatingConv2D(const utils::MutableNodeView& node) {
if (IsConv2D(*node.node()) || IsConv2DBackpropInput(*node.node())) {
const auto* attr = node.GetAttr(kAttrT);
if (attr != nullptr) {
return !kDataTypeIsFloating.Contains(attr->type());
}
}
return false;
}
// Utils for layout agnostic transposer. // Utils for layout agnostic transposer.
bool IsComparisonOp(const NodeDef& node) { bool IsComparisonOp(const NodeDef& node) {
@ -205,15 +217,19 @@ bool Transposer::ShouldProcess(const TransposeContext& context,
GetDeviceName(context.virtual_placer.get(), *node_def); GetDeviceName(context.virtual_placer.get(), *node_def);
string device; string device;
string task; string task;
bool is_on_target_device = const bool is_on_target_device =
DeviceNameUtils::SplitDeviceName(device_name, &task, &device) && DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
absl::StrContains(absl::AsciiStrToLower(device), absl::StrContains(absl::AsciiStrToLower(device),
absl::AsciiStrToLower(context.target_device)); absl::AsciiStrToLower(context.target_device));
// Only checks data format for layout sensitive op. // Only checks data format for layout sensitive op.
bool data_format_match = !IsLayoutSensitiveOp(*node_def) || const bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
AttrDataFormatMatch(node, context.src_format); AttrDataFormatMatch(node, context.src_format);
return is_on_target_device && data_format_match &&
// Only transposes floating point nodes.
const bool is_integer_conv2d = IsNonFloatingConv2D(node);
return is_on_target_device && data_format_match && !is_integer_conv2d &&
!context.nodes_to_preserve.contains(node_def->name()) && !context.nodes_to_preserve.contains(node_def->name()) &&
!(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0); !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
} }

View File

@ -1100,7 +1100,8 @@ class Conv2DProcessor : public NodeProcessor {
protected: protected:
bool ShouldProcess() const override { bool ShouldProcess() const override {
return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) && return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU(); HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU() &&
IsDataTypeFloat();
} }
TensorShapeProto GetShape(const string& input_name) const { TensorShapeProto GetShape(const string& input_name) const {
@ -1131,6 +1132,13 @@ class Conv2DProcessor : public NodeProcessor {
return false; return false;
} }
bool IsDataTypeFloat() const {
if (node_->attr().find("T") != node_->attr().end()) {
return kDataTypeIsFloating.Contains(node_->attr().at("T").type());
}
return false;
}
// The logic inside this function is based on the internal implementation of // The logic inside this function is based on the internal implementation of
// Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus // Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
// needs to be updated accordingly if the internal implementation changes. // needs to be updated accordingly if the internal implementation changes.

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/single_machine.h" #include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/virtual_placer.h" #include "tensorflow/core/grappler/costs/virtual_placer.h"
@ -56,11 +57,13 @@ class LayoutOptimizerTest : public GrapplerTest {
void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); } void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
template <typename T = float>
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
const string& padding) { const string& padding) {
return SimpleConv2D(s, input_size, filter_size, padding, ""); return SimpleConv2D<T>(s, input_size, filter_size, padding, "");
} }
template <typename T = float>
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size, Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
const string& padding, const string& device) { const string& padding, const string& device) {
int batch_size = 8; int batch_size = 8;
@ -71,15 +74,15 @@ class LayoutOptimizerTest : public GrapplerTest {
int stride = 1; int stride = 1;
TensorShape input_shape( TensorShape input_shape(
{batch_size, input_height, input_width, input_depth}); {batch_size, input_height, input_width, input_depth});
Tensor input_data(DT_FLOAT, input_shape); Tensor input_data(DataTypeToEnum<T>::value, input_shape);
test::FillIota<float>(&input_data, 1.0f); test::FillIota<T>(&input_data, static_cast<T>(1));
Output input = Output input =
ops::Const(s->WithOpName("Input"), Input::Initializer(input_data)); ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
TensorShape filter_shape( TensorShape filter_shape(
{filter_size, filter_size, input_depth, filter_count}); {filter_size, filter_size, input_depth, filter_count});
Tensor filter_data(DT_FLOAT, filter_shape); Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
test::FillIota<float>(&filter_data, 1.0f); test::FillIota<T>(&filter_data, static_cast<T>(1));
Output filter = Output filter =
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data)); ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
@ -359,6 +362,20 @@ TEST_F(LayoutOptimizerTest, ExplicitPadding) {
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer")); EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
} }
TEST_F(LayoutOptimizerTest, DataTypeIsInt32) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D<int32>(&s, 4, 2, "EXPLICIT");
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
LayoutOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
NodeMap node_map(&output);
EXPECT_FALSE(
node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
}
TEST_F(LayoutOptimizerTest, Pad) { TEST_F(LayoutOptimizerTest, Pad) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 4, 2, "VALID"); auto conv = SimpleConv2D(&s, 4, 2, "VALID");