[Grappler] Have GenericLayoutOptimizer be specific to GPUs for now.
PiperOrigin-RevId: 254307839
This commit is contained in:
parent
155c4acd30
commit
ffc6b92b28
@ -31,6 +31,17 @@ namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
inline int GetNumGPUs(const Cluster& cluster) {
|
||||
auto devices = cluster.GetDevices();
|
||||
int num_gpus = 0;
|
||||
for (const auto& device : devices) {
|
||||
if (device.second.type() == "GPU") {
|
||||
num_gpus++;
|
||||
}
|
||||
}
|
||||
return num_gpus;
|
||||
}
|
||||
|
||||
Status ExpandLayoutSensitiveOp(TransposeContext* context,
|
||||
TransposerFactory* transposer_factory) {
|
||||
const int num_nodes = context->num_nodes;
|
||||
@ -180,9 +191,10 @@ Status EraseOutputShapeAttrs(TransposeContext* context) {
|
||||
utils::Mutation* mutation = graph_view->GetMutationBuilder();
|
||||
const int num_nodes = graph_view->NumNodes();
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
mutation->RemoveNodeAttr(graph_view->GetNode(i), "_output_shapes");
|
||||
mutation->RemoveNodeAttr(graph_view->GetNode(i), kAttrOutputShape);
|
||||
TF_RETURN_IF_ERROR(mutation->Apply());
|
||||
}
|
||||
return mutation->Apply();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -190,15 +202,23 @@ Status EraseOutputShapeAttrs(TransposeContext* context) {
|
||||
Status GenericLayoutOptimizer::Optimize(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
// If optimizer returns early with error, output will be the input graph.
|
||||
*output = item.graph;
|
||||
if (cluster == nullptr) {
|
||||
LOG(WARNING)
|
||||
<< "generic layout optimizer was called with cluster == nullptr";
|
||||
return errors::Aborted("cluster == nullptr.");
|
||||
}
|
||||
if (GetNumGPUs(*cluster) < 1) {
|
||||
return errors::Aborted(
|
||||
"No GPUs found: GenericLayoutOptimizer is currently only tuned for "
|
||||
"GPU.");
|
||||
}
|
||||
|
||||
TransposeContext context;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TransposeContext::InitializeTransposeContext(item, cluster, &context));
|
||||
TF_RETURN_IF_ERROR(TransposeContext::InitializeTransposeContext(
|
||||
item, cluster, src_format_, dst_format_, target_device_, &context));
|
||||
TransposerFactory transposer_factory;
|
||||
TF_RETURN_IF_ERROR(ExpandLayoutSensitiveOp(&context, &transposer_factory));
|
||||
TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory));
|
||||
// TODO(lyandy): Merge non cancellable nodes.
|
||||
TF_RETURN_IF_ERROR(EraseCancellableNodes(&context));
|
||||
TF_RETURN_IF_ERROR(EraseOutputShapeAttrs(&context));
|
||||
|
||||
@ -213,25 +233,6 @@ void GenericLayoutOptimizer::Feedback(Cluster* cluster,
|
||||
// Takes no feedback.
|
||||
}
|
||||
|
||||
string GetAndValidateParameter(const string& parameter,
|
||||
const AttrValueMap& parameter_map,
|
||||
const std::set<string>& valid_inputs,
|
||||
std::vector<string>* validation_errors,
|
||||
std::vector<string>* missing_parameters) {
|
||||
if (parameter_map.find(parameter) != parameter_map.end()) {
|
||||
string input = str_util::Uppercase(parameter_map.at(parameter).s());
|
||||
if (valid_inputs.find(input) != valid_inputs.end()) {
|
||||
return input;
|
||||
}
|
||||
validation_errors->push_back(absl::StrCat(
|
||||
"Invalid input ", input, " for parameter ", parameter,
|
||||
", must be one of [", str_util::Join(valid_inputs, ", "), "]."));
|
||||
} else {
|
||||
missing_parameters->push_back(parameter);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
Status GenericLayoutOptimizer::Init(
|
||||
const RewriterConfig_CustomGraphOptimizer* config) {
|
||||
return Status::OK();
|
||||
|
@ -36,6 +36,11 @@ class GenericLayoutOptimizer : public CustomGraphOptimizer {
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
Status Init(const RewriterConfig_CustomGraphOptimizer* config) final;
|
||||
|
||||
private:
|
||||
string target_device_ = "GPU";
|
||||
string src_format_ = "NHWC";
|
||||
string dst_format_ = "NCHW";
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -36,35 +36,7 @@ namespace grappler {
|
||||
|
||||
constexpr char kAttrSrcFormat[] = "src_format";
|
||||
constexpr char kAttrDstFormat[] = "dst_format";
|
||||
|
||||
// NodeLayoutContext holds the data formats to convert from and to for a layout
|
||||
// sensitive node.
|
||||
struct NodeLayoutContext {
|
||||
void Initialize(absl::string_view src_format, absl::string_view dst_format,
|
||||
bool is_transposable);
|
||||
|
||||
bool operator==(const NodeLayoutContext& other) const {
|
||||
return src == other.src && dst == other.dst;
|
||||
}
|
||||
|
||||
bool operator!=(const NodeLayoutContext& other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
string src;
|
||||
string dst;
|
||||
std::vector<int> src_to_dst;
|
||||
std::vector<int> dst_to_src;
|
||||
bool transposable = false;
|
||||
};
|
||||
|
||||
// TransposeContext holds the data formats to convert from and to of each data
|
||||
// fanin and fanout for a layout agnostic node.
|
||||
struct LayoutAgnosticNodeFanouts {
|
||||
std::vector<NodeLayoutContext> fanout_layouts;
|
||||
std::vector<NodeLayoutContext> fanin_layouts;
|
||||
bool transposable = false;
|
||||
};
|
||||
constexpr char kAttrOutputShape[] = "_output_shapes";
|
||||
|
||||
// TransposeContext owns all data members. Must initialize GraphProperties,
|
||||
// FrameView, GraphDef and MutableGraphView with the same graph. NodeDef
|
||||
@ -76,8 +48,10 @@ struct TransposeContext {
|
||||
// TransposeContext outside constructor.
|
||||
static Status InitializeTransposeContext(const GrapplerItem& item,
|
||||
const Cluster* cluster,
|
||||
absl::string_view src_format,
|
||||
absl::string_view dst_format,
|
||||
absl::string_view target_device,
|
||||
TransposeContext* context);
|
||||
|
||||
FrameView frames;
|
||||
GraphDef graph;
|
||||
// Number of nodes in the original graph. As new nodes are appended to the end
|
||||
@ -88,12 +62,12 @@ struct TransposeContext {
|
||||
std::unique_ptr<GraphProperties> graph_properties;
|
||||
std::unique_ptr<utils::MutableGraphView> graph_view;
|
||||
std::unique_ptr<const VirtualPlacer> virtual_placer;
|
||||
// Mapping of node index to data layout conversion of layout sensitive nodes.
|
||||
absl::flat_hash_map<int, NodeLayoutContext> sensitive_node_layouts;
|
||||
// Mapping of node index to data layout conversion of layout agnostic nodes.
|
||||
absl::flat_hash_map<int, LayoutAgnosticNodeFanouts> agnostic_node_layouts;
|
||||
NodeLayoutContext unknown_node_layout;
|
||||
LayoutAgnosticNodeFanouts unknown_agnostic_node_layout;
|
||||
|
||||
string src_format;
|
||||
string dst_format;
|
||||
string target_device;
|
||||
std::vector<int> src_to_dst;
|
||||
std::vector<int> dst_to_src;
|
||||
};
|
||||
|
||||
class Transposer {
|
||||
@ -105,6 +79,16 @@ class Transposer {
|
||||
|
||||
virtual ~Transposer() {}
|
||||
|
||||
// Returns true iff the node should be processed by this transposer.
|
||||
// NodeProcessors may perform additional oprand specific checks before
|
||||
// processing if necessary.
|
||||
// Following common conditions are checked:
|
||||
// * node's device matches target device
|
||||
// * node's source format matches config's source format
|
||||
// * node has output
|
||||
virtual bool ShouldProcess(const TransposeContext& context,
|
||||
const utils::MutableNodeView& node) const;
|
||||
|
||||
// Transposes given node from src format to dst format. Also perform other
|
||||
// necessary operations to guarantee the graph produce the same result.
|
||||
// Eg. Add Transpose node sets before fanin ports and after fanout ports.
|
||||
@ -133,7 +117,6 @@ class Transposer {
|
||||
// Update all edges between dst_node->fanin[dst_ports] and dst_node by
|
||||
// inserting an op node.
|
||||
Status UpdateFaninEdgesWithOp(TransposeContext* context,
|
||||
const NodeLayoutContext& layout,
|
||||
absl::Span<const int> dst_ports,
|
||||
utils::MutableNodeView* dst_node,
|
||||
absl::string_view op);
|
||||
@ -141,7 +124,6 @@ class Transposer {
|
||||
// Update all edges between src_node:src_ports and nodes take
|
||||
// src_node:src_ports as fanin. Also update attr _output_shape of src_node.
|
||||
Status UpdateFanoutEdgesWithOp(TransposeContext* context,
|
||||
const NodeLayoutContext& layout,
|
||||
absl::Span<const int> src_ports,
|
||||
utils::MutableNodeView* src_node,
|
||||
absl::string_view op);
|
||||
@ -149,7 +131,6 @@ class Transposer {
|
||||
// Creates a DataFromat node with given properties.
|
||||
// DataFromat op is either DataFormatVecPermute or DataFormatDimMap.
|
||||
Status CreateDataFormatNode(TransposeContext* context,
|
||||
const NodeLayoutContext& layout,
|
||||
absl::string_view node_name, absl::string_view op,
|
||||
absl::string_view device,
|
||||
const DataType& data_type, bool is_fanin_on_host,
|
||||
@ -167,14 +148,11 @@ class Transposer {
|
||||
const utils::MutableNodeView& node) const;
|
||||
string GetDeviceName(const VirtualPlacer* virtual_placer,
|
||||
const NodeDef& node) const;
|
||||
virtual string GetDstDataFormatForDevice(
|
||||
const DeviceProperties& device) const;
|
||||
// Update all edges between dst_node->fanin[dst_ports] and dst_node.
|
||||
// A node with op is created and insereted between all edges.
|
||||
// op is one of Transpose, DataFormatVecPermute or DataFormatDimMap.
|
||||
Status UpdateEdge(TransposeContext* context, const NodeLayoutContext& layout,
|
||||
absl::string_view name_format, absl::string_view op,
|
||||
const AttrValue* input_shape,
|
||||
Status UpdateEdge(TransposeContext* context, absl::string_view name_format,
|
||||
absl::string_view op, const AttrValue* input_shape,
|
||||
bool is_src_format_to_dst_format, const int src_port,
|
||||
const int dst_port, utils::MutableNodeView* src_node,
|
||||
utils::MutableNodeView* dst_node);
|
||||
@ -197,19 +175,7 @@ class LayoutSensitiveOpTransposer : public Transposer {
|
||||
|
||||
// Updates attrs data_format, ksize, strides of the given node to dst_format.
|
||||
// _output_shape is updated during UpdateOutputEdges.
|
||||
Status UpdateNode(TransposeContext* context, const NodeLayoutContext& layout,
|
||||
utils::MutableNodeView* node);
|
||||
|
||||
protected:
|
||||
// Check if a node is transposable, and return what data format to transpose
|
||||
// from and to.
|
||||
const NodeLayoutContext& CheckNodeTransposable(
|
||||
TransposeContext* context, const utils::MutableNodeView& node) const;
|
||||
|
||||
// Apply mutation. If mutation fails, reset layout being transpose to
|
||||
// associated with the node.
|
||||
Status CommitMutation(TransposeContext* context,
|
||||
const utils::MutableNodeView& node);
|
||||
Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node);
|
||||
};
|
||||
|
||||
// Layout sensitive op transposers.
|
||||
@ -289,23 +255,12 @@ class LayoutAgnosticOpTransposer : public Transposer {
|
||||
explicit LayoutAgnosticOpTransposer() : Transposer() {}
|
||||
|
||||
protected:
|
||||
void ProcessFanoutDataLayouts(TransposeContext* context,
|
||||
const utils::MutableNodeView& node);
|
||||
bool IsAfterDstToSrcTransform(const TransposeContext& context,
|
||||
const utils::MutableNodeView& node) const;
|
||||
|
||||
void FetchNodeDataLayouts(TransposeContext* context,
|
||||
const utils::MutableNodeView& node);
|
||||
|
||||
// Check if a node is transposable, and return what data format to transpose
|
||||
// from and to.
|
||||
const LayoutAgnosticNodeFanouts& CheckNodeTransposable(
|
||||
TransposeContext* context, const utils::MutableNodeView& node);
|
||||
|
||||
// Validate node and get data format layout to transpose from and to for a
|
||||
// given fanout port of a node.
|
||||
const NodeLayoutContext& GetDataFormatLayoutForNode(
|
||||
TransposeContext* context, const utils::MutableNodeView& node,
|
||||
int fanout_port,
|
||||
std::function<bool(const utils::MutableNodeView& node)> node_check);
|
||||
std::vector<int> GetVariadic4DFaninPorts(
|
||||
const TransposeContext& context,
|
||||
const utils::MutableNodeView& node) const;
|
||||
};
|
||||
|
||||
class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
|
||||
@ -345,7 +300,6 @@ class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
|
||||
absl::string_view shape_const_node_name,
|
||||
const DataType& data_type);
|
||||
Status MaybeReshapeVectorFanin(TransposeContext* context,
|
||||
const NodeLayoutContext& layout,
|
||||
utils::MutableNodeView* node);
|
||||
};
|
||||
|
||||
@ -381,7 +335,9 @@ class MergeTransposer : public LayoutAgnosticOpTransposer {
|
||||
utils::MutableNodeView* node) override;
|
||||
|
||||
private:
|
||||
std::vector<int> GetFaninPorts(const utils::MutableNodeView& node);
|
||||
bool IsEveryFaninAfterDstToSrcTransform(
|
||||
const TransposeContext& context,
|
||||
const utils::MutableNodeView& node) const;
|
||||
};
|
||||
|
||||
class PadTransposer : public LayoutAgnosticOpTransposer {
|
||||
@ -403,7 +359,7 @@ class ReduceTransposer : public LayoutAgnosticOpTransposer {
|
||||
bool KeepDims(const utils::MutableNodeView& node);
|
||||
bool IsAlongAxis(const utils::MutableNodeView& axis_node,
|
||||
absl::Span<const int> axis);
|
||||
bool IsReduceAxisSupported(const NodeLayoutContext& layout,
|
||||
bool IsReduceAxisSupported(const TransposeContext& context,
|
||||
const utils::MutableNodeView& node);
|
||||
};
|
||||
|
||||
@ -495,8 +451,8 @@ class StridedSliceTransposer : public LayoutAgnosticOpTransposer {
|
||||
private:
|
||||
bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask);
|
||||
bool HasOnlyBeginEndMask(const utils::MutableNodeView& node);
|
||||
Status PermuteMask(TransposeContext* context, const NodeLayoutContext& layout,
|
||||
utils::MutableNodeView* node, absl::string_view mask);
|
||||
Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node,
|
||||
absl::string_view mask);
|
||||
};
|
||||
|
||||
class SwitchTransposer : public LayoutAgnosticOpTransposer {
|
||||
@ -580,19 +536,8 @@ std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node);
|
||||
bool GetValueAttrIfConstPermTransposeNode(const utils::MutableNodeView& node,
|
||||
Tensor* tensor);
|
||||
|
||||
// Returns true if the given node is a transpose op performing given const
|
||||
// permutation.
|
||||
bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node,
|
||||
absl::Span<const int> permutation);
|
||||
|
||||
bool IsDataFormatOp(const utils::MutableNodeView& node);
|
||||
|
||||
// Returns true if the given node is DataformatDimMap or DataformatVecPermute
|
||||
// performing given permutation from scr_format to dst_format.
|
||||
bool IsValidDataFormatNode(const utils::MutableNodeView& node,
|
||||
absl::string_view src_format,
|
||||
absl::string_view dst_format);
|
||||
|
||||
std::vector<int> GetPermutation(absl::string_view src_format,
|
||||
absl::string_view dst_format);
|
||||
|
||||
|
@ -49,6 +49,7 @@ constexpr int kOutHeight = 5;
|
||||
constexpr int kDepthOut = 16;
|
||||
constexpr char kSrcFormat[] = "NHWC";
|
||||
constexpr char kDstFormat[] = "NCHW";
|
||||
constexpr char kGPU[] = "GPU";
|
||||
constexpr char kAttrOutputShapes[] = "_output_shapes";
|
||||
constexpr char kAttrDataFormat[] = "data_format";
|
||||
constexpr char kOpTranspose[] = "Transpose";
|
||||
@ -316,7 +317,7 @@ TEST_F(TransposerTest, CreateConstPermNode) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
TransposerImpl transposer;
|
||||
constexpr char kNodeName[] = "const_perm_node";
|
||||
@ -358,8 +359,8 @@ TEST_F(TransposerTest, CreateTransposeNode) {
|
||||
GrapplerItem item;
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
|
||||
TF_ASSERT_OK(
|
||||
TransposeContext::InitializeTransposeContext(item, nullptr, &context));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
TransposerImpl transposer;
|
||||
constexpr char kNodeNameFormat[] =
|
||||
@ -398,14 +399,12 @@ TEST_F(TransposerTest, UpdateNode) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer transposer;
|
||||
auto* conv2d = context.graph_view->GetNode("conv2d");
|
||||
ASSERT_NE(conv2d, nullptr);
|
||||
NodeLayoutContext layout;
|
||||
layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true);
|
||||
TF_ASSERT_OK(transposer.UpdateNode(&context, layout, conv2d));
|
||||
TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
|
||||
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
|
||||
|
||||
auto* updated_conv2d = context.graph_view->GetNode("conv2d");
|
||||
@ -430,7 +429,7 @@ TEST_F(TransposerTest, UpdateStrides) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), "ABCD", "ACBD", kGPU, &context));
|
||||
|
||||
AttrValue_ListValue expected_original_strides =
|
||||
MakeAttrValueListValueFromVector({1, 2, 4, 1});
|
||||
@ -449,9 +448,7 @@ TEST_F(TransposerTest, UpdateStrides) {
|
||||
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer transposer;
|
||||
NodeLayoutContext layout;
|
||||
layout.Initialize("ABCD", "ACBD", /*is_transposable=*/true);
|
||||
TF_ASSERT_OK(transposer.UpdateNode(&context, layout, conv2d));
|
||||
TF_ASSERT_OK(transposer.UpdateNode(&context, conv2d));
|
||||
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
|
||||
|
||||
auto* updated_conv2d = context.graph_view->GetNode("conv2d");
|
||||
@ -468,19 +465,17 @@ TEST_F(TransposerTest, UpdateFaninEdgesTranspose) {
|
||||
GrapplerItem item;
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
|
||||
TF_ASSERT_OK(
|
||||
TransposeContext::InitializeTransposeContext(item, nullptr, &context));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
FusedBatchNormGradTransposer transposer;
|
||||
NodeLayoutContext layout;
|
||||
layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true);
|
||||
auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
|
||||
ASSERT_NE(fbng, nullptr);
|
||||
const auto& fbng_output_shapes_attr = fbng->GetAttr("_output_shapes");
|
||||
ASSERT_NE(fbng_output_shapes_attr, nullptr);
|
||||
const TensorShapeProto& expected_shape = fbng_output_shapes_attr->shape();
|
||||
TF_ASSERT_OK(transposer.UpdateFaninEdgesWithOp(&context, layout, {0, 1}, fbng,
|
||||
kOpTranspose));
|
||||
TF_ASSERT_OK(
|
||||
transposer.UpdateFaninEdgesWithOp(&context, {0, 1}, fbng, kOpTranspose));
|
||||
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
|
||||
|
||||
// Verify output shape matches input shape.
|
||||
@ -528,12 +523,10 @@ TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) {
|
||||
GrapplerItem item;
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
|
||||
TF_ASSERT_OK(
|
||||
TransposeContext::InitializeTransposeContext(item, nullptr, &context));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
TransposerImpl transposer;
|
||||
NodeLayoutContext layout;
|
||||
layout.Initialize(kSrcFormat, kDstFormat, /*is_transposable=*/true);
|
||||
TensorShapeProto expected_original_shape =
|
||||
MakeTensorShapeFromDimensions({32, 5, 3, 16});
|
||||
TensorShapeProto expected_updated_shape =
|
||||
@ -543,8 +536,8 @@ TEST_F(TransposerTest, UpdateFanoutEdgesTranspose) {
|
||||
ASSERT_NE(conv2d, nullptr);
|
||||
VerifyShapeAttributeMatch(conv2d, 0, expected_original_shape.DebugString());
|
||||
|
||||
TF_ASSERT_OK(transposer.UpdateFanoutEdgesWithOp(&context, layout, {0}, conv2d,
|
||||
kOpTranspose));
|
||||
TF_ASSERT_OK(
|
||||
transposer.UpdateFanoutEdgesWithOp(&context, {0}, conv2d, kOpTranspose));
|
||||
TF_ASSERT_OK(context.graph_view->GetMutationBuilder()->Apply());
|
||||
|
||||
auto* updated_conv2d = context.graph_view->GetNode("conv2d");
|
||||
@ -584,7 +577,7 @@ TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestFusedBatchNorm) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleFusedBatchNorm(&item.graph));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer transposer;
|
||||
auto* bn = context.graph_view->GetNode("bn");
|
||||
@ -639,7 +632,7 @@ TEST_F(TransposerTest, DefaultLayoutSensitiveOpTransposerTestConv2D) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleConv2DGraph(&item.graph));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer transposer;
|
||||
auto* conv2d = context.graph_view->GetNode("conv2d");
|
||||
@ -681,7 +674,7 @@ TEST_F(TransposerTest, MaxPoolGradTransposerTest) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleMaxPoolGrad(&item.graph));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
MaxPoolGradTransposer transposer;
|
||||
auto* maxpool_grad = context.graph_view->GetNode("maxpool_grad");
|
||||
@ -733,7 +726,7 @@ TEST_F(TransposerTest, BiasAddGradTransposerTest) {
|
||||
TF_ASSERT_OK(CreateSimpleBiasAddGrad(
|
||||
&item.graph, {kBatchSize, kHeight, kWidth, kDepthIn}));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
BiasAddGradTransposer transposer;
|
||||
auto* bag = context.graph_view->GetNode("bag");
|
||||
@ -768,7 +761,7 @@ TEST_F(TransposerTest, BiasAddGradTransposerIncorrectInputTest) {
|
||||
TF_ASSERT_OK(
|
||||
CreateSimpleBiasAddGrad(&item.graph, {kHeight, kWidth, kDepthIn}));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
BiasAddGradTransposer transposer;
|
||||
auto* bag = context.graph_view->GetNode("bag");
|
||||
@ -802,7 +795,7 @@ TEST_F(TransposerTest, Conv2DBackpropFilterTransposerTest) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleConv2DBackpropFilter(&item.graph));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
Conv2DBackpropFilterTransposer transposer;
|
||||
auto* conv2d_bf = context.graph_view->GetNode("conv2d_backprop_filter");
|
||||
@ -854,7 +847,7 @@ TEST_F(TransposerTest, FusedBatchNormGradTransposerIsTrainingTest) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, true));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
FusedBatchNormGradTransposer transposer;
|
||||
auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
|
||||
@ -922,7 +915,7 @@ TEST_F(TransposerTest, FusedBatchNormGradTransposerNotTrainingTest) {
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(CreateSimpleFusedBatchNormGrad(&item.graph, false));
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
FusedBatchNormGradTransposer transposer;
|
||||
auto* fbng = context.graph_view->GetNode("fused_batch_norm_grad");
|
||||
@ -992,7 +985,7 @@ TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityTest) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1043,7 +1036,7 @@ TEST_F(TransposerTest, DefaultLayoutAgnosticOpTransposerIdentityBadInputTest) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1084,7 +1077,7 @@ TEST_F(TransposerTest, AddNTransposerTest) {
|
||||
TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* conv2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1147,7 +1140,7 @@ TEST_F(TransposerTest, AddNTransposerNotAfterTransformTest) {
|
||||
TF_ASSERT_OK(CreateSimpleAddN(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
AddNTransposer addn_transposer;
|
||||
auto* an = context.graph_view->GetNode("add_n");
|
||||
@ -1197,7 +1190,7 @@ TEST_F(TransposerTest, IdentityNTransposerTest) {
|
||||
TF_ASSERT_OK(CreateSimpleIdentityN(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* conv2d_1 = context.graph_view->GetNode("conv2d_1");
|
||||
@ -1295,7 +1288,7 @@ TEST_F(TransposerTest, MergeTransposerTestMergeBothInputsConvertible) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1354,7 +1347,7 @@ TEST_F(TransposerTest, MergeTransposerTestMergeOneInputNotConvertible) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1407,7 +1400,7 @@ TEST_F(TransposerTest, PadTransposerTest) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1467,7 +1460,7 @@ TEST_F(TransposerTest, SwitchTransposerTest) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1533,7 +1526,7 @@ TEST_F(TransposerTest, TernaryOpTransposerTest) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1600,7 +1593,7 @@ TEST_F(TransposerTest, UnaryGradTransposerTestTanhGrad) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1663,7 +1656,7 @@ TEST_F(TransposerTest, UnaryGradTransposerTestRelu6Grad) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1731,7 +1724,7 @@ TEST_F(TransposerTest, SqueezeTransposerTest) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1785,7 +1778,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestUnsupportedInputShape) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1824,7 +1817,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestInvalidHWAxis) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1863,7 +1856,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestInvalidNHWAxis) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1902,7 +1895,7 @@ TEST_F(TransposerTest, SqueezeTransposerTestSqueezeDimsUpdated) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -1960,7 +1953,7 @@ TEST_F(TransposerTest, MaxPoolV2Transposer) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
MaxPoolV2Transposer maxpool_transposer;
|
||||
auto* maxpool = context.graph_view->GetNode("maxpoolv2");
|
||||
@ -2023,7 +2016,7 @@ TEST_F(TransposerTest, MaxPoolGradV2Transposer) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
MaxPoolGradV2Transposer maxpoolgrad_transposer;
|
||||
auto* maxpoolgrad = context.graph_view->GetNode("maxpoolgradv2");
|
||||
@ -2091,7 +2084,7 @@ TEST_F(TransposerTest, BinaryOpTransposerAdd) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2162,7 +2155,7 @@ TEST_F(TransposerTest, BinaryOpTransposerMul) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2235,7 +2228,7 @@ TEST_F(TransposerTest, BinaryOpTransposerPolygamma) {
|
||||
TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2326,7 +2319,7 @@ TEST_F(TransposerTest, ConcatOpTransposerConcat) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2402,7 +2395,7 @@ TEST_F(TransposerTest, ConcatOpTransposerConcatV2) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2474,7 +2467,7 @@ TEST_F(TransposerTest, ReverseV2Transposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2541,7 +2534,7 @@ TEST_F(TransposerTest, TileTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2606,7 +2599,7 @@ TEST_F(TransposerTest, ShapeTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2671,7 +2664,7 @@ TEST_F(TransposerTest, ShapeNTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d_1 = context.graph_view->GetNode("conv2d_1");
|
||||
@ -2765,7 +2758,7 @@ TEST_F(TransposerTest, FillOpTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2825,7 +2818,7 @@ TEST_F(TransposerTest, SliceTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2901,7 +2894,7 @@ TEST_F(TransposerTest, SplitTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -2989,7 +2982,7 @@ TEST_F(TransposerTest, SplitVTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -3080,7 +3073,7 @@ TEST_F(TransposerTest, StridedSliceTransposer) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -3170,7 +3163,7 @@ TEST_F(TransposerTest, StridedSliceTransposerEllipsisMaskPresent) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -3226,7 +3219,7 @@ TEST_F(TransposerTest, ReduceTransposerKeepDims) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
@ -3291,7 +3284,7 @@ TEST_F(TransposerTest, ReduceTransposerValidAxisNode) {
|
||||
|
||||
TransposeContext context;
|
||||
TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
|
||||
item, virtual_cluster_.get(), &context));
|
||||
item, virtual_cluster_.get(), kSrcFormat, kDstFormat, kGPU, &context));
|
||||
|
||||
DefaultLayoutSensitiveOpTransposer conv2d_transposer;
|
||||
auto* c2d = context.graph_view->GetNode("conv2d");
|
||||
|
Loading…
Reference in New Issue
Block a user