Disabled NCHW layout optimization for integer convolutions.
Integer convolutions do not support NCHW layout (both on CPU and GPU). PiperOrigin-RevId: 334309575 Change-Id: If35531e9678f4e735f7318f8c8fe166abeba7d90
This commit is contained in:
parent
d4ed713e95
commit
8a643858ce
@ -81,6 +81,7 @@ constexpr int kDepthOut = 16;
|
|||||||
{ 0, 3, 1, 2 }
|
{ 0, 3, 1, 2 }
|
||||||
#endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
#endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||||
|
|
||||||
|
template <typename T = float>
|
||||||
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
||||||
const string& padding, const string& device) {
|
const string& padding, const string& device) {
|
||||||
int batch_size = 8;
|
int batch_size = 8;
|
||||||
@ -91,15 +92,15 @@ Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
|||||||
int stride = 1;
|
int stride = 1;
|
||||||
TensorShape input_shape(
|
TensorShape input_shape(
|
||||||
DIMS(batch_size, input_height, input_width, input_depth));
|
DIMS(batch_size, input_height, input_width, input_depth));
|
||||||
Tensor input_data(DT_FLOAT, input_shape);
|
Tensor input_data(DataTypeToEnum<T>::value, input_shape);
|
||||||
test::FillIota<float>(&input_data, 1.0f);
|
test::FillIota<T>(&input_data, static_cast<T>(1));
|
||||||
Output input =
|
Output input =
|
||||||
ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
|
ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
|
||||||
|
|
||||||
TensorShape filter_shape(
|
TensorShape filter_shape(
|
||||||
{filter_size, filter_size, input_depth, filter_count});
|
{filter_size, filter_size, input_depth, filter_count});
|
||||||
Tensor filter_data(DT_FLOAT, filter_shape);
|
Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
|
||||||
test::FillIota<float>(&filter_data, 1.0f);
|
test::FillIota<T>(&filter_data, static_cast<T>(1));
|
||||||
Output filter =
|
Output filter =
|
||||||
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
|
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
|
||||||
|
|
||||||
@ -356,6 +357,25 @@ TEST_F(GenericLayoutOptimizerTest, CPUDevice) {
|
|||||||
#endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
#endif // (GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto conv = SimpleConv2D<int32>(&s, 4, 2, "VALID", "");
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
GenericLayoutOptimizer optimizer(REWRITER_CONFIG);
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||||
|
|
||||||
|
Status status;
|
||||||
|
utils::GraphView graph_view(&output, &status);
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
auto* conv_node = graph_view.GetNode("Conv2D");
|
||||||
|
ASSERT_NE(conv_node, nullptr);
|
||||||
|
VerifyDataFormatAttributeMatch(conv_node, SRC_DATA_FORMAT);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(GenericLayoutOptimizerTest, Connectivity) {
|
TEST_F(GenericLayoutOptimizerTest, Connectivity) {
|
||||||
Scope scope = Scope::NewRootScope();
|
Scope scope = Scope::NewRootScope();
|
||||||
auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
|
auto conv = SimpleConv2D(&scope, 4, 2, "VALID",
|
||||||
|
@ -30,11 +30,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/frame.h"
|
#include "tensorflow/core/grappler/utils/frame.h"
|
||||||
|
#include "tensorflow/core/grappler/utils/graph_view.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
#include "tensorflow/core/protobuf/device_properties.pb.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
@ -83,6 +85,16 @@ inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
|
|||||||
return AttrDataFormatMatch(node, src_data_format, &missing);
|
return AttrDataFormatMatch(node, src_data_format, &missing);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsNonFloatingConv2D(const utils::MutableNodeView& node) {
|
||||||
|
if (IsConv2D(*node.node()) || IsConv2DBackpropInput(*node.node())) {
|
||||||
|
const auto* attr = node.GetAttr(kAttrT);
|
||||||
|
if (attr != nullptr) {
|
||||||
|
return !kDataTypeIsFloating.Contains(attr->type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Utils for layout agnostic transposer.
|
// Utils for layout agnostic transposer.
|
||||||
|
|
||||||
bool IsComparisonOp(const NodeDef& node) {
|
bool IsComparisonOp(const NodeDef& node) {
|
||||||
@ -205,15 +217,19 @@ bool Transposer::ShouldProcess(const TransposeContext& context,
|
|||||||
GetDeviceName(context.virtual_placer.get(), *node_def);
|
GetDeviceName(context.virtual_placer.get(), *node_def);
|
||||||
string device;
|
string device;
|
||||||
string task;
|
string task;
|
||||||
bool is_on_target_device =
|
const bool is_on_target_device =
|
||||||
DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
|
DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
|
||||||
absl::StrContains(absl::AsciiStrToLower(device),
|
absl::StrContains(absl::AsciiStrToLower(device),
|
||||||
absl::AsciiStrToLower(context.target_device));
|
absl::AsciiStrToLower(context.target_device));
|
||||||
|
|
||||||
// Only checks data format for layout sensitive op.
|
// Only checks data format for layout sensitive op.
|
||||||
bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
|
const bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
|
||||||
AttrDataFormatMatch(node, context.src_format);
|
AttrDataFormatMatch(node, context.src_format);
|
||||||
return is_on_target_device && data_format_match &&
|
|
||||||
|
// Only transposes floating point nodes.
|
||||||
|
const bool is_integer_conv2d = IsNonFloatingConv2D(node);
|
||||||
|
|
||||||
|
return is_on_target_device && data_format_match && !is_integer_conv2d &&
|
||||||
!context.nodes_to_preserve.contains(node_def->name()) &&
|
!context.nodes_to_preserve.contains(node_def->name()) &&
|
||||||
!(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
|
!(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
|
||||||
}
|
}
|
||||||
|
@ -1100,7 +1100,8 @@ class Conv2DProcessor : public NodeProcessor {
|
|||||||
protected:
|
protected:
|
||||||
bool ShouldProcess() const override {
|
bool ShouldProcess() const override {
|
||||||
return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
|
return !MustPreserve() && IsNHWC() && IsPortZeroDimsFour(*node_) &&
|
||||||
HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU();
|
HasOutputs() && (!IsGemmUsed() || no_gemm_) && IsOnGPU() &&
|
||||||
|
IsDataTypeFloat();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShapeProto GetShape(const string& input_name) const {
|
TensorShapeProto GetShape(const string& input_name) const {
|
||||||
@ -1131,6 +1132,13 @@ class Conv2DProcessor : public NodeProcessor {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsDataTypeFloat() const {
|
||||||
|
if (node_->attr().find("T") != node_->attr().end()) {
|
||||||
|
return kDataTypeIsFloating.Contains(node_->attr().at("T").type());
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// The logic inside this function is based on the internal implementation of
|
// The logic inside this function is based on the internal implementation of
|
||||||
// Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
|
// Conv2D, Conv2DBackpropInput, and Conv2DBackpropFilter ops, and thus
|
||||||
// needs to be updated accordingly if the internal implementation changes.
|
// needs to be updated accordingly if the internal implementation changes.
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||||
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||||
@ -56,11 +57,13 @@ class LayoutOptimizerTest : public GrapplerTest {
|
|||||||
|
|
||||||
void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
|
void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
|
||||||
|
|
||||||
|
template <typename T = float>
|
||||||
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
||||||
const string& padding) {
|
const string& padding) {
|
||||||
return SimpleConv2D(s, input_size, filter_size, padding, "");
|
return SimpleConv2D<T>(s, input_size, filter_size, padding, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T = float>
|
||||||
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
|
||||||
const string& padding, const string& device) {
|
const string& padding, const string& device) {
|
||||||
int batch_size = 8;
|
int batch_size = 8;
|
||||||
@ -71,15 +74,15 @@ class LayoutOptimizerTest : public GrapplerTest {
|
|||||||
int stride = 1;
|
int stride = 1;
|
||||||
TensorShape input_shape(
|
TensorShape input_shape(
|
||||||
{batch_size, input_height, input_width, input_depth});
|
{batch_size, input_height, input_width, input_depth});
|
||||||
Tensor input_data(DT_FLOAT, input_shape);
|
Tensor input_data(DataTypeToEnum<T>::value, input_shape);
|
||||||
test::FillIota<float>(&input_data, 1.0f);
|
test::FillIota<T>(&input_data, static_cast<T>(1));
|
||||||
Output input =
|
Output input =
|
||||||
ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
|
ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
|
||||||
|
|
||||||
TensorShape filter_shape(
|
TensorShape filter_shape(
|
||||||
{filter_size, filter_size, input_depth, filter_count});
|
{filter_size, filter_size, input_depth, filter_count});
|
||||||
Tensor filter_data(DT_FLOAT, filter_shape);
|
Tensor filter_data(DataTypeToEnum<T>::value, filter_shape);
|
||||||
test::FillIota<float>(&filter_data, 1.0f);
|
test::FillIota<T>(&filter_data, static_cast<T>(1));
|
||||||
Output filter =
|
Output filter =
|
||||||
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
|
ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
|
||||||
|
|
||||||
@ -359,6 +362,20 @@ TEST_F(LayoutOptimizerTest, ExplicitPadding) {
|
|||||||
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LayoutOptimizerTest, DataTypeIsInt32) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
auto conv = SimpleConv2D<int32>(&s, 4, 2, "EXPLICIT");
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
LayoutOptimizer optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
|
||||||
|
NodeMap node_map(&output);
|
||||||
|
EXPECT_FALSE(
|
||||||
|
node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(LayoutOptimizerTest, Pad) {
|
TEST_F(LayoutOptimizerTest, Pad) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
|
auto conv = SimpleConv2D(&s, 4, 2, "VALID");
|
||||||
|
Loading…
Reference in New Issue
Block a user