[Grappler] Have GenericLayoutOptimizer be specific to GPUs for now.

PiperOrigin-RevId: 254307839
This commit is contained in:
Andy Ly 2019-06-20 17:46:58 -07:00 committed by TensorFlower Gardener
parent 155c4acd30
commit ffc6b92b28
5 changed files with 488 additions and 886 deletions

View File

@ -31,6 +31,17 @@ namespace grappler {
namespace { namespace {
inline int GetNumGPUs(const Cluster& cluster) {
auto devices = cluster.GetDevices();
int num_gpus = 0;
for (const auto& device : devices) {
if (device.second.type() == "GPU") {
num_gpus++;
}
}
return num_gpus;
}
Status ExpandLayoutSensitiveOp(TransposeContext* context, Status ExpandLayoutSensitiveOp(TransposeContext* context,
TransposerFactory* transposer_factory) { TransposerFactory* transposer_factory) {
const int num_nodes = context->num_nodes; const int num_nodes = context->num_nodes;
@ -180,9 +191,10 @@ Status EraseOutputShapeAttrs(TransposeContext* context) {
utils::Mutation* mutation = graph_view->GetMutationBuilder(); utils::Mutation* mutation = graph_view->GetMutationBuilder();
const int num_nodes = graph_view->NumNodes(); const int num_nodes = graph_view->NumNodes();
for (int i = 0; i < num_nodes; ++i) { for (int i = 0; i < num_nodes; ++i) {
mutation->RemoveNodeAttr(graph_view->GetNode(i), "_output_shapes"); mutation->RemoveNodeAttr(graph_view->GetNode(i), kAttrOutputShape);
TF_RETURN_IF_ERROR(mutation->Apply());
} }
return mutation->Apply(); return Status::OK();
} }
} // namespace } // namespace
@ -190,15 +202,23 @@ Status EraseOutputShapeAttrs(TransposeContext* context) {
Status GenericLayoutOptimizer::Optimize(Cluster* cluster, Status GenericLayoutOptimizer::Optimize(Cluster* cluster,
const GrapplerItem& item, const GrapplerItem& item,
GraphDef* output) { GraphDef* output) {
// If optimizer returns early with error, output will be the input graph. if (cluster == nullptr) {
*output = item.graph; LOG(WARNING)
<< "generic layout optimizer was called with cluster == nullptr";
return errors::Aborted("cluster == nullptr.");
}
if (GetNumGPUs(*cluster) < 1) {
return errors::Aborted(
"No GPUs found: GenericLayoutOptimizer is currently only tuned for "
"GPU.");
}
TransposeContext context; TransposeContext context;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(TransposeContext::InitializeTransposeContext(
TransposeContext::InitializeTransposeContext(item, cluster, &context)); item, cluster, src_format_, dst_format_, target_device_, &context));
TransposerFactory transposer_factory; TransposerFactory transposer_factory;
TF_RETURN_IF_ERROR(ExpandLayoutSensitiveOp(&context, &transposer_factory)); TF_RETURN_IF_ERROR(ExpandLayoutSensitiveOp(&context, &transposer_factory));
TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory)); TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory));
// TODO(lyandy): Merge non cancellable nodes.
TF_RETURN_IF_ERROR(EraseCancellableNodes(&context)); TF_RETURN_IF_ERROR(EraseCancellableNodes(&context));
TF_RETURN_IF_ERROR(EraseOutputShapeAttrs(&context)); TF_RETURN_IF_ERROR(EraseOutputShapeAttrs(&context));
@ -213,25 +233,6 @@ void GenericLayoutOptimizer::Feedback(Cluster* cluster,
// Takes no feedback. // Takes no feedback.
} }
string GetAndValidateParameter(const string& parameter,
const AttrValueMap& parameter_map,
const std::set<string>& valid_inputs,
std::vector<string>* validation_errors,
std::vector<string>* missing_parameters) {
if (parameter_map.find(parameter) != parameter_map.end()) {
string input = str_util::Uppercase(parameter_map.at(parameter).s());
if (valid_inputs.find(input) != valid_inputs.end()) {
return input;
}
validation_errors->push_back(absl::StrCat(
"Invalid input ", input, " for parameter ", parameter,
", must be one of [", str_util::Join(valid_inputs, ", "), "]."));
} else {
missing_parameters->push_back(parameter);
}
return "";
}
Status GenericLayoutOptimizer::Init( Status GenericLayoutOptimizer::Init(
const RewriterConfig_CustomGraphOptimizer* config) { const RewriterConfig_CustomGraphOptimizer* config) {
return Status::OK(); return Status::OK();

View File

@ -36,6 +36,11 @@ class GenericLayoutOptimizer : public CustomGraphOptimizer {
const GraphDef& optimize_output, double result) override; const GraphDef& optimize_output, double result) override;
Status Init(const RewriterConfig_CustomGraphOptimizer* config) final; Status Init(const RewriterConfig_CustomGraphOptimizer* config) final;
private:
string target_device_ = "GPU";
string src_format_ = "NHWC";
string dst_format_ = "NCHW";
}; };
} // namespace grappler } // namespace grappler

View File

@ -36,35 +36,7 @@ namespace grappler {
constexpr char kAttrSrcFormat[] = "src_format"; constexpr char kAttrSrcFormat[] = "src_format";
constexpr char kAttrDstFormat[] = "dst_format"; constexpr char kAttrDstFormat[] = "dst_format";
constexpr char kAttrOutputShape[] = "_output_shapes";
// NodeLayoutContext holds the data formats to convert from and to for a layout
// sensitive node.
struct NodeLayoutContext {
void Initialize(absl::string_view src_format, absl::string_view dst_format,
bool is_transposable);
bool operator==(const NodeLayoutContext& other) const {
return src == other.src && dst == other.dst;
}
bool operator!=(const NodeLayoutContext& other) const {
return !(*this == other);
}
string src;
string dst;
std::vector<int> src_to_dst;
std::vector<int> dst_to_src;
bool transposable = false;
};
// TransposeContext holds the data formats to convert from and to of each data
// fanin and fanout for a layout agnostic node.
struct LayoutAgnosticNodeFanouts {
std::vector<NodeLayoutContext> fanout_layouts;
std::vector<NodeLayoutContext> fanin_layouts;
bool transposable = false;
};
// TransposeContext owns all data members. Must initialize GraphProperties, // TransposeContext owns all data members. Must initialize GraphProperties,
// FrameView, GraphDef and MutableGraphView with the same graph. NodeDef // FrameView, GraphDef and MutableGraphView with the same graph. NodeDef
@ -76,8 +48,10 @@ struct TransposeContext {
// TransposeContext outside constructor. // TransposeContext outside constructor.
static Status InitializeTransposeContext(const GrapplerItem& item, static Status InitializeTransposeContext(const GrapplerItem& item,
const Cluster* cluster, const Cluster* cluster,
absl::string_view src_format,
absl::string_view dst_format,
absl::string_view target_device,
TransposeContext* context); TransposeContext* context);
FrameView frames; FrameView frames;
GraphDef graph; GraphDef graph;
// Number of nodes in the original graph. As new nodes are appended to the end // Number of nodes in the original graph. As new nodes are appended to the end
@ -88,12 +62,12 @@ struct TransposeContext {
std::unique_ptr<GraphProperties> graph_properties; std::unique_ptr<GraphProperties> graph_properties;
std::unique_ptr<utils::MutableGraphView> graph_view; std::unique_ptr<utils::MutableGraphView> graph_view;
std::unique_ptr<const VirtualPlacer> virtual_placer; std::unique_ptr<const VirtualPlacer> virtual_placer;
// Mapping of node index to data layout conversion of layout sensitive nodes.
absl::flat_hash_map<int, NodeLayoutContext> sensitive_node_layouts; string src_format;
// Mapping of node index to data layout conversion of layout agnostic nodes. string dst_format;
absl::flat_hash_map<int, LayoutAgnosticNodeFanouts> agnostic_node_layouts; string target_device;
NodeLayoutContext unknown_node_layout; std::vector<int> src_to_dst;
LayoutAgnosticNodeFanouts unknown_agnostic_node_layout; std::vector<int> dst_to_src;
}; };
class Transposer { class Transposer {
@ -105,6 +79,16 @@ class Transposer {
virtual ~Transposer() {} virtual ~Transposer() {}
// Returns true iff the node should be processed by this transposer.
// NodeProcessors may perform additional oprand specific checks before
// processing if necessary.
// Following common conditions are checked:
// * node's device matches target device
// * node's source format matches config's source format
// * node has output
virtual bool ShouldProcess(const TransposeContext& context,
const utils::MutableNodeView& node) const;
// Transposes given node from src format to dst format. Also perform other // Transposes given node from src format to dst format. Also perform other
// necessary operations to guarantee the graph produce the same result. // necessary operations to guarantee the graph produce the same result.
// Eg. Add Transpose node sets before fanin ports and after fanout ports. // Eg. Add Transpose node sets before fanin ports and after fanout ports.
@ -133,7 +117,6 @@ class Transposer {
// Update all edges between dst_node->fanin[dst_ports] and dst_node by // Update all edges between dst_node->fanin[dst_ports] and dst_node by
// inserting an op node. // inserting an op node.
Status UpdateFaninEdgesWithOp(TransposeContext* context, Status UpdateFaninEdgesWithOp(TransposeContext* context,
const NodeLayoutContext& layout,
absl::Span<const int> dst_ports, absl::Span<const int> dst_ports,
utils::MutableNodeView* dst_node, utils::MutableNodeView* dst_node,
absl::string_view op); absl::string_view op);
@ -141,7 +124,6 @@ class Transposer {
// Update all edges between src_node:src_ports and nodes take // Update all edges between src_node:src_ports and nodes take
// src_node:src_ports as fanin. Also update attr _output_shape of src_node. // src_node:src_ports as fanin. Also update attr _output_shape of src_node.
Status UpdateFanoutEdgesWithOp(TransposeContext* context, Status UpdateFanoutEdgesWithOp(TransposeContext* context,
const NodeLayoutContext& layout,
absl::Span<const int> src_ports, absl::Span<const int> src_ports,
utils::MutableNodeView* src_node, utils::MutableNodeView* src_node,
absl::string_view op); absl::string_view op);
@ -149,7 +131,6 @@ class Transposer {
// Creates a DataFromat node with given properties. // Creates a DataFromat node with given properties.
// DataFromat op is either DataFormatVecPermute or DataFormatDimMap. // DataFromat op is either DataFormatVecPermute or DataFormatDimMap.
Status CreateDataFormatNode(TransposeContext* context, Status CreateDataFormatNode(TransposeContext* context,
const NodeLayoutContext& layout,
absl::string_view node_name, absl::string_view op, absl::string_view node_name, absl::string_view op,
absl::string_view device, absl::string_view device,
const DataType& data_type, bool is_fanin_on_host, const DataType& data_type, bool is_fanin_on_host,
@ -167,14 +148,11 @@ class Transposer {
const utils::MutableNodeView& node) const; const utils::MutableNodeView& node) const;
string GetDeviceName(const VirtualPlacer* virtual_placer, string GetDeviceName(const VirtualPlacer* virtual_placer,
const NodeDef& node) const; const NodeDef& node) const;
virtual string GetDstDataFormatForDevice(
const DeviceProperties& device) const;
// Update all edges between dst_node->fanin[dst_ports] and dst_node. // Update all edges between dst_node->fanin[dst_ports] and dst_node.
// A node with op is created and insereted between all edges. // A node with op is created and insereted between all edges.
// op is one of Transpose, DataFormatVecPermute or DataFormatDimMap. // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap.
Status UpdateEdge(TransposeContext* context, const NodeLayoutContext& layout, Status UpdateEdge(TransposeContext* context, absl::string_view name_format,
absl::string_view name_format, absl::string_view op, absl::string_view op, const AttrValue* input_shape,
const AttrValue* input_shape,
bool is_src_format_to_dst_format, const int src_port, bool is_src_format_to_dst_format, const int src_port,
const int dst_port, utils::MutableNodeView* src_node, const int dst_port, utils::MutableNodeView* src_node,
utils::MutableNodeView* dst_node); utils::MutableNodeView* dst_node);
@ -197,19 +175,7 @@ class LayoutSensitiveOpTransposer : public Transposer {
// Updates attrs data_format, ksize, strides of the given node to dst_format. // Updates attrs data_format, ksize, strides of the given node to dst_format.
// _output_shape is updated during UpdateOutputEdges. // _output_shape is updated during UpdateOutputEdges.
Status UpdateNode(TransposeContext* context, const NodeLayoutContext& layout, Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node);
utils::MutableNodeView* node);
protected:
// Check if a node is transposable, and return what data format to transpose
// from and to.
const NodeLayoutContext& CheckNodeTransposable(
TransposeContext* context, const utils::MutableNodeView& node) const;
// Apply mutation. If mutation fails, reset layout being transpose to
// associated with the node.
Status CommitMutation(TransposeContext* context,
const utils::MutableNodeView& node);
}; };
// Layout sensitive op transposers. // Layout sensitive op transposers.
@ -289,23 +255,12 @@ class LayoutAgnosticOpTransposer : public Transposer {
explicit LayoutAgnosticOpTransposer() : Transposer() {} explicit LayoutAgnosticOpTransposer() : Transposer() {}
protected: protected:
void ProcessFanoutDataLayouts(TransposeContext* context, bool IsAfterDstToSrcTransform(const TransposeContext& context,
const utils::MutableNodeView& node); const utils::MutableNodeView& node) const;
void FetchNodeDataLayouts(TransposeContext* context, std::vector<int> GetVariadic4DFaninPorts(
const utils::MutableNodeView& node); const TransposeContext& context,
const utils::MutableNodeView& node) const;
// Check if a node is transposable, and return what data format to transpose
// from and to.
const LayoutAgnosticNodeFanouts& CheckNodeTransposable(
TransposeContext* context, const utils::MutableNodeView& node);
// Validate node and get data format layout to transpose from and to for a
// given fanout port of a node.
const NodeLayoutContext& GetDataFormatLayoutForNode(
TransposeContext* context, const utils::MutableNodeView& node,
int fanout_port,
std::function<bool(const utils::MutableNodeView& node)> node_check);
}; };
class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer { class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
@ -345,7 +300,6 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
absl::string_view shape_const_node_name, absl::string_view shape_const_node_name,
const DataType& data_type); const DataType& data_type);
Status MaybeReshapeVectorFanin(TransposeContext* context, Status MaybeReshapeVectorFanin(TransposeContext* context,
const NodeLayoutContext& layout,
utils::MutableNodeView* node); utils::MutableNodeView* node);
}; };
@ -381,7 +335,9 @@ class MergeTransposer : public LayoutAgnosticOpTransposer {
utils::MutableNodeView* node) override; utils::MutableNodeView* node) override;
private: private:
std::vector<int> GetFaninPorts(const utils::MutableNodeView& node); bool IsEveryFaninAfterDstToSrcTransform(
const TransposeContext& context,
const utils::MutableNodeView& node) const;
}; };
class PadTransposer : public LayoutAgnosticOpTransposer { class PadTransposer : public LayoutAgnosticOpTransposer {
@ -403,7 +359,7 @@ class ReduceTransposer : public LayoutAgnosticOpTransposer {
bool KeepDims(const utils::MutableNodeView& node); bool KeepDims(const utils::MutableNodeView& node);
bool IsAlongAxis(const utils::MutableNodeView& axis_node, bool IsAlongAxis(const utils::MutableNodeView& axis_node,
absl::Span<const int> axis); absl::Span<const int> axis);
bool IsReduceAxisSupported(const NodeLayoutContext& layout, bool IsReduceAxisSupported(const TransposeContext& context,
const utils::MutableNodeView& node); const utils::MutableNodeView& node);
}; };
@ -495,8 +451,8 @@ class StridedSliceTransposer : public LayoutAgnosticOpTransposer {
private: private:
bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask); bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask);
bool HasOnlyBeginEndMask(const utils::MutableNodeView& node); bool HasOnlyBeginEndMask(const utils::MutableNodeView& node);
Status PermuteMask(TransposeContext* context, const NodeLayoutContext& layout, Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node,
utils::MutableNodeView* node, absl::string_view mask); absl::string_view mask);
}; };
class SwitchTransposer : public LayoutAgnosticOpTransposer { class SwitchTransposer : public LayoutAgnosticOpTransposer {
@ -580,19 +536,8 @@ std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node);
bool GetValueAttrIfConstPermTransposeNode(const utils::MutableNodeView& node, bool GetValueAttrIfConstPermTransposeNode(const utils::MutableNodeView& node,
Tensor* tensor); Tensor* tensor);
// Returns true if the given node is a transpose op performing given const
// permutation.
bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node,
absl::Span<const int> permutation);
bool IsDataFormatOp(const utils::MutableNodeView& node); bool IsDataFormatOp(const utils::MutableNodeView& node);
// Returns true if the given node is DataformatDimMap or DataformatVecPermute
// performing given permutation from scr_format to dst_format.
bool IsValidDataFormatNode(const utils::MutableNodeView& node,
absl::string_view src_format,
absl::string_view dst_format);
std::vector<int> GetPermutation(absl::string_view src_format, std::vector<int> GetPermutation(absl::string_view src_format,
absl::string_view dst_format); absl::string_view dst_format);

View File

@ -49,6 +49,7 @@ constexpr int kOutHeight = 5;
constexpr int kDepthOut = 16; constexpr int kDepthOut = 16;
constexpr char kSrcFormat[] = "NHWC"; constexpr char kSrcFormat[] = "NHWC";
constexpr char kDstFormat[] = "NCHW"; constexpr char kDstFormat[] = "NCHW";
constexpr char kGPU[] = "GPU";
constexpr char kAttrOutputShapes[] = "_output_shapes"; constexpr char kAttrOutputShapes[] = "_output_shapes";
constexpr char kAttrDataFormat[] = "data_format"; constexpr char kAttrDataFormat[] = "data_format";
constexpr char kOpTranspose[] = "Transpose"; constexpr char kOpTranspose[] = "Transpose";
@ -316,7 +317,7 @@ TEST_F(TransposerTest, CreateConstPermNode) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
TransposerImpl transposer; TransposerImpl transposer;
constexpr char kNodeName[] = "const_perm_node"; constexpr char kNodeName[] = "const_perm_node";
@ -358,8 +359,8 @@ TEST_F(TransposerTest, CreateTransposeNode) {
GrapplerItem item; GrapplerItem item;
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
TF_ASSERT_OK( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
TransposeContext::InitializeTransposeContext(item, nullptr, &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
TransposerImpl transposer; TransposerImpl transposer;
constexpr char kNodeNameFormat[] = constexpr char kNodeNameFormat[] =
@ -398,14 +399,12 @@ TEST_F(TransposerTest, UpdateNode) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer transposer; DefaultLayoutSensitiveOpTransposer transposer;
auto* conv2d = context.graph_view->GetNode("conv2d"); auto* conv2d = context.graph_view->GetNode("conv2d");
ASSERT_NE(conv2d, nullptr); ASSERT_NE(conv2d, nullptr);
NodeLayoutContext layout; TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true);
TF_ASSERT_OK(transposer.UpdateNode(&context, layout, conv2d));
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
auto* updated_conv2d = context.graph_view->GetNode("conv2d"); auto* updated_conv2d = context.graph_view->GetNode("conv2d");
@ -430,7 +429,7 @@ TEST_F(TransposerTest, UpdateStrides) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), "ABCD", "ACBD", kGPU, &context));
AttrValue_ListValue expected_original_strides = AttrValue_ListValue expected_original_strides =
MakeAttrValueListValueFromVector({1, 2, 4, 1}); MakeAttrValueListValueFromVector({1, 2, 4, 1});
@ -449,9 +448,7 @@ TEST_F(TransposerTest, UpdateStrides) {
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
DefaultLayoutSensitiveOpTransposer transposer; DefaultLayoutSensitiveOpTransposer transposer;
NodeLayoutContext layout; TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
layout.Initialize("ABCD", "ACBD", /*is_transposable=*/true);
TF_ASSERT_OK(transposer.UpdateNode(&context, layout, conv2d));
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
auto* updated_conv2d = context.graph_view->GetNode("conv2d"); auto* updated_conv2d = context.graph_view->GetNode("conv2d");
@ -468,19 +465,17 @@ TEST_F(TransposerTest, UpdateFaninEdgesTranspose) {
GrapplerItem item; GrapplerItem item;
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true)); TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
TF_ASSERT_OK( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
TransposeContext::InitializeTransposeContext(item, nullptr, &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
FusedBatchNormGradTransposer transposer; FusedBatchNormGradTransposer transposer;
NodeLayoutContext layout;
layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true);
auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad"); auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
ASSERT_NE(fbng, nullptr); ASSERT_NE(fbng, nullptr);
const auto& fbng_output_shapes_attr = fbng->GetAttr("_output_shapes"); const auto& fbng_output_shapes_attr = fbng->GetAttr("_output_shapes");
ASSERT_NE(fbng_output_shapes_attr, nullptr); ASSERT_NE(fbng_output_shapes_attr, nullptr);
const TensorShapeProto& expected_shape = fbng_output_shapes_attr->shape(); const TensorShapeProto& expected_shape = fbng_output_shapes_attr->shape();
TF_ASSERT_OK(transposer.UpdateFaninEdgesWithOp(&context, layout, {0, 1}, fbng, TF_ASSERT_OK(
kOpTranspose)); transposer.UpdateFaninEdgesWithOp(&context, {0, 1}, fbng, kOpTranspose));
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
// Verify output shape matches input shape. // Verify output shape matches input shape.
@ -528,12 +523,10 @@ TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) {
GrapplerItem item; GrapplerItem item;
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
TF_ASSERT_OK( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
TransposeContext::InitializeTransposeContext(item, nullptr, &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
TransposerImpl transposer; TransposerImpl transposer;
NodeLayoutContext layout;
layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true);
TensorShapeProto expected_original_shape = TensorShapeProto expected_original_shape =
MakeTensorShapeFromDimensions({32, 5, 3, 16}); MakeTensorShapeFromDimensions({32, 5, 3, 16});
TensorShapeProto expected_updated_shape = TensorShapeProto expected_updated_shape =
@ -543,8 +536,8 @@ TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) {
ASSERT_NE(conv2d, nullptr); ASSERT_NE(conv2d, nullptr);
VerifyShapeAttributeMatch(conv2d, 0, expected_original_shape.DebugString()); VerifyShapeAttributeMatch(conv2d, 0, expected_original_shape.DebugString());
TF_ASSERT_OK(transposer.UpdateFanoutEdgesWithOp(&context, layout, {0}, conv2d, TF_ASSERT_OK(
kOpTranspose)); transposer.UpdateFanoutEdgesWithOp(&context, {0}, conv2d, kOpTranspose));
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply()); TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
auto* updated_conv2d = context.graph_view->GetNode("conv2d"); auto* updated_conv2d = context.graph_view->GetNode("conv2d");
@ -584,7 +577,7 @@ TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestFusedBatchNorm) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleFusedBatchNorm(&item.graph)); TF_ASSERT_OK(CreateSimpleFusedBatchNorm(&item.graph));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer transposer; DefaultLayoutSensitiveOpTransposer transposer;
auto* bn = context.graph_view->GetNode("bn"); auto* bn = context.graph_view->GetNode("bn");
@ -639,7 +632,7 @@ TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestConv2D) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph)); TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer transposer; DefaultLayoutSensitiveOpTransposer transposer;
auto* conv2d = context.graph_view->GetNode("conv2d"); auto* conv2d = context.graph_view->GetNode("conv2d");
@ -681,7 +674,7 @@ TEST_F(TransposerTest, MaxPoolGradTransposerTest) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleMaxPoolGrad(&item.graph)); TF_ASSERT_OK(CreateSimpleMaxPoolGrad(&item.graph));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
MaxPoolGradTransposer transposer; MaxPoolGradTransposer transposer;
auto* maxpool_grad = context.graph_view->GetNode("maxpool_grad"); auto* maxpool_grad = context.graph_view->GetNode("maxpool_grad");
@ -733,7 +726,7 @@ TEST_F(TransposerTest, BiasAddGradTransposerTest) {
TF_ASSERT_OK(CreateSimpleBiasAddGrad( TF_ASSERT_OK(CreateSimpleBiasAddGrad(
&item.graph, {kBatchSize, kHeight, kWidth, kDepthIn})); &item.graph, {kBatchSize, kHeight, kWidth, kDepthIn}));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
BiasAddGradTransposer transposer; BiasAddGradTransposer transposer;
auto* bag = context.graph_view->GetNode("bag"); auto* bag = context.graph_view->GetNode("bag");
@ -768,7 +761,7 @@ TEST_F(TransposerTest, BiasAddGradTransposerIncorrectInputTest) {
TF_ASSERT_OK( TF_ASSERT_OK(
CreateSimpleBiasAddGrad(&item.graph, {kHeight, kWidth, kDepthIn})); CreateSimpleBiasAddGrad(&item.graph, {kHeight, kWidth, kDepthIn}));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
BiasAddGradTransposer transposer; BiasAddGradTransposer transposer;
auto* bag = context.graph_view->GetNode("bag"); auto* bag = context.graph_view->GetNode("bag");
@ -802,7 +795,7 @@ TEST_F(TransposerTest, Conv2DBackpropFilterTransposerTest) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleConv2DBackpropFilter(&item.graph)); TF_ASSERT_OK(CreateSimpleConv2DBackpropFilter(&item.graph));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
Conv2DBackpropFilterTransposer transposer; Conv2DBackpropFilterTransposer transposer;
auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter"); auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter");
@ -854,7 +847,7 @@ TEST_F(TransposerTest, FusedBatchNormGradTransposerIsTrainingTest) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true)); TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
FusedBatchNormGradTransposer transposer; FusedBatchNormGradTransposer transposer;
auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad"); auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
@ -922,7 +915,7 @@ TEST_F(TransposerTest, FusedBatchNormGradTransposerNotTrainingTest) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, false)); TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, false));
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
FusedBatchNormGradTransposer transposer; FusedBatchNormGradTransposer transposer;
auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad"); auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
@ -992,7 +985,7 @@ TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityTest) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1043,7 +1036,7 @@ TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityBadInputTest) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1084,7 +1077,7 @@ TEST_F(TransposerTest, AddNTransposerTest) {
TF_ASSERT_OK(CreateSimpleAddN(&item.graph)); TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* conv2d = context.graph_view->GetNode("conv2d"); auto* conv2d = context.graph_view->GetNode("conv2d");
@ -1147,7 +1140,7 @@ TEST_F(TransposerTest, AddNTransposerNotAfterTransformTest) {
TF_ASSERT_OK(CreateSimpleAddN(&item.graph)); TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
AddNTransposer addn_transposer; AddNTransposer addn_transposer;
auto* an = context.graph_view->GetNode("add_n"); auto* an = context.graph_view->GetNode("add_n");
@ -1197,7 +1190,7 @@ TEST_F(TransposerTest, IdentityNTransposerTest) {
TF_ASSERT_OK(CreateSimpleIdentityN(&item.graph)); TF_ASSERT_OK(CreateSimpleIdentityN(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* conv2d_1 = context.graph_view->GetNode("conv2d_1"); auto* conv2d_1 = context.graph_view->GetNode("conv2d_1");
@ -1295,7 +1288,7 @@ TEST_F(TransposerTest, MergeTransposerTestMergeBothInputsConvertible) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1354,7 +1347,7 @@ TEST_F(TransposerTest, MergeTransposerTestMergeOneInputNotConvertible) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1407,7 +1400,7 @@ TEST_F(TransposerTest, PadTransposerTest) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1467,7 +1460,7 @@ TEST_F(TransposerTest, SwitchTransposerTest) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1533,7 +1526,7 @@ TEST_F(TransposerTest, TernaryOpTransposerTest) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1600,7 +1593,7 @@ TEST_F(TransposerTest, UnaryGradTransposerTestTanhGrad) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1663,7 +1656,7 @@ TEST_F(TransposerTest, UnaryGradTransposerTestRelu6Grad) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1731,7 +1724,7 @@ TEST_F(TransposerTest, SqueezeTransposerTest) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1785,7 +1778,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestUnsupportedInputShape) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1824,7 +1817,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestInvalidHWAxis) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1863,7 +1856,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestInvalidNHWAxis) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1902,7 +1895,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestSqueezeDimsUpdated) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -1960,7 +1953,7 @@ TEST_F(TransposerTest, MaxPoolV2Transposer) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
MaxPoolV2Transposer maxpool_transposer; MaxPoolV2Transposer maxpool_transposer;
auto* maxpool = context.graph_view->GetNode("maxpoolv2"); auto* maxpool = context.graph_view->GetNode("maxpoolv2");
@ -2023,7 +2016,7 @@ TEST_F(TransposerTest, MaxPoolGradV2Transposer) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
MaxPoolGradV2Transposer maxpoolgrad_transposer; MaxPoolGradV2Transposer maxpoolgrad_transposer;
auto* maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2"); auto* maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2");
@ -2091,7 +2084,7 @@ TEST_F(TransposerTest, BinaryOpTransposerAdd) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2162,7 +2155,7 @@ TEST_F(TransposerTest, BinaryOpTransposerMul) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2235,7 +2228,7 @@ TEST_F(TransposerTest, BinaryOpTransposerPolygamma) {
TF_ASSERT_OK(scope.ToGraphDef(&item.graph)); TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2326,7 +2319,7 @@ TEST_F(TransposerTest, ConcatOpTransposerConcat) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2402,7 +2395,7 @@ TEST_F(TransposerTest, ConcatOpTransposerConcatV2) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2474,7 +2467,7 @@ TEST_F(TransposerTest, ReverseV2Transposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2541,7 +2534,7 @@ TEST_F(TransposerTest, TileTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2606,7 +2599,7 @@ TEST_F(TransposerTest, ShapeTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2671,7 +2664,7 @@ TEST_F(TransposerTest, ShapeNTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d_1 = context.graph_view->GetNode("conv2d_1"); auto* c2d_1 = context.graph_view->GetNode("conv2d_1");
@ -2765,7 +2758,7 @@ TEST_F(TransposerTest, FillOpTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2825,7 +2818,7 @@ TEST_F(TransposerTest, SliceTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2901,7 +2894,7 @@ TEST_F(TransposerTest, SplitTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -2989,7 +2982,7 @@ TEST_F(TransposerTest, SplitVTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -3080,7 +3073,7 @@ TEST_F(TransposerTest, StridedSliceTransposer) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -3170,7 +3163,7 @@ TEST_F(TransposerTest, StridedSliceTransposerEllipsisMaskPresent) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -3226,7 +3219,7 @@ TEST_F(TransposerTest, ReduceTransposerKeepDims) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");
@ -3291,7 +3284,7 @@ TEST_F(TransposerTest, ReduceTransposerValidAxisNode) {
TransposeContext context; TransposeContext context;
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext( TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
item, virtual_cluster_.get(), &context)); item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
DefaultLayoutSensitiveOpTransposer conv2d_transposer; DefaultLayoutSensitiveOpTransposer conv2d_transposer;
auto* c2d = context.graph_view->GetNode("conv2d"); auto* c2d = context.graph_view->GetNode("conv2d");