Merge pull request #40596 from reedwm:auto_mp_mkl2
PiperOrigin-RevId: 317381920 Change-Id: I8e7fe93090dafeedba1e7dccfb093d16c6e5b742
This commit is contained in:
commit
8e88146931
@ -1,11 +1,14 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cc_test_mkl", "tf_kernel_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cc_test_mkl", "tf_copts", "tf_cuda_cc_test", "tf_kernel_library")
|
||||
|
||||
# Platform specific build config
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config_root.bzl",
|
||||
"if_static",
|
||||
)
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"mkl_deps",
|
||||
)
|
||||
|
||||
package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -610,6 +613,7 @@ cc_library(
|
||||
"auto_mixed_precision.h",
|
||||
"auto_mixed_precision_lists.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":custom_graph_optimizer_registry",
|
||||
@ -627,7 +631,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler/costs:virtual_placer",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
] + mkl_deps(),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
@ -52,6 +53,7 @@ const std::pair<int, int> kMinGPUArch = {0, 0};
|
||||
|
||||
const char kSuffix[] = "AutoMixedPrecision";
|
||||
const char kCastToFp16[] = "CastToFp16";
|
||||
const char kCastToBf16[] = "CastToBf16";
|
||||
const char kCastToFp32[] = "CastToFp32";
|
||||
|
||||
// Instances of this class represent unique type attribute identifiers within a
|
||||
@ -840,22 +842,6 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
|
||||
return AllowedDataTypes(*attr_def);
|
||||
}
|
||||
|
||||
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_fp16,
|
||||
const string& device) {
|
||||
const char* cast_string = to_fp16 ? kCastToFp16 : kCastToFp32;
|
||||
string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
|
||||
cast_string, "-", kSuffix);
|
||||
NodeDef node;
|
||||
node.set_name(name);
|
||||
node.set_op("Cast");
|
||||
node.set_device(device);
|
||||
node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
|
||||
(*node.mutable_attr())["SrcT"].set_type(to_fp16 ? DT_FLOAT : DT_HALF);
|
||||
(*node.mutable_attr())["DstT"].set_type(to_fp16 ? DT_HALF : DT_FLOAT);
|
||||
(*node.mutable_attr())["Truncate"].set_b(false);
|
||||
return node;
|
||||
}
|
||||
|
||||
Status ValidateLists(const gtl::FlatSet<string>& white_list,
|
||||
const gtl::FlatSet<string>& black_list,
|
||||
const gtl::FlatSet<string>& gray_list,
|
||||
@ -941,7 +927,8 @@ class AutoMixedPrecisionImpl {
|
||||
public:
|
||||
AutoMixedPrecisionImpl(Cluster* cluster,
|
||||
const std::unordered_set<string>& nodes_to_preserve,
|
||||
GraphDef* graph, string id)
|
||||
GraphDef* graph, string id,
|
||||
AutoMixedPrecisionMode mode)
|
||||
: virtual_placer_(cluster->GetDevices()),
|
||||
nodes_to_preserve_(nodes_to_preserve),
|
||||
graph_(graph),
|
||||
@ -949,23 +936,35 @@ class AutoMixedPrecisionImpl {
|
||||
id_(id),
|
||||
graph_view_(graph),
|
||||
cuda_version_(GetCudaVersion(*cluster)),
|
||||
cudnn_version_(GetCudnnVersion(*cluster)) {}
|
||||
cudnn_version_(GetCudnnVersion(*cluster)),
|
||||
mode_(mode),
|
||||
target_dtype_(mode_ == AutoMixedPrecisionMode::CUDA ? DT_HALF
|
||||
: DT_BFLOAT16) {}
|
||||
|
||||
Status Optimize();
|
||||
|
||||
private:
|
||||
typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
|
||||
|
||||
std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
|
||||
switch (mode_) {
|
||||
case AutoMixedPrecisionMode::CUDA:
|
||||
return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
|
||||
cudnn_version_);
|
||||
case AutoMixedPrecisionMode::MKL:
|
||||
return std::make_unique<AutoMixedPrecisionListsMkl>();
|
||||
}
|
||||
}
|
||||
Status PrintDebugLogs(bool preop, size_t timestamp);
|
||||
void LogSkippedNode(const NodeDef& node) const;
|
||||
bool MustPreserve(const NodeDef& node) const;
|
||||
bool IsOnGPU(const NodeDef& node) const;
|
||||
bool IsOnDevice(const NodeDef& node, const string& device_type) const;
|
||||
bool IsOnSuitableGPUArch(const NodeDef& node) const;
|
||||
bool ShouldProcess(const NodeDef& node) const;
|
||||
bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
|
||||
bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
|
||||
bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
|
||||
void ConvertBatchNormOpsToV2();
|
||||
bool SupportsFloat16(const NodeTypeId& node_type) const;
|
||||
bool SupportsF16(const NodeTypeId& node_type) const;
|
||||
const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
|
||||
bool IsSourceOrSinkOp(const string& op) const;
|
||||
void FindFloat32TensorListOpClustersAndBlacklistUnsafe(
|
||||
@ -990,6 +989,8 @@ class AutoMixedPrecisionImpl {
|
||||
absl::flat_hash_set<int>* white_set) const;
|
||||
void MakeCastsWhiteIfAllOutputsWhite(
|
||||
absl::flat_hash_set<int>* white_set) const;
|
||||
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
|
||||
const string& device) const;
|
||||
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
|
||||
|
||||
VirtualPlacer virtual_placer_;
|
||||
@ -1003,21 +1004,44 @@ class AutoMixedPrecisionImpl {
|
||||
NodeTypeAttrMap node_type_map_;
|
||||
GraphTypeTopologyView graph_type_view_;
|
||||
bool force_all_fp16_;
|
||||
gtl::FlatSet<string> fp16_whitelist_;
|
||||
gtl::FlatSet<string> fp16_blacklist_;
|
||||
gtl::FlatSet<string> fp16_graylist_;
|
||||
gtl::FlatSet<string> fp16_clearlist_;
|
||||
AutoMixedPrecisionMode mode_;
|
||||
gtl::FlatSet<string> f16_whitelist_;
|
||||
gtl::FlatSet<string> f16_blacklist_;
|
||||
gtl::FlatSet<string> f16_graylist_;
|
||||
gtl::FlatSet<string> f16_clearlist_;
|
||||
absl::flat_hash_set<const NodeDef*> should_process_nodes_;
|
||||
DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
|
||||
};
|
||||
|
||||
bool AutoMixedPrecisionImpl::NodeHasFP16KernelForTypeAttr(
|
||||
NodeDef AutoMixedPrecisionImpl::BuildCastNode(
|
||||
const MutableGraphView::OutputPort& src, bool to_f16,
|
||||
const string& device) const {
|
||||
DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
|
||||
DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
|
||||
const char* cast_string =
|
||||
!to_f16 ? kCastToFp32
|
||||
: target_dtype_ == DT_HALF ? kCastToFp16 : kCastToBf16;
|
||||
string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
|
||||
cast_string, "-", kSuffix);
|
||||
NodeDef node;
|
||||
node.set_name(name);
|
||||
node.set_op("Cast");
|
||||
node.set_device(device);
|
||||
node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
|
||||
(*node.mutable_attr())["SrcT"].set_type(src_type);
|
||||
(*node.mutable_attr())["DstT"].set_type(dst_type);
|
||||
(*node.mutable_attr())["Truncate"].set_b(false);
|
||||
return node;
|
||||
}
|
||||
|
||||
bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
|
||||
const NodeDef& node, TypeAttrId taid) const {
|
||||
NodeDef node_copy(node);
|
||||
if (node.device().empty()) {
|
||||
string device_name = virtual_placer_.get_canonical_device_name(node);
|
||||
node_copy.set_device(device_name);
|
||||
}
|
||||
if (!SetDataType(&node_copy, taid, DataType::DT_HALF)) {
|
||||
if (!SetDataType(&node_copy, taid, target_dtype_)) {
|
||||
return false;
|
||||
}
|
||||
return IsKernelRegisteredForNode(node_copy).ok();
|
||||
@ -1053,21 +1077,22 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
|
||||
fname = io::JoinPath(prepend_path,
|
||||
strings::StrCat("paintbuckets", suffix, ".txt"));
|
||||
f.open(fname.c_str(), std::fstream::out);
|
||||
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||
get_mixed_precision_lists();
|
||||
f << "WhiteList:\n";
|
||||
for (const auto& x :
|
||||
AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_)) {
|
||||
for (const auto& x : mp_lists->WhiteList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f << "\nBlackList:\n";
|
||||
for (const auto& x : AutoMixedPrecisionLists::BlackList()) {
|
||||
for (const auto& x : mp_lists->BlackList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f << "\nGrayList:\n";
|
||||
for (const auto& x : AutoMixedPrecisionLists::GrayList()) {
|
||||
for (const auto& x : mp_lists->GrayList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f << "\nClearList:\n";
|
||||
for (const auto& x : AutoMixedPrecisionLists::ClearList()) {
|
||||
for (const auto& x : mp_lists->ClearList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f.close();
|
||||
@ -1088,7 +1113,8 @@ bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
|
||||
return nodes_to_preserve_.count(node.name());
|
||||
}
|
||||
|
||||
bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const {
|
||||
bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
|
||||
const string& device_type) const {
|
||||
string device_name;
|
||||
if (node.device().empty()) {
|
||||
device_name = virtual_placer_.get_canonical_device_name(node);
|
||||
@ -1099,7 +1125,7 @@ bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const {
|
||||
string not_used;
|
||||
if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
|
||||
absl::StrContains(absl::AsciiStrToLower(device),
|
||||
absl::AsciiStrToLower(DEVICE_GPU))) {
|
||||
absl::AsciiStrToLower(device_type))) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -1164,15 +1190,14 @@ bool IsTensorListWriterOp(const string& op) {
|
||||
return tensor_list_writer_ops.count(op);
|
||||
}
|
||||
|
||||
bool AutoMixedPrecisionImpl::SupportsFloat16(
|
||||
const NodeTypeId& node_type) const {
|
||||
bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
|
||||
const OpDef* op_def;
|
||||
Status status =
|
||||
OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
|
||||
if (!status.ok()) return false;
|
||||
return AllowedDataTypes(*op_def, node_type.type_attr)
|
||||
.Contains(DataType::DT_HALF) &&
|
||||
NodeHasFP16KernelForTypeAttr(*node_type.node, node_type.type_attr);
|
||||
.Contains(target_dtype_) &&
|
||||
NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
|
||||
}
|
||||
|
||||
// TODO(mconley): Make this change the node's name (to aid debugging). Need to
|
||||
@ -1219,22 +1244,40 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
|
||||
optimization_level = absl::AsciiStrToUpper(optimization_level);
|
||||
force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
|
||||
if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
|
||||
// Many ops do not support bfloat16 on the CPU so we disallowing forcing to
|
||||
// bfloat16.
|
||||
return errors::InvalidArgument(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
|
||||
"UNSAFE_FORCE_ALL when MKL is used");
|
||||
}
|
||||
|
||||
fp16_whitelist_ =
|
||||
AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_);
|
||||
fp16_blacklist_ = AutoMixedPrecisionLists::BlackList();
|
||||
fp16_graylist_ = AutoMixedPrecisionLists::GrayList();
|
||||
fp16_clearlist_ = AutoMixedPrecisionLists::ClearList();
|
||||
TF_RETURN_IF_ERROR(ValidateLists(fp16_whitelist_, fp16_blacklist_,
|
||||
fp16_graylist_, fp16_clearlist_));
|
||||
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||
get_mixed_precision_lists();
|
||||
f16_whitelist_ = mp_lists->WhiteList();
|
||||
f16_blacklist_ = mp_lists->BlackList();
|
||||
f16_graylist_ = mp_lists->GrayList();
|
||||
f16_clearlist_ = mp_lists->ClearList();
|
||||
TF_RETURN_IF_ERROR(ValidateLists(f16_whitelist_, f16_blacklist_,
|
||||
f16_graylist_, f16_clearlist_));
|
||||
|
||||
size_t timestamp = Env::Default()->NowMicros() / 1000;
|
||||
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
|
||||
|
||||
VLOG(2) << "Identifying nodes that should be processed";
|
||||
for (const NodeDef& node : graph_->node()) {
|
||||
if (!MustPreserve(node) && IsOnGPU(node) &&
|
||||
(ShouldIgnorePerformance() || IsOnSuitableGPUArch(node))) {
|
||||
bool should_process;
|
||||
switch (mode_) {
|
||||
case AutoMixedPrecisionMode::CUDA:
|
||||
should_process =
|
||||
!MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
|
||||
(ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
|
||||
break;
|
||||
case AutoMixedPrecisionMode::MKL:
|
||||
should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
|
||||
break;
|
||||
}
|
||||
if (should_process) {
|
||||
should_process_nodes_.insert(&node);
|
||||
} else {
|
||||
LogSkippedNode(node);
|
||||
@ -1260,29 +1303,29 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
for (const auto& cluster : tensor_list_clusters) {
|
||||
VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
|
||||
for (const NodeDef* node : cluster) {
|
||||
VLOG(2) << "Cluster member: " << node->op() << " node " << node->name();
|
||||
VLOG(2) << " Cluster member: " << node->op() << " node " << node->name();
|
||||
}
|
||||
FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
|
||||
|
||||
// The goal here is to change performance-critical ops to fp16, and to do so
|
||||
// with the minimal number of casts, subject to the constraint that the
|
||||
// The goal here is to change performance-critical ops to fp16 or bf16, and to
|
||||
// do so with the minimal number of casts, subject to the constraint that the
|
||||
// model's convergence is not affected. This is achieved by first identifying
|
||||
// which nodes should be changed to fp16 and then inserting casts at the
|
||||
// boundaries between fp16/non-fp16 nodes.
|
||||
// which nodes should be changed to f16 and then inserting casts at the
|
||||
// boundaries between f16/non-f16 nodes.
|
||||
|
||||
// The algorithm for deciding which nodes to change to fp16 is as follows:
|
||||
// The algorithm for deciding which nodes to change to f16 is as follows:
|
||||
// 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
|
||||
// This is done under the assumption that whitelist ops are always
|
||||
// numerically-safe in fp16 and that they are the most important ops for
|
||||
// numerically-safe in f16 and that they are the most important ops for
|
||||
// improving performance.
|
||||
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
|
||||
// "blacklist" ops) or they are on a forward path from a blacklist node to
|
||||
// a black/gray node (including the node at the end of the path) through
|
||||
// non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
|
||||
// This is done to prevent numerically-dangerous ops and their downstream
|
||||
// effects from being changed to fp16, which would risk breaking the
|
||||
// effects from being changed to f16, which would risk breaking the
|
||||
// numerical accuracy of the model.
|
||||
// 3) For all remaining nodes that are not considered dangerous (greylist
|
||||
// and clearlist ops), find those that are between (i.e., both upstream
|
||||
@ -1480,7 +1523,7 @@ void AutoMixedPrecisionImpl::AddWhitelistOps(
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!ShouldProcess(*root.node)) continue;
|
||||
bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
|
||||
if (fp16_whitelist_.count(root.node->op()) || force_white) {
|
||||
if (f16_whitelist_.count(root.node->op()) || force_white) {
|
||||
bool inserted = white_set->insert(root_idx).second;
|
||||
if (VLOG_IS_ON(2) && inserted) {
|
||||
VLOG(2) << "Painting type " << root.type_attr.DebugString()
|
||||
@ -1504,8 +1547,8 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||
absl::flat_hash_set<int> upstream_of_black_or_gray_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!(fp16_blacklist_.count(root.node->op()) ||
|
||||
fp16_graylist_.count(root.node->op()))) {
|
||||
if (!(f16_blacklist_.count(root.node->op()) ||
|
||||
f16_graylist_.count(root.node->op()))) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(graph_type_view_, {&root},
|
||||
@ -1514,7 +1557,7 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
return idx == root_idx ||
|
||||
(!upstream_of_black_or_gray_set.count(idx) &&
|
||||
fp16_clearlist_.count(item.node->op()));
|
||||
f16_clearlist_.count(item.node->op()));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||
upstream_of_black_or_gray_set.insert(idx);
|
||||
@ -1524,7 +1567,7 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||
// Propagate black forward through nodes in upstream_of_black_or_gray_set.
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (black_set->count(root_idx) || !fp16_blacklist_.count(root.node->op())) {
|
||||
if (black_set->count(root_idx) || !f16_blacklist_.count(root.node->op())) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(
|
||||
@ -1552,7 +1595,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
||||
absl::flat_hash_set<int> downstream_of_white_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!ShouldProcess(*root.node) || !fp16_whitelist_.count(root.node->op())) {
|
||||
if (!ShouldProcess(*root.node) || !f16_whitelist_.count(root.node->op())) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(
|
||||
@ -1561,14 +1604,14 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
return idx == root_idx ||
|
||||
(!downstream_of_white_set.count(idx) &&
|
||||
!fp16_whitelist_.count(item.node->op()) &&
|
||||
!f16_whitelist_.count(item.node->op()) &&
|
||||
!black_set.count(idx) && ShouldProcess(*item.node) &&
|
||||
// TODO(benbarsdell): Consider allowing propagation through
|
||||
// ops that are already float16 in order to reduce the number
|
||||
// of casts.
|
||||
IsFloat32(item) && SupportsFloat16(item) &&
|
||||
(fp16_clearlist_.count(item.node->op()) ||
|
||||
fp16_graylist_.count(item.node->op())));
|
||||
IsFloat32(item) && SupportsF16(item) &&
|
||||
(f16_clearlist_.count(item.node->op()) ||
|
||||
f16_graylist_.count(item.node->op())));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder(
|
||||
[&](int idx) { downstream_of_white_set.insert(idx); }));
|
||||
@ -1579,7 +1622,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
|
||||
!fp16_whitelist_.count(root.node->op())) {
|
||||
!f16_whitelist_.count(root.node->op())) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(
|
||||
@ -1620,8 +1663,8 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
||||
return idx == root_idx ||
|
||||
(!white_set->count(idx) && !black_set.count(idx) &&
|
||||
ShouldProcess(*item.node) && IsFloat32(item) &&
|
||||
SupportsFloat16(item) &&
|
||||
(fp16_clearlist_.count(item.node->op())) &&
|
||||
SupportsF16(item) &&
|
||||
(f16_clearlist_.count(item.node->op())) &&
|
||||
// We don't propagate (backwards) through nodes that read
|
||||
// Variables because it can break the behavior of TensorBoard
|
||||
// visualization and/or (in the case of Enter nodes) the model
|
||||
@ -1806,13 +1849,13 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
||||
}
|
||||
}
|
||||
|
||||
// Changes all white-painted type attributes to DT_HALF, and inserts Cast nodes
|
||||
// at node outputs for all edges that connect white-painted <->
|
||||
// non-white-painted type attributes.
|
||||
// Changes all white-painted type attributes to DT_HALF or DT_BFLOAT16, and
|
||||
// inserts Cast nodes at node outputs for all edges that connect
|
||||
// white-painted <-> non-white-painted type attributes.
|
||||
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
const absl::flat_hash_set<int>& white_set) {
|
||||
int num_nodes_changed = 0;
|
||||
int num_nonvar_casts_to_fp16 = 0;
|
||||
int num_nonvar_casts_to_f16 = 0;
|
||||
int num_nodes_preop = graph_->node_size();
|
||||
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
|
||||
NodeDef* node = graph_->mutable_node(node_idx);
|
||||
@ -1829,8 +1872,9 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
bool src_is_white = white_set.count(node_type_idx);
|
||||
if (src_is_white) {
|
||||
VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
|
||||
<< node->op() << " node " << node->name() << " to DT_HALF";
|
||||
if (!SetDataType(node, type_attr, DT_HALF)) {
|
||||
<< node->op() << " node " << node->name() << " to "
|
||||
<< DataTypeString(target_dtype_);
|
||||
if (!SetDataType(node, type_attr, target_dtype_)) {
|
||||
return errors::Internal("Failed to set type attribute");
|
||||
}
|
||||
++num_nodes_changed;
|
||||
@ -1855,16 +1899,16 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
bool dst_is_white = white_set.count(dst_type_idx);
|
||||
if (src_is_white != dst_is_white) {
|
||||
if (!added_cast_node) {
|
||||
bool to_fp16 = dst_is_white;
|
||||
bool to_f16 = dst_is_white;
|
||||
VLOG(1) << "Inserting cast to "
|
||||
<< (to_fp16 ? "DT_HALF" : "DT_FLOAT") << " at "
|
||||
<< src.node->op() << " " << src.node->name() << ":"
|
||||
<< src.port_id;
|
||||
<< (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
|
||||
<< " at " << src.node->op() << " " << src.node->name()
|
||||
<< ":" << src.port_id;
|
||||
added_cast_node = graph_view_.AddNode(
|
||||
BuildCastNode(src, to_fp16, src.node->device()));
|
||||
if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) &&
|
||||
BuildCastNode(src, to_f16, src.node->device()));
|
||||
if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
|
||||
!NodeImplicitlyReadsNonResourceVariable(*node)) {
|
||||
++num_nonvar_casts_to_fp16;
|
||||
++num_nonvar_casts_to_f16;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
|
||||
@ -1874,9 +1918,13 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
|
||||
// since many Python users will see this message.
|
||||
const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
|
||||
LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
|
||||
<< " nodes to float16 precision using " << num_nonvar_casts_to_fp16
|
||||
<< " cast(s) to float16 (excluding Const and Variable casts)";
|
||||
<< " nodes to " << type_str << " precision using "
|
||||
<< num_nonvar_casts_to_f16 << " cast(s) to " << type_str
|
||||
<< " (excluding Const and Variable casts)";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1902,12 +1950,24 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
return errors::InvalidArgument("cluster == nullptr");
|
||||
}
|
||||
|
||||
#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||
if (mode_ == AutoMixedPrecisionMode::MKL) {
|
||||
return errors::Unimplemented(
|
||||
"The auto_mixed_precision_mkl optimizer cannot be used since "
|
||||
"this build of TensorFlow is not compiled with MKL support for "
|
||||
"bfloat16. "
|
||||
"For information on MKL builds, see: "
|
||||
"https://software.intel.com/en-us/articles/intel-optimization-for-"
|
||||
"tensorflow-installation-guide");
|
||||
}
|
||||
#endif
|
||||
|
||||
// Start by copying input graph to output.
|
||||
*output = item.graph;
|
||||
|
||||
int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
|
||||
: GetNumGPUs(*cluster, kMinGPUArch);
|
||||
if (num_gpus < 1) {
|
||||
if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
|
||||
// AutoMixedPrecision is currently only tuned for GPU.
|
||||
LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
|
||||
<< " graph optimizer";
|
||||
@ -1916,7 +1976,7 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// Optimize the output graph in-place.
|
||||
AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
|
||||
item.id);
|
||||
item.id, mode_);
|
||||
if (item.id == "tf_graph") {
|
||||
LOG(INFO) << "Running " << name() << " graph optimizer";
|
||||
} else {
|
||||
|
@ -22,16 +22,25 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Convert data types to float16 where appropriate to improve performance on
|
||||
// GPUs.
|
||||
enum class AutoMixedPrecisionMode { CUDA, MKL };
|
||||
|
||||
// Convert data types to float16 or bfloat16 where appropriate to improve
|
||||
// performance on GPUs or CPUs.
|
||||
class AutoMixedPrecision : public GraphOptimizer {
|
||||
public:
|
||||
// If 'mode' is CUDA, converts nodes to float16 on Nvidia GPUs. If MKL,
|
||||
// converts nodes to bfloat16 on CPUs in order to take advantage of MKL
|
||||
// performance improvements with bfloat16.
|
||||
explicit AutoMixedPrecision(
|
||||
RewriterConfig::Toggle opt_level = RewriterConfig::ON) {}
|
||||
AutoMixedPrecisionMode mode = AutoMixedPrecisionMode::CUDA)
|
||||
: mode_(mode) {}
|
||||
|
||||
~AutoMixedPrecision() override {}
|
||||
|
||||
string name() const override { return "auto_mixed_precision"; };
|
||||
string name() const override {
|
||||
return mode_ == AutoMixedPrecisionMode::CUDA ? "auto_mixed_precision_cuda"
|
||||
: "auto_mixed_precision_mkl";
|
||||
};
|
||||
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
@ -40,6 +49,9 @@ class AutoMixedPrecision : public GraphOptimizer {
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override;
|
||||
|
||||
private:
|
||||
const AutoMixedPrecisionMode mode_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -23,10 +23,43 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Represents the four lists of ops: the white list, gray list, black list, and
|
||||
// clear list. These lists determine which ops are converted to fp16/bf16
|
||||
// (referred to as 'f16' for short) and which ops stay as fp32.
|
||||
class AutoMixedPrecisionLists {
|
||||
private:
|
||||
static void UpdateList(gtl::FlatSet<string>* list, const string& to_add,
|
||||
const string& to_remove) {
|
||||
public:
|
||||
virtual ~AutoMixedPrecisionLists() {}
|
||||
|
||||
// Returns the set of ops that are considered numerically-safe (for execution
|
||||
// in f16), performance-critical, and can run in f16. These ops are always
|
||||
// converted to f16.
|
||||
virtual gtl::FlatSet<string> WhiteList() = 0;
|
||||
// Returns the set of ops that can run in f16 and are considered numerically-
|
||||
// safe (for execution in f16), but which may be made unsafe by an upstream
|
||||
// blacklist op.
|
||||
virtual gtl::FlatSet<string> GrayList() = 0;
|
||||
// Returns the set of ops that are considered numerically-dangerous (i.e.,
|
||||
// unsafe for execution in f16) and whose effects may also be observed in
|
||||
// downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
|
||||
// the Exp).
|
||||
virtual gtl::FlatSet<string> BlackList() = 0;
|
||||
// Returns the set of ops that do not have numerically-significant effects
|
||||
// (i.e., they are always considered safe for execution in f16 precision), and
|
||||
// can run in f16.
|
||||
virtual gtl::FlatSet<string> ClearList() = 0;
|
||||
|
||||
protected:
|
||||
// Adds or removes ops from list if certain environmental variables are set.
|
||||
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
|
||||
CHECK(list_name == "WHITELIST" || list_name == "GRAYLIST" || // Crash OK.
|
||||
list_name == "BLACKLIST" || list_name == "CLEARLIST");
|
||||
string add_env_var =
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
|
||||
string remove_env_var =
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE";
|
||||
string to_add, to_remove;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add));
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove));
|
||||
for (const auto& x : str_util::Split(to_add, ",")) {
|
||||
list->insert(x);
|
||||
}
|
||||
@ -35,6 +68,29 @@ class AutoMixedPrecisionLists {
|
||||
}
|
||||
}
|
||||
|
||||
// Subclasses should include these on the ClearList.
|
||||
static void AddTensorListOps(gtl::FlatSet<string>* list) {
|
||||
// Note: if a data structure op (such as TensorListPopBack) is added here,
|
||||
// IsTensorListReaderOp or IsTensorListWriterOp may need to be modified
|
||||
// LINT.IfChange
|
||||
constexpr const char* tensor_list_ops[] = {
|
||||
"TensorListConcat", "TensorListConcatLists",
|
||||
"TensorListConcatV2", "TensorListGather",
|
||||
"TensorListGetItem", "TensorListPopBack",
|
||||
"TensorListPushBack", "TensorListPushBackBatch",
|
||||
"TensorListFromTensor", "TensorListScatter",
|
||||
"TensorListScatterV2", "TensorListScatterIntoExistingList",
|
||||
"TensorListSetItem", "TensorListSplit",
|
||||
"TensorListStack"};
|
||||
// LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc)
|
||||
for (auto op : tensor_list_ops) {
|
||||
list->insert(op);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
private:
|
||||
static bool IsPseudoFastMath() {
|
||||
string optimization_level;
|
||||
TF_CHECK_OK(
|
||||
@ -45,16 +101,10 @@ class AutoMixedPrecisionLists {
|
||||
}
|
||||
|
||||
public:
|
||||
// Returns the set of ops that are considered numerically-safe (for execution
|
||||
// in fp16) and performance-critical. These ops are always converted to fp16.
|
||||
static gtl::FlatSet<string> WhiteList(int cuda_version, int cudnn_version) {
|
||||
string to_add, to_remove;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_ADD", "", &to_add));
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_REMOVE", "",
|
||||
&to_remove));
|
||||
AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
|
||||
: cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
|
||||
|
||||
gtl::FlatSet<string> WhiteList() override {
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"BlockLSTM",
|
||||
"BlockLSTMV2",
|
||||
@ -81,12 +131,12 @@ class AutoMixedPrecisionLists {
|
||||
// "DepthwiseConv2dNativeBackpropInput",
|
||||
"MatMul",
|
||||
};
|
||||
if (cuda_version >= 9010) {
|
||||
if (cuda_version_ >= 9010) {
|
||||
// Fp16 BatchMatMul is slow before CUDA 9.1.
|
||||
list.insert("BatchMatMul");
|
||||
list.insert("BatchMatMulV2");
|
||||
}
|
||||
if (cudnn_version >= 7602) {
|
||||
if (cudnn_version_ >= 7602) {
|
||||
// Fp16 3D conv is slow before CUDNN 7.6.2.
|
||||
list.insert("Conv3D");
|
||||
list.insert("Conv3DBackpropFilter");
|
||||
@ -94,22 +144,14 @@ class AutoMixedPrecisionLists {
|
||||
list.insert("Conv3DBackpropInput");
|
||||
list.insert("Conv3DBackpropInputV2");
|
||||
}
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
UpdateList("WHITELIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
// Returns the set of ops that are considered numerically-safe (for execution
|
||||
// in fp16), but which may be made unsafe by an upstream blacklist op.
|
||||
static gtl::FlatSet<string> GrayList() {
|
||||
gtl::FlatSet<string> GrayList() override {
|
||||
if (IsPseudoFastMath()) {
|
||||
return gtl::FlatSet<string>{};
|
||||
}
|
||||
string to_add, to_remove;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_ADD", "", &to_add));
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_REMOVE", "",
|
||||
&to_remove));
|
||||
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Add",
|
||||
@ -156,23 +198,14 @@ class AutoMixedPrecisionLists {
|
||||
"Tanh",
|
||||
"TanhGrad",
|
||||
};
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
UpdateList("GRAYLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
// Returns the set of ops that are considered numerically-dangerous (i.e.,
|
||||
// unsafe for execution in fp16) and whose effects may also be observed in
|
||||
// downstream nodes (e.g., in Exp -> Add, the Add is unsafe due to the Exp).
|
||||
static gtl::FlatSet<string> BlackList() {
|
||||
gtl::FlatSet<string> BlackList() override {
|
||||
if (IsPseudoFastMath()) {
|
||||
return gtl::FlatSet<string>{};
|
||||
}
|
||||
string to_add, to_remove;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_ADD", "", &to_add));
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_REMOVE", "",
|
||||
&to_remove));
|
||||
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Exp",
|
||||
@ -185,22 +218,14 @@ class AutoMixedPrecisionLists {
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"Sum",
|
||||
};
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
UpdateList("BLACKLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
// Returns the set of ops that do not have numerically-significant effects
|
||||
// (i.e., they are always considered safe for execution in fp16 precision).
|
||||
static gtl::FlatSet<string> ClearList() {
|
||||
gtl::FlatSet<string> ClearList() override {
|
||||
if (IsPseudoFastMath()) {
|
||||
return gtl::FlatSet<string>{};
|
||||
}
|
||||
string to_add, to_remove;
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_ADD", "", &to_add));
|
||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_REMOVE", "",
|
||||
&to_remove));
|
||||
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Abs",
|
||||
@ -291,21 +316,6 @@ class AutoMixedPrecisionLists {
|
||||
"StridedSlice",
|
||||
"StridedSliceGrad",
|
||||
"Switch",
|
||||
"TensorListConcat",
|
||||
"TensorListConcatLists",
|
||||
"TensorListConcatV2",
|
||||
"TensorListGather",
|
||||
"TensorListGetItem",
|
||||
"TensorListPopBack",
|
||||
"TensorListPushBack",
|
||||
"TensorListPushBackBatch",
|
||||
"TensorListFromTensor",
|
||||
"TensorListScatter",
|
||||
"TensorListScatterV2",
|
||||
"TensorListScatterIntoExistingList",
|
||||
"TensorListSetItem",
|
||||
"TensorListSplit",
|
||||
"TensorListStack",
|
||||
"Tile",
|
||||
"TopK",
|
||||
"TopKV2",
|
||||
@ -313,7 +323,96 @@ class AutoMixedPrecisionLists {
|
||||
"Where",
|
||||
"ZerosLike",
|
||||
};
|
||||
UpdateList(&list, to_add, to_remove);
|
||||
AddTensorListOps(&list);
|
||||
UpdateList("CLEARLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
private:
|
||||
int cuda_version_;
|
||||
int cudnn_version_;
|
||||
};
|
||||
|
||||
class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||
public:
|
||||
AutoMixedPrecisionListsMkl() {}
|
||||
|
||||
// Only ops which are supported by MKL in bfloat16 should be added to the
|
||||
// white list, gray list, or clear list.
|
||||
gtl::FlatSet<string> WhiteList() override {
|
||||
auto list = gtl::FlatSet<string>{"Conv2D",
|
||||
"Conv2DBackpropFilter",
|
||||
"Conv2DBackpropInput",
|
||||
"Conv3D",
|
||||
"Conv3DBackpropFilterV2",
|
||||
"Conv3DBackpropInputV2",
|
||||
"DepthwiseConv2dNative",
|
||||
"DepthwiseConv2dNativeBackpropFilter",
|
||||
"DepthwiseConv2dNativeBackpropInput",
|
||||
"MatMul",
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2"};
|
||||
|
||||
UpdateList("WHITELIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
gtl::FlatSet<string> GrayList() override {
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Add",
|
||||
"AddN",
|
||||
"AddV2",
|
||||
"AvgPool",
|
||||
"AvgPool3D",
|
||||
"AvgPool3DGrad",
|
||||
"AvgPoolGrad",
|
||||
"BiasAdd",
|
||||
"BiasAddGrad",
|
||||
"BiasAddV1",
|
||||
"FusedBatchNormV2",
|
||||
"FusedBatchNormGradV2",
|
||||
"FusedBatchNormV3",
|
||||
"FusedBatchNormGradV3",
|
||||
"LeakyRelu",
|
||||
"LeakyReluGrad",
|
||||
"Mul",
|
||||
"Sub",
|
||||
};
|
||||
UpdateList("GRAYLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
gtl::FlatSet<string> BlackList() override {
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Exp",
|
||||
"Expm1",
|
||||
"L2Loss",
|
||||
"Mean",
|
||||
"Pow",
|
||||
"SaveV2",
|
||||
"Softmax",
|
||||
"SoftmaxCrossEntropyWithLogits",
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"Sum",
|
||||
};
|
||||
UpdateList("BLACKLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
gtl::FlatSet<string> ClearList() override {
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"Concat", "ConcatV2", "Enter", "EnsureShape",
|
||||
"Equal", "Exit", "ExpandDims", "Identity",
|
||||
"MaxPool", "MaxPool3D", "MaxPool3DGrad", "MaxPoolGrad",
|
||||
"MaxPoolV2", "Maximum", "Merge", "NextIteration",
|
||||
"PreventGradient", "Relu", "Relu6", "Relu6Grad",
|
||||
"ReluGrad", "Reshape", "Select", "SelectV2",
|
||||
"Shape", "ShapeN", "Slice", "Split",
|
||||
"SplitV", "Squeeze", "StopGradient", "Switch",
|
||||
"Transpose", "ZerosLike",
|
||||
};
|
||||
AddTensorListOps(&list);
|
||||
UpdateList("CLEARLIST", &list);
|
||||
return list;
|
||||
}
|
||||
};
|
||||
|
@ -13,11 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Currently, this test only passes when TensorFlow passes with CUDA, because
|
||||
// otherwise the optimizer will not turn clearlist nodes to float16. When
|
||||
// looking at clearlist nodes, this optimizer checks if the nodes have a float16
|
||||
// GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || \
|
||||
(INTEL_MKL && defined(ENABLE_INTEL_MKL_BFLOAT16))
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
||||
|
||||
@ -70,6 +67,31 @@ Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void VerifyGraphsEquivalent(const GraphDef& original_graph,
|
||||
const GraphDef& optimized_graph,
|
||||
const string& func) {
|
||||
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
|
||||
GraphView optimized_view(&optimized_graph);
|
||||
for (int i = 0; i < original_graph.node_size(); ++i) {
|
||||
const NodeDef& original = original_graph.node(i);
|
||||
const NodeDef& optimized = *optimized_view.GetNode(original.name());
|
||||
EXPECT_EQ(original.name(), optimized.name()) << func;
|
||||
EXPECT_EQ(original.op(), optimized.op()) << func;
|
||||
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
|
||||
if (original.input_size() == optimized.input_size()) {
|
||||
for (int j = 0; j < original.input_size(); ++j) {
|
||||
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Currently, this test suite only passes when TensorFlow passes with CUDA,
|
||||
// because otherwise the optimizer will not turn clearlist nodes to float16.
|
||||
// When looking at clearlist nodes, this optimizer checks if the nodes have a
|
||||
// float16 GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
const std::pair<int, int> kMinGPUArch = {7, 0};
|
||||
|
||||
class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
@ -184,25 +206,6 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
bool gpu_available_;
|
||||
};
|
||||
|
||||
void VerifyGraphsEquivalent(const GraphDef& original_graph,
|
||||
const GraphDef& optimized_graph,
|
||||
const string& func) {
|
||||
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
|
||||
GraphView optimized_view(&optimized_graph);
|
||||
for (int i = 0; i < original_graph.node_size(); ++i) {
|
||||
const NodeDef& original = original_graph.node(i);
|
||||
const NodeDef& optimized = *optimized_view.GetNode(original.name());
|
||||
EXPECT_EQ(original.name(), optimized.name()) << func;
|
||||
EXPECT_EQ(original.op(), optimized.op()) << func;
|
||||
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
|
||||
if (original.input_size() == optimized.input_size()) {
|
||||
for (int j = 0; j < original.input_size(); ++j) {
|
||||
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionTest, NoOp) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
|
||||
@ -1164,8 +1167,191 @@ TEST_F(AutoMixedPrecisionTest, TanhOp) {
|
||||
});
|
||||
}
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#if INTEL_MKL
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
|
||||
class AutoMixedPrecisionMklTest : public GrapplerTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
|
||||
TF_CHECK_OK(virtual_cluster_->Provision());
|
||||
}
|
||||
void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
|
||||
|
||||
std::unique_ptr<Cluster> virtual_cluster_;
|
||||
};
|
||||
|
||||
TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
||||
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
||||
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
EXPECT_EQ(tensors.size(), item.fetch.size());
|
||||
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||
test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
|
||||
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
|
||||
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk3);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Ta").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Tb").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
EXPECT_EQ(tensors.size(), item.fetch.size());
|
||||
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||
test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
tensorflow::Input shape = {32, 32};
|
||||
auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
|
||||
Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
|
||||
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
||||
auto tl1w1 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
auto tl1w2 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
|
||||
// Ensure that TensorListResize doesn't cause any problems.
|
||||
Output tl1rs =
|
||||
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
||||
Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
|
||||
shape, DT_FLOAT)
|
||||
.item;
|
||||
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||
auto tl1w3 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2);
|
||||
Output tl1r2 =
|
||||
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
||||
shape, DT_FLOAT)
|
||||
.item;
|
||||
auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
|
||||
auto tl2w1 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
|
||||
Output tl2r1 =
|
||||
ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
|
||||
shape, DT_FLOAT)
|
||||
.item;
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
|
||||
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch1", "fetch2"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||
|
||||
VLOG(1) << output.DebugString();
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
EXPECT_EQ(tensors.size(), item.fetch.size());
|
||||
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||
test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||
#endif // INTEL_MKL
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM || (INTEL_MKL &&
|
||||
// defined(ENABLE_INTEL_MKL_BFLOAT16))
|
||||
|
@ -188,7 +188,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
|
||||
MK_OPT("remap", new Remapper(cfg_.remapping()));
|
||||
MK_OPT("layout", new GenericLayoutOptimizer());
|
||||
MK_OPT("auto_mixed_precision",
|
||||
new AutoMixedPrecision(cfg_.auto_mixed_precision()));
|
||||
new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
|
||||
MK_OPT("auto_mixed_precision_mkl",
|
||||
new AutoMixedPrecision(AutoMixedPrecisionMode::MKL));
|
||||
MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
|
||||
MK_OPT("common_subgraph_elimination",
|
||||
new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
|
||||
@ -249,7 +251,11 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
}
|
||||
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
|
||||
optimizers->push_back(
|
||||
MakeUnique<AutoMixedPrecision>(cfg_.auto_mixed_precision()));
|
||||
MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
|
||||
}
|
||||
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())) {
|
||||
optimizers->push_back(
|
||||
MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::MKL));
|
||||
}
|
||||
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
||||
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
||||
@ -835,6 +841,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
|
||||
rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
||||
rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
|
||||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
|
||||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
|
||||
!rewrite_cfg.optimizers().empty() ||
|
||||
!rewrite_cfg.custom_optimizers().empty();
|
||||
}
|
||||
|
@ -85,11 +85,15 @@ message RewriterConfig {
|
||||
// Enable the swap of kernel implementations based on the device placement
|
||||
// (default is ON).
|
||||
Toggle implementation_selector = 22;
|
||||
// Optimize data types (default is OFF).
|
||||
// e.g., This will try to use float16 on GPU which is faster.
|
||||
// Optimize data types for CUDA (default is OFF).
|
||||
// This will try to use float16 on GPU which is faster.
|
||||
// Note that this can change the numerical stability of the graph and may
|
||||
// require the use of loss scaling to maintain model convergence.
|
||||
Toggle auto_mixed_precision = 23;
|
||||
// Optimize data types for MKL (default is OFF).
|
||||
// This will try to use bfloat16 on CPUs, which is faster.
|
||||
// Note that this can change the numerical stability of the graph.
|
||||
Toggle auto_mixed_precision_mkl = 25;
|
||||
// Disable the entire meta optimizer (off by default).
|
||||
bool disable_meta_optimizer = 19;
|
||||
|
||||
|
@ -19,8 +19,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
@ -209,7 +209,7 @@ def _make_node_with_color(color, input_tensor, name=None):
|
||||
if color == 'c': # Clear node
|
||||
return nn.relu(input_tensor, name=name)
|
||||
if color == 'b': # Black node
|
||||
return math_ops.sqrt(math_ops.pow(input_tensor, 2.), name=name)
|
||||
return math_ops.pow(math_ops.pow(input_tensor, 2.), 0.5, name=name)
|
||||
raise ValueError('Invalid node color: ' + str(color))
|
||||
|
||||
|
||||
@ -231,18 +231,21 @@ def _build_simple_loop_graph(inp_colors, body_colors, out_colors):
|
||||
return a
|
||||
|
||||
|
||||
def _get_config(auto_mixed_precision=True):
|
||||
def _get_config(auto_mixed_precision_mode):
|
||||
"""Returns a ConfigProto with auto mixed precision enabled if appropriate."""
|
||||
if auto_mixed_precision:
|
||||
rewrite_config = rewriter_config_pb2.RewriterConfig(
|
||||
auto_mixed_precision=rewriter_config_pb2.RewriterConfig.ON,
|
||||
# do not remove duplicated nodes
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
||||
rewrite_config = rewriter_config_pb2.RewriterConfig(
|
||||
# do not remove duplicated nodes
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
# do not turn Conv2D and other nodes into _FusedConv2D
|
||||
remapping=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
)
|
||||
if auto_mixed_precision_mode == 'cuda':
|
||||
rewrite_config.auto_mixed_precision = rewriter_config_pb2.RewriterConfig.ON
|
||||
elif auto_mixed_precision_mode == 'mkl':
|
||||
rewrite_config.auto_mixed_precision_mkl = (
|
||||
rewriter_config_pb2.RewriterConfig.ON)
|
||||
else:
|
||||
rewrite_config = rewriter_config_pb2.RewriterConfig(
|
||||
auto_mixed_precision=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
# do not remove duplicated nodes
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
||||
assert auto_mixed_precision_mode is None
|
||||
rewrite_config.min_graph_nodes = -1
|
||||
graph_options = config_pb2.GraphOptions(
|
||||
rewrite_options=rewrite_config, build_cost_model=1)
|
||||
@ -255,19 +258,33 @@ def _is_cast_to_fp16(node_name):
|
||||
return node_name.endswith('-CastToFp16-AutoMixedPrecision')
|
||||
|
||||
|
||||
def _is_cast_to_bf16(node_name):
|
||||
return node_name.endswith('-CastToBf16-AutoMixedPrecision')
|
||||
|
||||
|
||||
def _is_cast_to_fp32(node_name):
|
||||
return node_name.endswith('-CastToFp32-AutoMixedPrecision')
|
||||
|
||||
|
||||
def _count_casts(nodes):
|
||||
def _count_casts(mode, nodes):
|
||||
"""Counts the number of casts to f16 and fp32."""
|
||||
num_to_fp16 = 0
|
||||
num_to_bf16 = 0
|
||||
num_to_fp32 = 0
|
||||
for node in nodes:
|
||||
if _is_cast_to_fp16(node.name):
|
||||
num_to_fp16 += 1
|
||||
if _is_cast_to_bf16(node.name):
|
||||
num_to_bf16 += 1
|
||||
elif _is_cast_to_fp32(node.name):
|
||||
num_to_fp32 += 1
|
||||
return num_to_fp16, num_to_fp32
|
||||
if mode == 'cuda':
|
||||
assert num_to_bf16 == 0
|
||||
return num_to_fp16, num_to_fp32
|
||||
else:
|
||||
assert mode == 'mkl'
|
||||
assert num_to_fp16 == 0
|
||||
return num_to_bf16, num_to_fp32
|
||||
|
||||
|
||||
def _build_node_map(nodes):
|
||||
@ -303,7 +320,7 @@ def _example_noninlined_funcdef(features):
|
||||
return features * math_ops.sigmoid(features)
|
||||
|
||||
|
||||
class AutoMixedPrecisionTest(test.TestCase):
|
||||
class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
"""Tests the Grappler auto mixed precision optimizer."""
|
||||
IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
|
||||
|
||||
@ -311,8 +328,8 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(AutoMixedPrecisionTest, self).setUp()
|
||||
# Enable the tests to be run on pre-Volta GPUs by telling the grappler pass
|
||||
# to ignore performance and always transform the graph.
|
||||
# Enable the CUDA tests to be run on pre-Volta GPUs by telling the grappler
|
||||
# pass to ignore performance and always transform the graph.
|
||||
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
|
||||
os.environ[self.IGNORE_PERF_VAR] = '1'
|
||||
|
||||
@ -323,24 +340,33 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
del os.environ[self.IGNORE_PERF_VAR]
|
||||
super(AutoMixedPrecisionTest, self).tearDown()
|
||||
|
||||
def _assert_output_fp16(self, node_map, node_name, output_port=0):
|
||||
self.assertEqual(node_map[node_name].output_info[output_port].dtype,
|
||||
types_pb2.DT_HALF)
|
||||
def _lower_precision_dtype(self, mode):
|
||||
return dtypes.float16 if mode == 'cuda' else dtypes.bfloat16
|
||||
|
||||
def _run(self, fetches):
|
||||
def _assert_output_f16(self, mode, node_map, node_name, output_port=0):
|
||||
self.assertEqual(node_map[node_name].output_info[output_port].dtype,
|
||||
self._lower_precision_dtype(mode).as_datatype_enum)
|
||||
|
||||
def _run(self, mode, fetches):
|
||||
"""Runs the graph and returns the evaluation of the fetches."""
|
||||
with session.Session(config=_get_config(False)) as sess:
|
||||
with session.Session(config=_get_config(None)) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
output_val_ref = self.evaluate(fetches)
|
||||
|
||||
with session.Session(config=_get_config()) as sess:
|
||||
with session.Session(config=_get_config(mode)) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
metadata = config_pb2.RunMetadata()
|
||||
output_val = sess.run(fetches, run_metadata=metadata)
|
||||
|
||||
return output_val_ref, output_val, metadata.cost_graph
|
||||
|
||||
def _run_simple_loop_test(self, inp, body, out):
|
||||
def _maybe_skip(self, mode):
|
||||
if mode == 'cuda' and not test.is_gpu_available(cuda_only=True):
|
||||
self.skipTest('No GPU is available')
|
||||
if mode == 'mkl' and not test_util.IsMklEnabled():
|
||||
self.skipTest('MKL is not enabled')
|
||||
|
||||
def _run_simple_loop_test(self, mode, inp, body, out):
|
||||
"""Runs a test of a simple loop.
|
||||
|
||||
The loop has different node colors in different sections of the graph. The
|
||||
@ -352,6 +378,7 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
inp -> loop [ body ] -> out.
|
||||
|
||||
Args:
|
||||
mode: Either 'cuda' or 'mkl'.
|
||||
inp: A string of letters indicating the colors and expected dtypes of the
|
||||
input nodes.
|
||||
body: A string of letters indicating the colors and expected dtypes of the
|
||||
@ -359,398 +386,446 @@ class AutoMixedPrecisionTest(test.TestCase):
|
||||
out: A string of letters indicating the colors and expected dtypes of the
|
||||
output nodes.
|
||||
"""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
expected_types = []
|
||||
for section in [inp, body, out]:
|
||||
section_expected_types = []
|
||||
for color in section:
|
||||
if color.isupper():
|
||||
expected_type = types_pb2.DT_HALF
|
||||
else:
|
||||
expected_type = types_pb2.DT_FLOAT
|
||||
section_expected_types.append(expected_type)
|
||||
expected_types.append(section_expected_types)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
expected_types = []
|
||||
for section in [inp, body, out]:
|
||||
section_expected_types = []
|
||||
for color in section:
|
||||
if color.isupper():
|
||||
expected_type = self._lower_precision_dtype(mode).as_datatype_enum
|
||||
else:
|
||||
expected_type = types_pb2.DT_FLOAT
|
||||
section_expected_types.append(expected_type)
|
||||
expected_types.append(section_expected_types)
|
||||
|
||||
a = _build_simple_loop_graph(inp, body, out)
|
||||
output_val_ref, output_val, cost_graph = self._run(a)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
a = _build_simple_loop_graph(inp, body, out)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, a)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
section_names = ['input', 'while/body', 'output']
|
||||
all_types_correct = True
|
||||
for section_name, expected_types in zip(section_names, expected_types):
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
node_name = section_name + '_%i' % i
|
||||
output_port = 0
|
||||
optimized_type = node_map[node_name].output_info[output_port].dtype
|
||||
if optimized_type != expected_type:
|
||||
print('Expected node %s to have type %s but got type %s' %
|
||||
(node_name, expected_type, optimized_type))
|
||||
all_types_correct = False
|
||||
self.assertTrue(all_types_correct)
|
||||
section_names = ['input', 'while/body', 'output']
|
||||
all_types_correct = True
|
||||
for section_name, expected_types in zip(section_names, expected_types):
|
||||
for i, expected_type in enumerate(expected_types):
|
||||
node_name = section_name + '_%i' % i
|
||||
output_port = 0
|
||||
optimized_type = node_map[node_name].output_info[output_port].dtype
|
||||
if optimized_type != expected_type:
|
||||
print('Expected node %s to have type %s but got type %s' %
|
||||
(node_name, expected_type, optimized_type))
|
||||
all_types_correct = False
|
||||
self.assertTrue(all_types_correct)
|
||||
if mode == 'mkl':
|
||||
self.assertAllClose(output_val_ref, output_val, atol=2e-2, rtol=2e-2)
|
||||
else:
|
||||
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn(self):
|
||||
def test_conv_bn(self, mode):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
x = _conv_bn(x)
|
||||
output = _conv_bn(x)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
x = _conv_bn(x)
|
||||
output = _conv_bn(x)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
self.assertEqual(num_to_fp16,
|
||||
3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
|
||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||
self._assert_output_f16(mode, node_map, 'Conv2D')
|
||||
self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
|
||||
self._assert_output_f16(mode, node_map, 'Conv2D_1')
|
||||
self.assertEqual(num_to_f16, 3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
|
||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||
if mode == 'mkl':
|
||||
tol = 1e-2
|
||||
elif test.is_built_with_rocm():
|
||||
# Bump up the tolerance for the ROCm platform
|
||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||
# miscompares on ROCm platform, and hence the tolerance bump
|
||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
tol = 2e-3
|
||||
else:
|
||||
tol = 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
# TODO: enable these tests when cuDNN is upgraded to >= 7.6.2. Same with the
|
||||
# test_conv3d() below.
|
||||
@unittest.skip('Test case should be skipped when cuDNN < 7.6.2')
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv3d_bn(self):
|
||||
def test_conv3d_bn(self, mode):
|
||||
"""Test graph with convolution followed by batch norm."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 8, 1])
|
||||
x = _conv3d_bn(x)
|
||||
output = _conv3d_bn(x)
|
||||
self._maybe_skip(mode)
|
||||
if mode == 'cuda':
|
||||
# TODO(reedwm): enable these tests when cuDNN is upgraded to >= 7.6.2.
|
||||
self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 8, 1])
|
||||
x = _conv3d_bn(x)
|
||||
output = _conv3d_bn(x)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(mode, cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'Conv3D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
||||
self._assert_output_fp16(node_map, 'Conv3D_1')
|
||||
self.assertEqual(num_to_fp16, 3) # Before Conv3D:0, Conv3D:1, Conv3D_1:1
|
||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2)
|
||||
self._assert_output_f16(mode, node_map, 'Conv3D')
|
||||
self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
|
||||
self._assert_output_f16(mode, node_map, 'Conv3D_1')
|
||||
self.assertEqual(num_to_fp16, 3) # Before Conv3D:0, Conv3D:1, Conv3D_1:1
|
||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@unittest.skip('Test case should be skipped when cuDNN < 7.6.2')
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv3d(self):
|
||||
def test_conv3d(self, mode):
|
||||
"""Test grad ops with convolution3d graph."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 8, 1])
|
||||
f = _weight([3, 3, 3, 1, 6])
|
||||
y = _conv3d(x, f)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x, f])
|
||||
output = (y, g)
|
||||
self._maybe_skip(mode)
|
||||
if mode == 'cuda':
|
||||
# TODO(reedwm): enable these tests when cuDNN is upgraded to >= 7.6.2.
|
||||
self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 8, 1])
|
||||
f = _weight([3, 3, 3, 1, 6])
|
||||
y = _conv3d(x, f)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x, f])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
self._assert_output_fp16(node_map, 'Conv3D')
|
||||
self._assert_output_fp16(node_map,
|
||||
'gradients/Conv3D_grad/Conv3DBackpropInputV2')
|
||||
self._assert_output_fp16(node_map,
|
||||
'gradients/Conv3D_grad/Conv3DBackpropFilterV2')
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
self._assert_output_f16(mode, node_map, 'Conv3D')
|
||||
self._assert_output_f16(mode, node_map,
|
||||
'gradients/Conv3D_grad/Conv3DBackpropInputV2')
|
||||
self._assert_output_f16(mode, node_map,
|
||||
'gradients/Conv3D_grad/Conv3DBackpropFilterV2')
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
tol = 5e-2 if mode == 'mkl' else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
|
||||
# MKL
|
||||
@parameterized.parameters(['cuda'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_bn_dropout(self):
|
||||
def test_conv_bn_dropout(self, mode):
|
||||
"""Test dropout precision of convolution batch norm graph."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
y = _conv_bn(x)
|
||||
y = nn.dropout(y, rate=0.5)
|
||||
y = math_ops.add(y, 1, name='addition')
|
||||
y = _conv_bn(y)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(
|
||||
learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
y = _conv_bn(x)
|
||||
y = nn.dropout(y, rate=0.5)
|
||||
y = math_ops.add(y, 1, name='addition')
|
||||
y = _conv_bn(y)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
||||
# We do not assert dropout's dtype because we do not want to rely on the
|
||||
# node names of dropout's internal implementation.
|
||||
self._assert_output_fp16(node_map, 'addition')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
self._assert_output_f16(mode, node_map, 'Conv2D')
|
||||
self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
|
||||
# We do not assert dropout's dtype because we do not want to rely on the
|
||||
# node names of dropout's internal implementation.
|
||||
self._assert_output_f16(mode, node_map, 'addition')
|
||||
self._assert_output_f16(mode, node_map, 'Conv2D_1')
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
# Bump up the tolerance for the ROCm platform
|
||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||
# miscompares on ROCm platform, and hence the tolerance bump
|
||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
# Bump up the tolerance for the ROCm platform
|
||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||
# miscompares on ROCm platform, and hence the tolerance bump
|
||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
|
||||
# MKL
|
||||
@parameterized.parameters(['cuda'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_conv_pool(self):
|
||||
def test_conv_pool(self, mode):
|
||||
"""Test graph with convolution followed by pooling."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
output = _conv_pool(x)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 1])
|
||||
output = _conv_pool(x)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'Conv2D')
|
||||
self._assert_output_fp16(node_map, 'Relu')
|
||||
self._assert_output_fp16(node_map, 'MaxPool')
|
||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
||||
self.assertEqual(num_to_fp16, 4)
|
||||
self.assertEqual(num_to_fp32, 1)
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_f16(mode, node_map, 'Conv2D')
|
||||
self._assert_output_f16(mode, node_map, 'Relu')
|
||||
self._assert_output_f16(mode, node_map, 'MaxPool')
|
||||
self._assert_output_f16(mode, node_map, 'Conv2D_1')
|
||||
self.assertEqual(num_to_f16, 4)
|
||||
self.assertEqual(num_to_fp32, 1)
|
||||
tol = 5e-3 if mode == 'mkl' else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_simple_loop(self):
|
||||
def test_simple_loop(self, mode):
|
||||
"""Test graph with while loop."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y = _simple_loop(x, _matmul_act)[1]
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y = _simple_loop(x, _matmul_act)[1]
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
||||
self._assert_output_fp16(node_map, 'while/Relu')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_f16(mode, node_map, 'while/MatMul')
|
||||
self._assert_output_f16(mode, node_map, 'while/Relu')
|
||||
tol = 1e-2 if mode == 'mkl' else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_loop_with_vars_intertwined(self):
|
||||
def test_loop_with_vars_intertwined(self, mode):
|
||||
"""Test graph with intertwined while loops."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
_, _, k, l = _loop_vars_intertwined(
|
||||
array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(k, [x])
|
||||
output = (k, l, g)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
_, _, k, l = _loop_vars_intertwined(
|
||||
array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(k, [x])
|
||||
output = (k, l, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
||||
self._assert_output_fp16(node_map, 'while/Relu')
|
||||
self._assert_output_fp16(node_map, 'while/MatMul_1')
|
||||
self._assert_output_fp16(node_map, 'while/Relu_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_f16(mode, node_map, 'while/MatMul')
|
||||
self._assert_output_f16(mode, node_map, 'while/Relu')
|
||||
self._assert_output_f16(mode, node_map, 'while/MatMul_1')
|
||||
self._assert_output_f16(mode, node_map, 'while/Relu_1')
|
||||
tol = 5e-3 if mode == 'mkl' else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.parameters(['cuda'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_multi_paths(self):
|
||||
def test_multi_paths(self, mode):
|
||||
"""Test graph with multiple paths."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 3])
|
||||
x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3)
|
||||
y1 = _conv_pool(x1)
|
||||
y2 = _conv_pool(x2)
|
||||
y3 = _conv_pool(x3)
|
||||
y = array_ops.concat([y1, y2, y3], axis=3)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([2, 8, 8, 3])
|
||||
x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3)
|
||||
y1 = _conv_pool(x1)
|
||||
y2 = _conv_pool(x2)
|
||||
y3 = _conv_pool(x3)
|
||||
y = array_ops.concat([y1, y2, y3], axis=3)
|
||||
y = array_ops.identity(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (y, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'split')
|
||||
for suffix in [''] + ['_%i' % i for i in range(1, 6)]:
|
||||
self._assert_output_fp16(node_map, 'Conv2D' + suffix)
|
||||
self._assert_output_fp16(node_map, 'Relu' + suffix)
|
||||
self._assert_output_fp16(node_map, 'MaxPool' + suffix)
|
||||
self._assert_output_fp16(node_map, 'concat')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_f16(mode, node_map, 'split')
|
||||
for suffix in [''] + ['_%i' % i for i in range(1, 6)]:
|
||||
self._assert_output_f16(mode, node_map, 'Conv2D' + suffix)
|
||||
self._assert_output_f16(mode, node_map, 'Relu' + suffix)
|
||||
self._assert_output_f16(mode, node_map, 'MaxPool' + suffix)
|
||||
self._assert_output_f16(mode, node_map, 'concat')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_multi_paths_2(self):
|
||||
def test_multi_paths_2(self, mode):
|
||||
"""Test graph with multiple paths."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y1 = _matmul_act(x)
|
||||
y2 = _matmul_act(x)
|
||||
y = y1 + y2 + x
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (g, y)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y1 = _matmul_act(x)
|
||||
y2 = _matmul_act(x)
|
||||
y = y1 + y2 + x
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (g, y)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'MatMul')
|
||||
self._assert_output_fp16(node_map, 'Relu')
|
||||
self._assert_output_fp16(node_map, 'MatMul_1')
|
||||
self._assert_output_fp16(node_map, 'Relu_1')
|
||||
self._assert_output_f16(mode, node_map, 'MatMul')
|
||||
self._assert_output_f16(mode, node_map, 'Relu')
|
||||
self._assert_output_f16(mode, node_map, 'MatMul_1')
|
||||
self._assert_output_f16(mode, node_map, 'Relu_1')
|
||||
if mode == 'mkl':
|
||||
tol = 2e-2
|
||||
elif test.is_built_with_rocm():
|
||||
# Bump up the tolerance for the ROCm platform
|
||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||
# miscompares on ROCm platform, and hence the tolerance bump
|
||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
tol = 2e-3
|
||||
else:
|
||||
tol = 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.parameters(['cuda']) # MKL doesn't support bf16 Sigmoid
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_recurrent_lstm(self):
|
||||
def test_recurrent_lstm(self, mode):
|
||||
"""Test graph with recurrent lstm."""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
init_c = _input([8, 4])
|
||||
init_h = _input([8, 4])
|
||||
_, _, h, _ = _recurrent_lstm(init_c, init_h)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(h, [init_c, init_h])
|
||||
output = (h, g)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
init_c = _input([8, 4])
|
||||
init_h = _input([8, 4])
|
||||
_, _, h, _ = _recurrent_lstm(init_c, init_h)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(h, [init_c, init_h])
|
||||
output = (h, g)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/concat')
|
||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
||||
self._assert_output_fp16(node_map, 'while/split')
|
||||
self._assert_output_fp16(node_map, 'while/Sigmoid')
|
||||
self._assert_output_fp16(node_map, 'while/Sigmoid_1')
|
||||
self._assert_output_fp16(node_map, 'while/Sigmoid_2')
|
||||
self._assert_output_fp16(node_map, 'while/Tanh')
|
||||
self._assert_output_fp16(node_map, 'while/Tanh_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_f16(mode, node_map, 'while/concat')
|
||||
self._assert_output_f16(mode, node_map, 'while/MatMul')
|
||||
self._assert_output_f16(mode, node_map, 'while/split')
|
||||
self._assert_output_f16(mode, node_map, 'while/Sigmoid')
|
||||
self._assert_output_f16(mode, node_map, 'while/Sigmoid_1')
|
||||
self._assert_output_f16(mode, node_map, 'while/Sigmoid_2')
|
||||
self._assert_output_f16(mode, node_map, 'while/Tanh')
|
||||
self._assert_output_f16(mode, node_map, 'while/Tanh_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_1(self):
|
||||
self._run_simple_loop_test('W', 'C', 'C')
|
||||
def test_propagation_through_simple_loop_1(self, mode):
|
||||
self._run_simple_loop_test(mode, 'W', 'C', 'C')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_2(self):
|
||||
self._run_simple_loop_test('C', 'C', 'W')
|
||||
def test_propagation_through_simple_loop_2(self, mode):
|
||||
self._run_simple_loop_test(mode, 'C', 'C', 'W')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_3(self):
|
||||
self._run_simple_loop_test('W', 'G', 'W')
|
||||
def test_propagation_through_simple_loop_3(self, mode):
|
||||
self._run_simple_loop_test(mode, 'W', 'G', 'W')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('v1 loop test')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_4(self):
|
||||
self._run_simple_loop_test('W', 'gbg', 'W')
|
||||
def test_propagation_through_simple_loop_4(self, mode):
|
||||
self._run_simple_loop_test(mode, 'W', 'gbg', 'W')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_5(self):
|
||||
self._run_simple_loop_test('b', 'gWC', 'c')
|
||||
def test_propagation_through_simple_loop_5(self, mode):
|
||||
self._run_simple_loop_test(mode, 'b', 'gWC', 'c')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_6(self):
|
||||
self._run_simple_loop_test('b', 'CWCG', 'C')
|
||||
def test_propagation_through_simple_loop_6(self, mode):
|
||||
self._run_simple_loop_test(mode, 'b', 'CWCG', 'C')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_7(self):
|
||||
self._run_simple_loop_test('C', 'GWCG', 'C')
|
||||
def test_propagation_through_simple_loop_7(self, mode):
|
||||
self._run_simple_loop_test(mode, 'C', 'GWCG', 'C')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_v1_only('b/138749235')
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_propagation_through_simple_loop_8(self):
|
||||
self._run_simple_loop_test('C', 'CgbgWC', 'g')
|
||||
def test_propagation_through_simple_loop_8(self, mode):
|
||||
self._run_simple_loop_test(mode, 'C', 'CgbgWC', 'g')
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_noninlined_funcdef(self):
|
||||
def test_noninlined_funcdef(self, mode):
|
||||
"""Test graph with non-inlined function subgraph.
|
||||
|
||||
This requires the grappler pass to handle an OpDef that only appears in the
|
||||
graph's function registry instead of the global op registry.
|
||||
|
||||
Args:
|
||||
mode: Either 'cuda' or 'mkl'.
|
||||
"""
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y = _matmul_act(x)
|
||||
y = _example_noninlined_funcdef(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (g, y)
|
||||
self._maybe_skip(mode)
|
||||
random_seed.set_random_seed(0)
|
||||
x = _input([8, 8])
|
||||
y = _matmul_act(x)
|
||||
y = _example_noninlined_funcdef(y)
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||
g = optimizer.compute_gradients(y, [x])
|
||||
output = (g, y)
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'MatMul')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_f16(mode, node_map, 'MatMul')
|
||||
tol = 1e-2 if mode == 'mkl' else 1e-3
|
||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||
|
||||
@parameterized.parameters(['cuda', 'mkl'])
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.disable_xla('This test does not pass with XLA')
|
||||
def test_ingraph_train_loop(self):
|
||||
def test_ingraph_train_loop(self, mode):
|
||||
"""Tests a graph containing a while loop around a training update.
|
||||
|
||||
This requires the grappler pass to take special care with its handling of
|
||||
Enter ops that appear in front of reads from non-resource variables. See
|
||||
the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc.
|
||||
|
||||
Args:
|
||||
mode: Either 'cuda' or 'mkl'.
|
||||
"""
|
||||
self._maybe_skip(mode)
|
||||
if tf2.enabled():
|
||||
# This test tests non-resource variables, which are only used in TF1.
|
||||
self.skipTest('TensorFlow 1 required')
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
random_seed.set_random_seed(1234)
|
||||
np.random.seed(1234)
|
||||
num_iter, bs, nchan, nclass = 100, 64, 32, 100
|
||||
random_seed.set_random_seed(1234)
|
||||
np.random.seed(1234)
|
||||
num_iter, bs, nchan, nclass = 100, 64, 32, 100
|
||||
|
||||
data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
|
||||
labels = np.random.randint(nclass, size=(bs * num_iter,))
|
||||
ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
|
||||
ds = ds.batch(bs).prefetch(3)
|
||||
it = ds.make_one_shot_iterator()
|
||||
data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
|
||||
labels = np.random.randint(nclass, size=(bs * num_iter,))
|
||||
ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
|
||||
ds = ds.batch(bs).prefetch(3)
|
||||
it = ds.make_one_shot_iterator()
|
||||
|
||||
def body(_, i):
|
||||
i += 1
|
||||
x, yt = it.get_next()
|
||||
dense = layers.Dense(nclass)
|
||||
y = dense(x)
|
||||
loss = losses.sparse_softmax_cross_entropy(yt, y)
|
||||
opt = adam.AdamOptimizer()
|
||||
train_op = opt.minimize(loss, var_list=dense.trainable_weights)
|
||||
with ops.control_dependencies([train_op]):
|
||||
loss = array_ops.identity(loss)
|
||||
return loss, i
|
||||
def body(_, i):
|
||||
i += 1
|
||||
x, yt = it.get_next()
|
||||
dense = layers.Dense(nclass)
|
||||
y = dense(x)
|
||||
loss = losses.sparse_softmax_cross_entropy(yt, y)
|
||||
opt = adam.AdamOptimizer()
|
||||
train_op = opt.minimize(loss, var_list=dense.trainable_weights)
|
||||
with ops.control_dependencies([train_op]):
|
||||
loss = array_ops.identity(loss)
|
||||
return loss, i
|
||||
|
||||
begin, end = constant_op.constant(0), constant_op.constant(num_iter)
|
||||
loss, _ = control_flow_ops.while_loop(
|
||||
lambda loss, i: math_ops.less(i, end), body, [0.0, begin])
|
||||
begin, end = constant_op.constant(0), constant_op.constant(num_iter)
|
||||
loss, _ = control_flow_ops.while_loop(lambda loss, i: math_ops.less(i, end),
|
||||
body, [0.0, begin])
|
||||
|
||||
output_val_ref, output_val, cost_graph = self._run(loss)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
output_val_ref, output_val, cost_graph = self._run(mode, loss)
|
||||
node_map = _build_node_map(cost_graph.node)
|
||||
|
||||
self._assert_output_fp16(node_map, 'while/dense/MatMul')
|
||||
self._assert_output_fp16(
|
||||
node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
self._assert_output_f16(mode, node_map, 'while/dense/MatMul')
|
||||
self._assert_output_f16(mode, node_map,
|
||||
'while/gradients/while/dense/MatMul_grad/MatMul_1')
|
||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||
|
||||
# TODO(benbarsdell): Add tests for list ops (TensorList*) that pass through
|
||||
# graph source/sink nodes, similar to the TensorListThroughFunction C++ test.
|
||||
|
Loading…
Reference in New Issue
Block a user