Add mixed precision Grappler optimizer
- Adds an opt-in Grappler pass to convert parts of a graph from float32 to float16 Co-authored-by: Ben Barsdell <bbarsdell@nvidia.com> Co-authored-by: Carl Case <carlc@nvidia.com> Co-authored-by: Trent Lo <trentl@nvidia.com> Co-authored-by: Nathan Luehr <nluehr@nvidia.com>
This commit is contained in:
parent
6a883e7f8d
commit
30d87b5299
@ -27,7 +27,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
int GetNumAvailableGPUs() {
|
||||
int GetNumAvailableGPUs(
|
||||
const std::pair<int, int>& min_cuda_compute_capability) {
|
||||
int num_eligible_gpus = 0;
|
||||
#if GOOGLE_CUDA
|
||||
if (ValidateGPUMachineManager().ok()) {
|
||||
@ -39,20 +40,29 @@ int GetNumAvailableGPUs() {
|
||||
if (exec_status.ok()) {
|
||||
se::StreamExecutor* se = exec_status.ValueOrDie();
|
||||
const se::DeviceDescription& desc = se->GetDeviceDescription();
|
||||
int cc_major = 0;
|
||||
int cc_minor = 0;
|
||||
desc.cuda_compute_capability(&cc_major, &cc_minor);
|
||||
std::pair<int, int> cuda_compute_capability(cc_major, cc_minor);
|
||||
int min_gpu_core_count = 8;
|
||||
if (desc.core_count() >= min_gpu_core_count) {
|
||||
if (desc.core_count() >= min_gpu_core_count &&
|
||||
cuda_compute_capability >= min_cuda_compute_capability) {
|
||||
num_eligible_gpus++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Number of eligible GPUs (core count >= 8): "
|
||||
<< num_eligible_gpus;
|
||||
LOG(INFO)
|
||||
<< "Number of eligible GPUs (core count >= 8, compute capability >= "
|
||||
<< min_cuda_compute_capability.first << "."
|
||||
<< min_cuda_compute_capability.second << "): " << num_eligible_gpus;
|
||||
#else
|
||||
LOG(INFO) << "Number of eligible GPUs (core count >= 8): "
|
||||
<< num_eligible_gpus
|
||||
<< " (Note: TensorFlow was not compiled with CUDA support)";
|
||||
LOG(INFO)
|
||||
<< "Number of eligible GPUs (core count >= 8, compute capability >= "
|
||||
<< min_cuda_compute_capability.first << "."
|
||||
<< min_cuda_compute_capability.second << "): " << num_eligible_gpus
|
||||
<< " (Note: TensorFlow was not compiled with CUDA support)";
|
||||
#endif // GOOGLE_CUDA
|
||||
return num_eligible_gpus;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_GRAPPLER_DEVICES_H_
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
@ -26,8 +27,10 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Get the number of available GPUs whose number of multiprocessors is no less
|
||||
// than 8.
|
||||
int GetNumAvailableGPUs();
|
||||
// than 8 and whose CUDA compute capability is no less than
|
||||
// min_cuda_compute_capability.
|
||||
int GetNumAvailableGPUs(
|
||||
const std::pair<int, int>& min_cuda_compute_capability = {0, 0});
|
||||
|
||||
// Maximum amount of gpu memory available per gpu. gpu_id must be in the range
|
||||
// [0, num_available_gpu)
|
||||
|
@ -523,6 +523,46 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "auto_mixed_precision",
|
||||
srcs = ["auto_mixed_precision.cc"],
|
||||
hdrs = [
|
||||
"auto_mixed_precision.h",
|
||||
"auto_mixed_precision_lists.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":custom_graph_optimizer_registry",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:devices",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/costs:virtual_placer",
|
||||
"//tensorflow/core/grappler/utils:frame",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "auto_mixed_precision_test",
|
||||
srcs = ["auto_mixed_precision_test.cc"],
|
||||
deps = [
|
||||
":auto_mixed_precision",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler/clusters:single_machine",
|
||||
"//tensorflow/core/grappler/clusters:virtual_cluster",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "meta_optimizer",
|
||||
srcs = ["meta_optimizer.cc"],
|
||||
@ -531,6 +571,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":auto_mixed_precision",
|
||||
":arithmetic_optimizer",
|
||||
":auto_parallel",
|
||||
":constant_folding",
|
||||
|
1720
tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
Normal file
1720
tensorflow/core/grappler/optimizers/auto_mixed_precision.cc
Normal file
File diff suppressed because it is too large
Load Diff
46
tensorflow/core/grappler/optimizers/auto_mixed_precision.h
Normal file
46
tensorflow/core/grappler/optimizers/auto_mixed_precision.h
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Convert data types to float16 where appropriate to improve performance on
|
||||
// GPUs.
|
||||
class AutoMixedPrecision : public GraphOptimizer {
|
||||
public:
|
||||
explicit AutoMixedPrecision(
|
||||
RewriterConfig::Toggle opt_level = RewriterConfig::ON) {}
|
||||
|
||||
~AutoMixedPrecision() override {}
|
||||
|
||||
string name() const override { return "auto_mixed_precision"; };
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_H_
|
302
tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
Normal file
302
tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h
Normal file
@ -0,0 +1,302 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
|
||||
|
||||
#include <set>
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
// TODO(benbarsdell): GOOGLE_CUDA doesn't seem to be working?
|
||||
//#if GOOGLE_CUDA
|
||||
// Needed for CUDA_VERSION macro.
|
||||
#include "cuda/include/cuda.h"
|
||||
//#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class AutoMixedPrecisionLists {
|
||||
private:
|
||||
static void UpdateList(std::set<string>* list, const string& to_add,
|
||||
const string& to_remove) {
|
||||
for (auto x : str_util::Split(to_add, ",")) {
|
||||
list->insert(x);
|
||||
}
|
||||
for (auto x : str_util::Split(to_remove, ",")) {
|
||||
list->erase(x);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsPseudoFastMath() {
|
||||
string optimization_level;
|
||||
ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "",
|
||||
&optimization_level);
|
||||
optimization_level = str_util::Uppercase(optimization_level);
|
||||
return optimization_level == "TENSOR_CORES_ONLY";
|
||||
}
|
||||
|
||||
public:
|
||||
// Returns the set of ops that are considered numerically-safe (for execution
|
||||
// in fp16) and performance-critical. These ops are always converted to fp16.
|
||||
static std::set<string> WhiteList() {
|
||||
string to_add, to_remove;
|
||||
ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_ADD",
|
||||
"", &to_add);
|
||||
ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_REMOVE", "",
|
||||
&to_remove);
|
||||
|
||||
auto list = std::set<string> {
|
||||
#if CUDA_VERSION >= 9010 // Fp16 BatchMatMul is slow before CUDA 9.1.
|
||||
"BatchMatMul",
|
||||
#endif
|
||||
"Conv2D",
|
||||
"Conv2DBackpropFilter",
|
||||
"Conv2DBackpropInput",
|
||||
// TODO(benbarsdell): Enable these when Tensor Core kernels are
|
||||
// available for 3D convolutions.
|
||||
// "Conv3D",
|
||||
// "Conv3DBackpropFilter",
|
||||
// "Conv3DBackpropFilterV2",
|
||||
// "Conv3DBackpropInput",
|
||||
// "Conv3DBackpropInputV2",
|
||||
"CudnnRNN",
|
||||
"CudnnRNNBackprop",
|
||||
"CudnnRNNBackpropV2",
|
||||
"CudnnRNNBackpropV3",
|
||||
"CudnnRNNV2",
|
||||
"CudnnRNNV3",
|
||||
// TODO(benbarsdell): Enable these when fast and safe fp16 kernels are
|
||||
// available for depthwise convolutions.
|
||||
// "DepthwiseConv2dNative",
|
||||
// "DepthwiseConv2dNativeBackpropFilter",
|
||||
// "DepthwiseConv2dNativeBackpropInput",
|
||||
"MatMul",
|
||||
};
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
return list;
|
||||
}
|
||||
|
||||
// Returns the set of ops that are considered numerically-safe (for execution
|
||||
// in fp16), but which may be made unsafe by an upstream blacklist op.
|
||||
static std::set<string> GrayList() {
|
||||
if (IsPseudoFastMath()) {
|
||||
return std::set<string>{};
|
||||
}
|
||||
string to_add, to_remove;
|
||||
ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_ADD",
|
||||
"", &to_add);
|
||||
ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_REMOVE", "",
|
||||
&to_remove);
|
||||
|
||||
auto list = std::set<string>{
|
||||
"Add",
|
||||
"AddN",
|
||||
"AddV2",
|
||||
"AvgPool",
|
||||
"AvgPool3D",
|
||||
"AvgPool3DGrad",
|
||||
"AvgPoolGrad",
|
||||
"BiasAdd",
|
||||
"BiasAddGrad",
|
||||
"BiasAddV1",
|
||||
"Elu",
|
||||
"EluGrad",
|
||||
"Erf",
|
||||
"Erfc",
|
||||
"FloorDiv",
|
||||
"FusedBatchNormV2",
|
||||
"FusedBatchNormGradV2",
|
||||
"Inv",
|
||||
"LeakyRelu",
|
||||
"LeakyReluGrad",
|
||||
"Mul",
|
||||
"Prod",
|
||||
"RealDiv",
|
||||
"Reciprocal",
|
||||
"Sigmoid",
|
||||
"SigmoidGrad",
|
||||
"Softplus",
|
||||
"SoftplusGrad",
|
||||
"Sqrt",
|
||||
"Sub",
|
||||
"Sum",
|
||||
"Tanh",
|
||||
"TanhGrad",
|
||||
};
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
return list;
|
||||
}
|
||||
|
||||
// Returns the set of ops that are considered numerically-dangerous (i.e.,
|
||||
// unsafe for execution in fp16) and whose effects may also be observed in
|
||||
// downstream nodes (e.g., in Exp -> Add, the Add is unsafe due to the Exp).
|
||||
static std::set<string> BlackList() {
|
||||
if (IsPseudoFastMath()) {
|
||||
return std::set<string>{};
|
||||
}
|
||||
string to_add, to_remove;
|
||||
ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_ADD",
|
||||
"", &to_add);
|
||||
ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_REMOVE", "",
|
||||
&to_remove);
|
||||
|
||||
auto list = std::set<string>{
|
||||
"Exp",
|
||||
"Expm1",
|
||||
"L2Loss",
|
||||
"Log",
|
||||
"Log1p",
|
||||
"LogSoftmax",
|
||||
"Mean",
|
||||
"Pow",
|
||||
"SaveV2",
|
||||
"Softmax",
|
||||
"SoftmaxCrossEntropyWithLogits",
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
};
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
return list;
|
||||
}
|
||||
|
||||
// Returns the set of ops that do not have numerically-significant effects
|
||||
// (i.e., they are always considered safe for execution in fp16 precision).
|
||||
static std::set<string> ClearList() {
|
||||
if (IsPseudoFastMath()) {
|
||||
return std::set<string>{};
|
||||
}
|
||||
string to_add, to_remove;
|
||||
ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_ADD", "",
|
||||
&to_add);
|
||||
ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_REMOVE", "",
|
||||
&to_remove);
|
||||
|
||||
auto list = std::set<string> {
|
||||
"Abs",
|
||||
"ArgMax",
|
||||
"ArgMin",
|
||||
"BatchToSpace",
|
||||
"BatchToSpaceND",
|
||||
"BroadcastTo",
|
||||
"Ceil",
|
||||
"CheckNumerics",
|
||||
"ClipByValue",
|
||||
"Concat",
|
||||
"ConcatV2",
|
||||
"DepthToSpace",
|
||||
"DynamicPartition",
|
||||
"DynamicStitch",
|
||||
"Enter",
|
||||
"EnsureShape",
|
||||
"Equal",
|
||||
"Exit",
|
||||
"ExpandDims",
|
||||
"Fill",
|
||||
"Floor",
|
||||
"Gather",
|
||||
"GatherNd",
|
||||
"GatherV2",
|
||||
"Greater",
|
||||
"GreaterEqual",
|
||||
"Identity",
|
||||
"IdentityN",
|
||||
"IsFinite",
|
||||
"IsInf",
|
||||
"IsNan",
|
||||
"Less",
|
||||
"LessEqual",
|
||||
"Max",
|
||||
"MaxPool",
|
||||
"MaxPool3D",
|
||||
"MaxPool3DGrad",
|
||||
"MaxPool3DGradGrad",
|
||||
"MaxPoolGrad",
|
||||
"MaxPoolGradGrad",
|
||||
"MaxPoolGradGradV2",
|
||||
"MaxPoolGradV2",
|
||||
"MaxPoolV2",
|
||||
"Maximum",
|
||||
"Merge",
|
||||
"Min",
|
||||
"Minimum",
|
||||
"MirrorPad",
|
||||
"MirrorPadGrad",
|
||||
"Neg",
|
||||
"NextIteration",
|
||||
"NotEqual",
|
||||
"OnesLike",
|
||||
"Pack",
|
||||
"Pad",
|
||||
"PadV2",
|
||||
"PreventGradient",
|
||||
"Rank",
|
||||
"Relu",
|
||||
"Relu6",
|
||||
"Relu6Grad",
|
||||
"ReluGrad",
|
||||
"Reshape",
|
||||
"ResizeNearestNeighbor",
|
||||
"ResizeNearestNeighborGrad",
|
||||
"Reverse",
|
||||
"ReverseSequence",
|
||||
"ReverseV2",
|
||||
"Round",
|
||||
"Select",
|
||||
"Shape",
|
||||
"ShapeN",
|
||||
"Sign",
|
||||
"Size",
|
||||
"Slice",
|
||||
"Snapshot",
|
||||
"SpaceToBatch",
|
||||
"SpaceToBatchND",
|
||||
"SpaceToDepth",
|
||||
"Split",
|
||||
"SplitV",
|
||||
"Squeeze",
|
||||
"StackPopV2",
|
||||
"StackPushV2",
|
||||
"StopGradient",
|
||||
"StridedSlice",
|
||||
"StridedSliceGrad",
|
||||
"Switch",
|
||||
"TensorArrayConcatV3",
|
||||
"TensorArrayGatherV3",
|
||||
"TensorArrayReadV3",
|
||||
"TensorArrayScatterV3",
|
||||
"TensorArraySplitV3",
|
||||
"TensorArrayWriteV3",
|
||||
"Tile",
|
||||
"TopK",
|
||||
"TopKV2",
|
||||
"Transpose",
|
||||
"Where",
|
||||
"ZerosLike",
|
||||
};
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
return list;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
|
479
tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
Normal file
479
tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc
Normal file
@ -0,0 +1,479 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/grappler/clusters/single_machine.h"
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/grappler/devices.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
const std::pair<int, int> kMinGPUArch = {7, 0};
|
||||
|
||||
class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
int num_gpus = GetNumAvailableGPUs();
|
||||
// If GPUs are available, require that they all satisfy the min arch.
|
||||
gpu_available_ =
|
||||
num_gpus > 0 && num_gpus == GetNumAvailableGPUs(kMinGPUArch);
|
||||
|
||||
if (gpu_available_) {
|
||||
virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 1));
|
||||
} else {
|
||||
DeviceProperties device_properties;
|
||||
device_properties.set_type("GPU");
|
||||
device_properties.mutable_environment()->insert({"architecture", "7"});
|
||||
virtual_cluster_.reset(
|
||||
new VirtualCluster({{"/GPU:1", device_properties}}));
|
||||
}
|
||||
TF_CHECK_OK(virtual_cluster_->Provision());
|
||||
}
|
||||
|
||||
void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
|
||||
|
||||
NodeDef* AddSimpleNode(const string& name, const string& op,
|
||||
const std::vector<string>& inputs,
|
||||
GraphDef* graph) const {
|
||||
std::vector<std::pair<string, AttrValue>> attributes;
|
||||
if (op == "AddN" || op == "ShapeN") {
|
||||
AttrValue num_inputs;
|
||||
num_inputs.set_i(inputs.size());
|
||||
attributes.emplace_back("N", num_inputs);
|
||||
}
|
||||
if (op == "ShapeN") {
|
||||
AttrValue out_type;
|
||||
out_type.set_type(DT_INT32);
|
||||
attributes.emplace_back("out_type", out_type);
|
||||
}
|
||||
AttrValue type;
|
||||
type.set_type(DT_FLOAT);
|
||||
if (op == "Const" || op == "Placeholder" || op == "VariableV2" ||
|
||||
op == "VarHandleOp" || op == "ReadVariableOp") {
|
||||
attributes.emplace_back("dtype", type);
|
||||
} else if (op == "SparseMatMul") {
|
||||
attributes.emplace_back("Ta", type);
|
||||
attributes.emplace_back("Tb", type);
|
||||
} else if (op == "IdentityN") {
|
||||
AttrValue type_list;
|
||||
for (int i = 0; i < (int)inputs.size(); ++i) {
|
||||
type_list.mutable_list()->add_type(DT_FLOAT);
|
||||
}
|
||||
attributes.emplace_back("T", type_list);
|
||||
} else if (op == "StackV2" || op == "StackPopV2") {
|
||||
attributes.emplace_back("elem_type", type);
|
||||
} else if (op == "Cast") {
|
||||
attributes.emplace_back("SrcT", type);
|
||||
attributes.emplace_back("DstT", type);
|
||||
} else {
|
||||
attributes.emplace_back("T", type);
|
||||
}
|
||||
return AddNode(name, op, inputs, attributes, graph);
|
||||
}
|
||||
|
||||
std::unique_ptr<Cluster> virtual_cluster_;
|
||||
bool gpu_available_;
|
||||
};
|
||||
|
||||
void VerifyGraphsEquivalent(const GraphDef& original_graph,
|
||||
const GraphDef& optimized_graph,
|
||||
const string& func) {
|
||||
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
|
||||
GraphView optimized_view(&optimized_graph);
|
||||
for (int i = 0; i < original_graph.node_size(); ++i) {
|
||||
const NodeDef& original = original_graph.node(i);
|
||||
const NodeDef& optimized = *optimized_view.GetNode(original.name());
|
||||
EXPECT_EQ(original.name(), optimized.name()) << func;
|
||||
EXPECT_EQ(original.op(), optimized.op()) << func;
|
||||
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
|
||||
if (original.input_size() == optimized.input_size()) {
|
||||
for (int j = 0; j < original.input_size(); ++j) {
|
||||
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, NoOp) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("B1", "Exp", {"In"}, &graph);
|
||||
AddSimpleNode("C1", "Relu", {"B1"}, &graph);
|
||||
AddSimpleNode("G1", "Sqrt", {"C1"}, &graph);
|
||||
AddSimpleNode("C2", "Relu", {"G1"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("B1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("G1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C2")->attr().at("T").type(), DT_FLOAT);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
NodeDef* cast1 = AddSimpleNode("Cast1", "Cast", {"In"}, &graph);
|
||||
cast1->mutable_attr()->at("DstT").set_type(DT_HALF);
|
||||
NodeDef* w1 = AddSimpleNode("W1", "MatMul", {"Cast1", "Cast1"}, &graph);
|
||||
w1->mutable_attr()->at("T").set_type(DT_HALF);
|
||||
NodeDef* c1 = AddSimpleNode("C1", "Relu", {"W1"}, &graph);
|
||||
c1->mutable_attr()->at("T").set_type(DT_HALF);
|
||||
NodeDef* cast2 = AddSimpleNode("Cast2", "Cast", {"C1"}, &graph);
|
||||
cast2->mutable_attr()->at("SrcT").set_type(DT_HALF);
|
||||
AddSimpleNode("C2", "Relu", {"Cast2"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("Cast1")->attr().at("DstT").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("C1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Cast2")->attr().at("SrcT").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Cast2")->attr().at("DstT").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C2")->attr().at("T").type(), DT_FLOAT);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, Simple) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("B1", "Exp", {"In"}, &graph);
|
||||
AddSimpleNode("C1", "Relu", {"B1"}, &graph);
|
||||
AddSimpleNode("G1", "Sqrt", {"C1"}, &graph);
|
||||
AddSimpleNode("C2", "Relu", {"G1"}, &graph);
|
||||
AddSimpleNode("W1", "MatMul", {"C2", "C2"}, &graph);
|
||||
AddSimpleNode("C3", "Relu", {"W1"}, &graph);
|
||||
AddSimpleNode("B2", "Exp", {"C3"}, &graph);
|
||||
AddSimpleNode("C4", "Relu", {"B2"}, &graph);
|
||||
AddSimpleNode("B4", "SparseMatMul", {"C4", "C4"}, &graph);
|
||||
AddSimpleNode("C5", "Relu", {"B4"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("B1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("G1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("C3")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("B2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C4")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("B4")->attr().at("Ta").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("B4")->attr().at("Tb").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C5")->attr().at("T").type(), DT_FLOAT);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("C1", "Relu", {"In"}, &graph);
|
||||
AddSimpleNode("C2", "Relu", {"In"}, &graph);
|
||||
AddSimpleNode("W1", "MatMul", {"C1", "C1"}, &graph);
|
||||
AddSimpleNode("C3", "ShapeN", {"C1", "C2"}, &graph);
|
||||
AddSimpleNode("C4", "Relu", {"C2"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 1);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("C2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("C3")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("C4")->attr().at("T").type(), DT_HALF);
|
||||
};
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("Const1", "Const", {}, &graph);
|
||||
AddSimpleNode("W1", "MatMul", {"In", "Const1"}, &graph);
|
||||
AddSimpleNode("C1", "Relu", {"W1"}, &graph);
|
||||
AddSimpleNode("G1", "Sqrt", {"C1"}, &graph);
|
||||
AddSimpleNode("B1", "Exp", {"G1"}, &graph);
|
||||
AddSimpleNode("C2", "Relu", {"B1"}, &graph);
|
||||
AddSimpleNode("W2", "MatMul", {"C2", "C2"}, &graph);
|
||||
AddSimpleNode("C3", "Relu", {"W2"}, &graph);
|
||||
AddSimpleNode("B2", "Exp", {"C3"}, &graph);
|
||||
AddSimpleNode("C4", "Relu", {"B2"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
item.fetch.push_back("W1");
|
||||
item.fetch.push_back("C2");
|
||||
item.fetch.push_back("C3");
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("Const1")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("G1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("B1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("W2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("C3")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("B2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C4")->attr().at("T").type(), DT_FLOAT);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("C1", "Relu", {"In"}, &graph);
|
||||
AddSimpleNode("W1", "MatMul", {"C1", "C1"}, &graph);
|
||||
AddSimpleNode("G1", "Tanh", {"W1"}, &graph);
|
||||
NodeDef* w2 = AddSimpleNode("W2", "MatMul", {"G1", "G1"}, &graph);
|
||||
w2->set_device("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
AddSimpleNode("C2", "Relu", {"W2"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("G1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("W2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C2")->attr().at("T").type(), DT_FLOAT);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("V1", "VariableV2", {}, &graph);
|
||||
AddSimpleNode("C1", "Identity", {"V1"}, &graph);
|
||||
AddSimpleNode("W1", "MatMul", {"In", "C1"}, &graph);
|
||||
AddSimpleNode("VarHandle1", "VarHandleOp", {}, &graph);
|
||||
AddSimpleNode("V2", "ReadVariableOp", {"VarHandle1"}, &graph);
|
||||
AddSimpleNode("W2", "MatMul", {"In", "V2"}, &graph);
|
||||
AddSimpleNode("Const1", "Const", {}, &graph);
|
||||
AddSimpleNode("C2", "Identity", {"Const1"}, &graph);
|
||||
AddSimpleNode("W3", "MatMul", {"In", "C2"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 4);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("V1")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("VarHandle1")->attr().at("dtype").type(),
|
||||
DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("V2")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("W2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Const1")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("C2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W3")->attr().at("T").type(), DT_HALF);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("X", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("Const1", "Const", {}, &graph);
|
||||
AddSimpleNode("Scale", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("Offset", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("Mean", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("Variance", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("W1", "Conv2D", {"X", "Const1"}, &graph);
|
||||
AddSimpleNode("BN1", "FusedBatchNorm",
|
||||
{"W1", "Scale", "Offset", "Mean", "Variance"}, &graph);
|
||||
AddSimpleNode("BNG1", "FusedBatchNormGrad",
|
||||
{"BN1", "W1", "Scale", "Mean", "Variance"}, &graph);
|
||||
AddSimpleNode("G1", "Add", {"BN1", "BNG1"}, &graph);
|
||||
AddSimpleNode("W2", "Conv2D", {"G1", "Const1"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("BN1")->op(), "FusedBatchNormV2");
|
||||
EXPECT_EQ(output_view.GetNode("BN1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("BN1")->attr().at("U").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("BNG1")->op(), "FusedBatchNormGradV2");
|
||||
EXPECT_EQ(output_view.GetNode("BNG1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("BNG1")->attr().at("U").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("G1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W2")->attr().at("T").type(), DT_HALF);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("W1", "MatMul", {"In", "In"}, &graph);
|
||||
AddSimpleNode("ID", "IdentityN", {"W1", "W1", "W1"}, &graph);
|
||||
AddSimpleNode("G1", "AddN", {"ID:0", "ID:1", "ID:2"}, &graph);
|
||||
AddSimpleNode("W2", "MatMul", {"G1", "G1"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 1);
|
||||
EXPECT_EQ(output_view.GetNode("In")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
for (auto type : output_view.GetNode("ID")->attr().at("T").list().type()) {
|
||||
EXPECT_EQ(type, DT_HALF);
|
||||
}
|
||||
EXPECT_EQ(output_view.GetNode("G1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W2")->attr().at("T").type(), DT_HALF);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, ExistingCast) {
|
||||
GraphDef graph;
|
||||
NodeDef* ph = AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
ph->mutable_attr()->at("dtype").set_type(DT_BOOL);
|
||||
NodeDef* cast = AddSimpleNode("Cast1", "Cast", {"In"}, &graph);
|
||||
cast->mutable_attr()->at("SrcT").set_type(DT_BOOL);
|
||||
AddSimpleNode("W1", "MatMul", {"Cast1", "Cast1"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size());
|
||||
EXPECT_EQ(output_view.GetNode("Cast1")->attr().at("DstT").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, StackV2) {
|
||||
GraphDef graph;
|
||||
AddSimpleNode("Handle1", "Const", {}, &graph);
|
||||
AddSimpleNode("Stack1", "StackV2", {"Handle1"}, &graph);
|
||||
AddSimpleNode("In", "Placeholder", {}, &graph);
|
||||
AddSimpleNode("Push1", "StackPushV2", {"Stack1", "In"}, &graph);
|
||||
AddSimpleNode("W1", "MatMul", {"In", "In"}, &graph);
|
||||
AddSimpleNode("Push2", "StackPushV2", {"Stack1", "W1"}, &graph);
|
||||
AddSimpleNode("Pop1", "StackPopV2", {"Stack1"}, &graph);
|
||||
AddSimpleNode("G1", "Tanh", {"Pop1"}, &graph);
|
||||
AddSimpleNode("W2", "MatMul", {"G1", "G1"}, &graph);
|
||||
AddSimpleNode("Push3", "StackPushV2", {"Stack1", "W2"}, &graph);
|
||||
AddSimpleNode("Handle2", "Const", {}, &graph);
|
||||
AddSimpleNode("Stack2", "StackV2", {"Handle2"}, &graph);
|
||||
AddSimpleNode("Push1-2", "StackPushV2", {"Stack2", "In"}, &graph);
|
||||
AddSimpleNode("Pop1-2", "StackPopV2", {"Stack2"}, &graph);
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = graph;
|
||||
AutoMixedPrecision optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), graph.node_size() + 1);
|
||||
EXPECT_EQ(output_view.GetNode("Stack1")->attr().at("elem_type").type(),
|
||||
DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Push1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Push2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Pop1")->attr().at("elem_type").type(),
|
||||
DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("G1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("W2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Push3")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("Stack2")->attr().at("elem_type").type(),
|
||||
DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("Push1-2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("Pop1-2")->attr().at("elem_type").type(),
|
||||
DT_FLOAT);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
||||
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
@ -44,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -78,7 +80,7 @@ int NumIterations(const RewriterConfig& cfg) {
|
||||
// Check if optimizer is allowed to run only once.
|
||||
bool IsRunOnceOptimizer(const string& name) {
|
||||
return name == "layout" || name == "memory_optimizer" ||
|
||||
name == "loop_optimizer";
|
||||
name == "loop_optimizer" || name == "auto_mixed_precision";
|
||||
}
|
||||
|
||||
// Check if the graphdef contains nodes that indicate TPU execution.
|
||||
@ -115,6 +117,25 @@ Status CompressConstants(GraphDef* graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A helper function to decide whether to enable the automatic mixed precision
|
||||
// optimizer.
|
||||
bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
|
||||
if (opt_level == RewriterConfig::ON ||
|
||||
opt_level == RewriterConfig::AGGRESSIVE) {
|
||||
return true;
|
||||
}
|
||||
if (opt_level == RewriterConfig::OFF) return false;
|
||||
// Default is to check env var, otherwise off.
|
||||
static bool is_enabled = [] {
|
||||
bool ret = false;
|
||||
TF_CHECK_OK(
|
||||
ReadBoolFromEnvVar("TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE",
|
||||
/*default_val=*/false, &ret));
|
||||
return ret;
|
||||
}();
|
||||
return is_enabled;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#define MK_OPT(NAME, VALUE) \
|
||||
@ -128,6 +149,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
|
||||
MK_OPT("shape", new ShapeOptimizer());
|
||||
MK_OPT("remap", new Remapper(cfg_.remapping()));
|
||||
MK_OPT("layout", new LayoutOptimizer());
|
||||
MK_OPT("auto_mixed_precision",
|
||||
new AutoMixedPrecision(cfg_.auto_mixed_precision()));
|
||||
MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
|
||||
MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
|
||||
MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
|
||||
@ -161,6 +184,10 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
if (!cfg_.disable_model_pruning()) {
|
||||
optimizers->push_back(MakeUnique<ModelPruner>());
|
||||
}
|
||||
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
|
||||
optimizers->push_back(
|
||||
MakeUnique<AutoMixedPrecision>(cfg_.auto_mixed_precision()));
|
||||
}
|
||||
if (cfg_.implementation_selector() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(MakeUnique<ImplementationSelector>());
|
||||
}
|
||||
@ -694,6 +721,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
|
||||
rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
|
||||
rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
||||
rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
|
||||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
|
||||
!rewrite_cfg.optimizers().empty() ||
|
||||
!rewrite_cfg.custom_optimizers().empty();
|
||||
}
|
||||
|
@ -81,6 +81,11 @@ message RewriterConfig {
|
||||
// Enable the swap of kernel implementations based on the device placement
|
||||
// (default is ON).
|
||||
Toggle implementation_selector = 22;
|
||||
// Optimize data types (default is OFF).
|
||||
// e.g., This will try to use float16 on GPU which is faster.
|
||||
// Note that this can change the numerical stability of the graph and may
|
||||
// require the use of loss scaling to maintain model convergence.
|
||||
Toggle auto_mixed_precision = 23;
|
||||
// Disable the entire meta optimizer (off by default).
|
||||
bool disable_meta_optimizer = 19;
|
||||
|
||||
|
@ -6096,6 +6096,32 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "auto_mixed_precision_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"grappler/auto_mixed_precision_test.py",
|
||||
],
|
||||
additional_deps = [
|
||||
":client_testlib",
|
||||
":framework_for_generated_wrappers",
|
||||
":array_ops",
|
||||
":constant_op",
|
||||
":dtypes",
|
||||
":math_ops",
|
||||
":nn",
|
||||
":ops",
|
||||
":random_ops",
|
||||
":control_flow_ops",
|
||||
":training",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
],
|
||||
tags = [
|
||||
"grappler",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "nccl_ops_gen",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
|
533
tensorflow/python/grappler/auto_mixed_precision_test.py
Normal file
533
tensorflow/python/grappler/auto_mixed_precision_test.py
Normal file
@ -0,0 +1,533 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Grappler AutoMixedPrecision."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2 as rwcpb2
|
||||
from tensorflow.core.protobuf import config_pb2 as cpb2
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import gradient_descent
|
||||
|
||||
|
||||
def _input(shape):
|
||||
"""Generates an input of a given shape."""
|
||||
return variables.Variable(random_ops.truncated_normal(shape, seed=0))
|
||||
|
||||
|
||||
def _weight(shape):
|
||||
"""Generates a weight of a given shape."""
|
||||
# Note that the lambda is needed to allow construction inside loops.
|
||||
return variables.Variable(
|
||||
lambda: init_ops.glorot_uniform_initializer(seed=0)(shape))
|
||||
|
||||
|
||||
def _bias(shape):
|
||||
"""Generates a bias of a given shape."""
|
||||
return constant_op.constant(0.1, shape=shape)
|
||||
|
||||
|
||||
def _conv2d(x, w):
|
||||
"""Returns a 2d convolution layer with full stride."""
|
||||
return nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')
|
||||
|
||||
|
||||
def _max_pool_2x2(x):
|
||||
"""Downsamples a feature map by 2X."""
|
||||
return nn.max_pool(
|
||||
x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
|
||||
|
||||
|
||||
def _fused_batchnorm(x, scale, offset):
|
||||
"""Batchnorm."""
|
||||
return nn_impl.fused_batch_norm(
|
||||
x, scale=scale, offset=offset, is_training=True)
|
||||
|
||||
|
||||
def _conv_bn(x):
|
||||
"""Conv followed by batchnorm."""
|
||||
i = array_ops.reshape(x, [-1, 8, 8, 1])
|
||||
f = _weight([3, 3, 1, 6])
|
||||
x = _conv2d(i, f)
|
||||
s = _weight([6])
|
||||
o = _weight([6])
|
||||
y, m, v = _fused_batchnorm(x, s, o)
|
||||
y = array_ops.identity(y)
|
||||
return y
|
||||
|
||||
|
||||
def _matmul_act(x):
|
||||
"""Matmul followed by activation."""
|
||||
i = array_ops.reshape(x, [8, 8])
|
||||
f = _weight([8, 8])
|
||||
x = math_ops.matmul(i, f)
|
||||
y = nn.relu(x)
|
||||
return y
|
||||
|
||||
|
||||
def _conv_pool(x):
|
||||
"""Conv followed by pooling."""
|
||||
x_image = array_ops.reshape(x, [-1, 8, 8, 1])
|
||||
w_conv1 = _weight([3, 3, 1, 6])
|
||||
b_conv1 = _bias([6])
|
||||
h_conv1 = nn.relu(nn.bias_add(_conv2d(x_image, w_conv1), b_conv1))
|
||||
h_pool1 = _max_pool_2x2(h_conv1)
|
||||
w_conv2 = _weight([3, 3, 6, 4])
|
||||
b_conv2 = _bias([4])
|
||||
h_conv2 = nn.relu(nn.bias_add(_conv2d(h_pool1, w_conv2), b_conv2))
|
||||
h_pool2 = _max_pool_2x2(h_conv2)
|
||||
return h_pool2
|
||||
|
||||
|
||||
def _simple_loop(x, functor):
|
||||
"""Simple loop whose body is provided by the functor."""
|
||||
init = (constant_op.constant(0), x)
|
||||
c = lambda i, j: i < 4
|
||||
b = lambda i, j: (i+1, functor(j))
|
||||
ij = control_flow_ops.while_loop(c, b, init)
|
||||
return ij
|
||||
|
||||
|
||||
|
||||
def _loop_vars_intertwined(x0, y0, functor_x, functor_y):
|
||||
"""Loop whose loop variables are intertwined."""
|
||||
c = lambda i, j, x, y: j < 4
|
||||
b = lambda i, j, x, y: (j+1, i+1, functor_y(y), functor_x(x))
|
||||
init = (constant_op.constant(0), constant_op.constant(0), x0, y0)
|
||||
ijzw = control_flow_ops.while_loop(c, b, init)
|
||||
return ijzw
|
||||
|
||||
|
||||
def _lstm_cell(prev_c, prev_h, x):
|
||||
""" LSTMCell
|
||||
i: input gate
|
||||
f: forget gate
|
||||
o: output gate
|
||||
c: cell state
|
||||
x: input
|
||||
h: embedding
|
||||
"""
|
||||
bias = _bias([4])
|
||||
w = _weight([8, 16])
|
||||
ifoc = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w)
|
||||
i, f, o, c = array_ops.split(ifoc, 4, axis=1)
|
||||
i = math_ops.sigmoid(nn.bias_add(i, bias))
|
||||
f = math_ops.sigmoid(nn.bias_add(f, bias))
|
||||
o = math_ops.sigmoid(nn.bias_add(o, bias))
|
||||
c = math_ops.tanh(nn.bias_add(c, bias))
|
||||
next_c = f * prev_c + i * c
|
||||
next_h = o * math_ops.tanh(next_c)
|
||||
return next_c, next_h
|
||||
|
||||
|
||||
def _recurrent_lstm(c, h):
|
||||
""" Dynamic single-layer LSTM with TensorArray """
|
||||
def cond(i, c, h, ta_x):
|
||||
return i<4
|
||||
|
||||
def body(i, c, h, ta_x):
|
||||
x = ta_x.read(i)
|
||||
next_c, next_h = _lstm_cell(c, h, x)
|
||||
return (i+1, next_c, next_h, ta_x)
|
||||
|
||||
ta_x = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32,
|
||||
size=4)
|
||||
for i in range(0, 4):
|
||||
ta_x = ta_x.write(
|
||||
i, constant_op.constant(0.1, shape=[8, 4],
|
||||
dtype=dtypes.float32))
|
||||
init = (constant_op.constant(0), c, h, ta_x)
|
||||
r = control_flow_ops.while_loop(cond, body, init)
|
||||
return r
|
||||
|
||||
|
||||
def _make_node_with_color(color, input_tensor, name=None):
|
||||
color = color.lower()
|
||||
if color == 'w': # White node
|
||||
weights = _weight(input_tensor.get_shape().as_list())
|
||||
return math_ops.matmul(input_tensor, weights, name=name)
|
||||
elif color == 'g': # Gray node
|
||||
return math_ops.sqrt(input_tensor, name=name)
|
||||
elif color == 'c': # Clear node
|
||||
return nn.relu(input_tensor, name=name)
|
||||
elif color == 'b': # Black node
|
||||
return math_ops.log(input_tensor, name=name)
|
||||
else:
|
||||
raise ValueError("Invalid node color: " + str(color))
|
||||
|
||||
|
||||
def _build_intertwined_loop_graph(inpA_colors, inpB_colors, bodyA_colors,
|
||||
bodyB_colors, outA_colors, outB_colors):
|
||||
a = _input([8, 8])
|
||||
for i, color in enumerate(inpA_colors):
|
||||
a = _make_node_with_color(color, a, 'inputA_%i' % i)
|
||||
b = _input([8, 8])
|
||||
for i, color in enumerate(inpB_colors):
|
||||
b = _make_node_with_color(color, b, 'inputB_%i' % i)
|
||||
def bodyA(x):
|
||||
for i, color in enumerate(bodyA_colors):
|
||||
x = _make_node_with_color(color, x, 'bodyA_%i' % i)
|
||||
return x
|
||||
def bodyB(x):
|
||||
for i, color in enumerate(bodyB_colors):
|
||||
x = _make_node_with_color(color, x, 'bodyB_%i' % i)
|
||||
return x
|
||||
a, b = _loop_vars_intertwined(a, b, bodyA, bodyB)[2:]
|
||||
for i, color in enumerate(outA_colors):
|
||||
a = _make_node_with_color(color, a, 'outputA_%i' % i)
|
||||
for i, color in enumerate(outB_colors):
|
||||
b = _make_node_with_color(color, b, 'outputB_%i' % i)
|
||||
a = array_ops.identity(a)
|
||||
b = array_ops.identity(b)
|
||||
return a, b
|
||||
|
||||
|
||||
def _get_config(auto_mixed_precision=True):
|
||||
if auto_mixed_precision:
|
||||
rewrite_config = rwcpb2.RewriterConfig(
|
||||
auto_mixed_precision=rwcpb2.RewriterConfig.ON,
|
||||
# do not remove duplicated nodes
|
||||
arithmetic_optimization=rwcpb2.RewriterConfig.OFF)
|
||||
else:
|
||||
rewrite_config = rwcpb2.RewriterConfig(
|
||||
auto_mixed_precision=rwcpb2.RewriterConfig.OFF,
|
||||
# do not remove duplicated nodes
|
||||
arithmetic_optimization=rwcpb2.RewriterConfig.OFF)
|
||||
rewrite_config.min_graph_nodes = -1
|
||||
graph_options = cpb2.GraphOptions(
|
||||
rewrite_options=rewrite_config, build_cost_model=1)
|
||||
config = cpb2.ConfigProto(graph_options=graph_options)
|
||||
config.graph_options.optimizer_options.opt_level = -1
|
||||
return config
|
||||
|
||||
|
||||
def _is_cast_to_fp16(node_name):
|
||||
return node_name.endswith('-CastToFp16-AutoMixedPrecision')
|
||||
|
||||
|
||||
def _is_cast_to_fp32(node_name):
|
||||
return node_name.endswith('-CastToFp32-AutoMixedPrecision')
|
||||
|
||||
|
||||
def _count_casts(nodes):
|
||||
num_to_fp16 = 0
|
||||
num_to_fp32 = 0
|
||||
for node in nodes:
|
||||
if _is_cast_to_fp16(node.name):
|
||||
num_to_fp16 += 1
|
||||
elif _is_cast_to_fp32(node.name):
|
||||
num_to_fp32 += 1
|
||||
return num_to_fp16, num_to_fp32
|
||||
|
||||
|
||||
def _build_node_map(nodes):
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
node_map[node.name] = node
|
||||
return node_map
|
||||
|
||||
|
||||
class AutoMixedPrecisionTest(test.TestCase):
|
||||
"""Tests the Grappler auto mixed precision optimizer."""
|
||||
MIN_GPU_ARCH = (7, 0)
|
||||
|
||||
def _assert_output_fp16(self, node_map, node_name, output_port=0):
|
||||
self.assertEqual(node_map[node_name].output_info[output_port].dtype,
|
||||
types_pb2.DT_HALF)
|
||||
|
||||
def _run(self, fetches):
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
output_val_ref = self.evaluate(fetches)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
metadata = cpb2.RunMetadata()
|
||||
output_val = sess.run(fetches, run_metadata=metadata)
|
||||
|
||||
return output_val_ref, output_val, metadata.cost_graph
|
||||
|
||||
def _run_intertwined_loop_test(self, inpA, inpB, bodyA, bodyB, outA, outB,
|
||||
expected_num_to_fp16, expected_num_to_fp32):
|
||||
"""Runs a test of an intertwined loop with different node colors in
|
||||
different sections of the graph. The arguments must be strings where each
|
||||
character represents the color of a node in that section of the graph:
|
||||
w = white, g = gray, c = clear, b = black. CAPITALIZED characters indicate
|
||||
that the node is expected to be changed to DT_HALF during graph
|
||||
optimization.
|
||||
|
||||
inpA -> loop [ bodyA ] -> outA
|
||||
: |
|
||||
=====<<=====
|
||||
| :
|
||||
inpB -> loop [ bodyB ] -> outB
|
||||
"""
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
expected_types = []
|
||||
for section in [inpA, inpB, bodyA, bodyB, outA, outB]:
|
||||
section_expected_types = []
|
||||
for color in section:
|
||||
if color.isupper():
|
||||
expected_type = types_pb2.DT_HALF
|
||||
else:
|
||||
expected_type = types_pb2.DT_FLOAT
|
||||
section_expected_types.append(expected_type)
|
||||
expected_types.append(section_expected_types)
|
||||
|
||||
a, b = _build_intertwined_loop_graph(inpA, inpB, bodyA, bodyB, outA, outB)
|
||||
output_val_ref, output_val, cost_graph = self._run((a, b))
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
|
||||
section_names = ['inputA', 'inputB', 'while/bodyA', 'while/bodyB',
|
||||
'outputA', 'outputB']
|
||||
all_types_correct = True
|
||||
for section_name, expected_types in zip(section_names, expected_types):
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
node_name = section_name + '_%i' % i
|
||||
output_port = 0
|
||||
optimized_type = node_map[node_name].output_info[output_port].dtype
|
||||
if (optimized_type != expected_type):
|
||||
print("Expected node %s to have type %s but got type %s" %
|
||||
(node_name, expected_type, optimized_type))
|
||||
all_types_correct = False
|
||||
self.assertTrue(all_types_correct)
|
||||
self.assertEqual(num_to_fp16, expected_num_to_fp16)
|
||||
self.assertEqual(num_to_fp32, expected_num_to_fp32)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testConvBN(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
x = _conv_bn(x)
|
||||
output = _conv_bn(x)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNorm')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
self.assertEqual(num_to_fp16, 3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
|
||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNorm:0
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testConvBNDropout(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
y = _conv_bn(x)
|
||||
y = nn.dropout(y, rate=0.5)
|
||||
y = _conv_bn(y)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNorm')
|
||||
self._assert_output_fp16(node_map, 'dropout/mul')
|
||||
self._assert_output_fp16(node_map, 'dropout/Cast')
|
||||
self._assert_output_fp16(node_map, 'dropout/mul_1')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testConvPool(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
output = _conv_pool(x)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'Relu')
|
||||
self._assert_output_fp16(node_map, 'MaxPool')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testSimpleLoop(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y = _simple_loop(x, _matmul_act)[1]
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
||||
self._assert_output_fp16(node_map, 'while/Relu')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testLoopWithVarsIntertwined(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
i, j, k, l = _loop_vars_intertwined(array_ops.ones(array_ops.shape(x)),
|
||||
x, _matmul_act, _matmul_act)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(k, [x])
|
||||
output = (k, l, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
||||
self._assert_output_fp16(node_map, 'while/Relu')
|
||||
self._assert_output_fp16(node_map, 'while/MatMul_1')
|
||||
self._assert_output_fp16(node_map, 'while/Relu_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testMultiPaths(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 3])
|
||||
x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3)
|
||||
y1 = _conv_pool(x1)
|
||||
y2 = _conv_pool(x2)
|
||||
y3 = _conv_pool(x3)
|
||||
y = array_ops.concat([y1, y2, y3], axis=3)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'split')
|
||||
for suffix in [''] + ['_%i' % i for i in range(1, 6)]:
|
||||
self._assert_output_fp16(node_map, 'Conv2D' + suffix)
|
||||
self._assert_output_fp16(node_map, 'Relu' + suffix)
|
||||
self._assert_output_fp16(node_map, 'MaxPool' + suffix)
|
||||
self._assert_output_fp16(node_map, 'concat')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testMultiPaths2(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y1 = _matmul_act(x)
|
||||
y2 = _matmul_act(x)
|
||||
y = y1 + y2 + x
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (g, y)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'MatMul')
|
||||
self._assert_output_fp16(node_map, 'Relu')
|
||||
self._assert_output_fp16(node_map, 'MatMul_1')
|
||||
self._assert_output_fp16(node_map, 'Relu_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testRecurrentLSTM(self):
|
||||
if test.is_gpu_available(cuda_only=True,
|
||||
min_cuda_compute_capability=self.MIN_GPU_ARCH):
|
||||
random_seed.set_random_seed(0)
|
||||
init_c = _input([8, 4])
|
||||
init_h = _input([8, 4])
|
||||
i, c, h, ta = _recurrent_lstm(init_c, init_h)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(h, [init_c, init_h])
|
||||
output = (h, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/concat')
|
||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
||||
self._assert_output_fp16(node_map, 'while/split')
|
||||
self._assert_output_fp16(node_map, 'while/Sigmoid')
|
||||
self._assert_output_fp16(node_map, 'while/Sigmoid_1')
|
||||
self._assert_output_fp16(node_map, 'while/Sigmoid_2')
|
||||
self._assert_output_fp16(node_map, 'while/Tanh')
|
||||
self._assert_output_fp16(node_map, 'while/Tanh_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop1(self):
|
||||
self._run_intertwined_loop_test('C', 'C', 'bgW', 'C', 'g', 'b', 4, 3)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop2(self):
|
||||
# Note that this results in NextIteration and Merge being painted different
|
||||
# colors, requiring NextIteration to be forced to match.
|
||||
self._run_intertwined_loop_test('b', 'g', 'gW', 'C', 'c', 'C', 3, 2)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop3(self):
|
||||
self._run_intertwined_loop_test('g', 'g', 'g', 'g', 'W', 'c', 3, 2)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop4(self):
|
||||
self._run_intertwined_loop_test('W', 'g', 'g', 'g', 'g', 'g', 3, 2)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop5(self):
|
||||
self._run_intertwined_loop_test('W', 'c', 'b', 'c', 'c', 'W', 4, 2)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop6(self):
|
||||
self._run_intertwined_loop_test('b', 'g', 'g', 'g', 'g', 'W', 2, 1)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop7(self):
|
||||
self._run_intertwined_loop_test('c', 'c', 'bWg', 'c', 'g', 'b', 2, 1)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop8(self):
|
||||
self._run_intertwined_loop_test('C', 'C', 'C', 'C', 'W', 'g', 3, 2)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop9(self):
|
||||
self._run_intertwined_loop_test('W', 'g', 'G', 'G', 'g', 'W', 4, 2)
|
||||
|
||||
def testPropagationThroughIntertwinedLoop10(self):
|
||||
self._run_intertwined_loop_test('g', 'g', 'GWG', 'G', 'g', 'g', 3, 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user