Add MKL supoprt to auto_mixed_precision.

This extends the auto mixed precision grappler pass to support converting nodes to bfloat16 on MKL-supported CPUs.

Co-authored-by: Niranjan Hasabnis <niranjan.hasabnis@intel.com>
This commit is contained in:
Reed 2020-06-18 12:42:19 -07:00
parent 81041bcd82
commit d8bfc935fd
8 changed files with 951 additions and 479 deletions

View File

@ -1,5 +1,5 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cc_test_mkl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cc_test_mkl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_cc_test")
# Platform specific build config # Platform specific build config
load( load(
@ -7,6 +7,11 @@ load(
"if_static", "if_static",
) )
load(
"//third_party/mkl:build_defs.bzl",
"mkl_deps",
)
package( package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
@ -611,6 +616,7 @@ cc_library(
"auto_mixed_precision_lists.h", "auto_mixed_precision_lists.h",
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
copts = tf_copts(),
deps = [ deps = [
":custom_graph_optimizer_registry", ":custom_graph_optimizer_registry",
":graph_optimizer", ":graph_optimizer",
@ -627,7 +633,7 @@ cc_library(
"//tensorflow/core/grappler/costs:virtual_placer", "//tensorflow/core/grappler/costs:virtual_placer",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
], ] + mkl_deps(),
) )
tf_cuda_cc_test( tf_cuda_cc_test(

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h" #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
#include <fstream> #include <fstream>
#include <memory>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
@ -52,6 +53,7 @@ const std::pair<int, int> kMinGPUArch = {0, 0};
const char kSuffix[] = "AutoMixedPrecision"; const char kSuffix[] = "AutoMixedPrecision";
const char kCastToFp16[] = "CastToFp16"; const char kCastToFp16[] = "CastToFp16";
const char kCastToBf16[] = "CastToBf16";
const char kCastToFp32[] = "CastToFp32"; const char kCastToFp32[] = "CastToFp32";
// Instances of this class represent unique type attribute identifiers within a // Instances of this class represent unique type attribute identifiers within a
@ -840,22 +842,6 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
return AllowedDataTypes(*attr_def); return AllowedDataTypes(*attr_def);
} }
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_fp16,
const string& device) {
const char* cast_string = to_fp16 ? kCastToFp16 : kCastToFp32;
string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
cast_string, "-", kSuffix);
NodeDef node;
node.set_name(name);
node.set_op("Cast");
node.set_device(device);
node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
(*node.mutable_attr())["SrcT"].set_type(to_fp16 ? DT_FLOAT : DT_HALF);
(*node.mutable_attr())["DstT"].set_type(to_fp16 ? DT_HALF : DT_FLOAT);
(*node.mutable_attr())["Truncate"].set_b(false);
return node;
}
Status ValidateLists(const gtl::FlatSet<string>& white_list, Status ValidateLists(const gtl::FlatSet<string>& white_list,
const gtl::FlatSet<string>& black_list, const gtl::FlatSet<string>& black_list,
const gtl::FlatSet<string>& gray_list, const gtl::FlatSet<string>& gray_list,
@ -941,7 +927,8 @@ class AutoMixedPrecisionImpl {
public: public:
AutoMixedPrecisionImpl(Cluster* cluster, AutoMixedPrecisionImpl(Cluster* cluster,
const std::unordered_set<string>& nodes_to_preserve, const std::unordered_set<string>& nodes_to_preserve,
GraphDef* graph, string id) GraphDef* graph, string id,
AutoMixedPrecisionMode mode)
: virtual_placer_(cluster->GetDevices()), : virtual_placer_(cluster->GetDevices()),
nodes_to_preserve_(nodes_to_preserve), nodes_to_preserve_(nodes_to_preserve),
graph_(graph), graph_(graph),
@ -949,23 +936,35 @@ class AutoMixedPrecisionImpl {
id_(id), id_(id),
graph_view_(graph), graph_view_(graph),
cuda_version_(GetCudaVersion(*cluster)), cuda_version_(GetCudaVersion(*cluster)),
cudnn_version_(GetCudnnVersion(*cluster)) {} cudnn_version_(GetCudnnVersion(*cluster)),
mode_(mode),
target_dtype_(mode_ == AutoMixedPrecisionMode::CUDA ? DT_HALF
: DT_BFLOAT16) {}
Status Optimize(); Status Optimize();
private: private:
typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet; typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
switch (mode_) {
case AutoMixedPrecisionMode::CUDA:
return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
cudnn_version_);
case AutoMixedPrecisionMode::MKL:
return std::make_unique<AutoMixedPrecisionListsMkl>();
}
}
Status PrintDebugLogs(bool preop, size_t timestamp); Status PrintDebugLogs(bool preop, size_t timestamp);
void LogSkippedNode(const NodeDef& node) const; void LogSkippedNode(const NodeDef& node) const;
bool MustPreserve(const NodeDef& node) const; bool MustPreserve(const NodeDef& node) const;
bool IsOnGPU(const NodeDef& node) const; bool IsOnDevice(const NodeDef& node, const string& device_type) const;
bool IsOnSuitableGPUArch(const NodeDef& node) const; bool IsOnSuitableGPUArch(const NodeDef& node) const;
bool ShouldProcess(const NodeDef& node) const; bool ShouldProcess(const NodeDef& node) const;
bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const; bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const; bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
void ConvertBatchNormOpsToV2(); void ConvertBatchNormOpsToV2();
bool SupportsFloat16(const NodeTypeId& node_type) const; bool SupportsF16(const NodeTypeId& node_type) const;
const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const; const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
bool IsSourceOrSinkOp(const string& op) const; bool IsSourceOrSinkOp(const string& op) const;
void FindFloat32TensorListOpClustersAndBlacklistUnsafe( void FindFloat32TensorListOpClustersAndBlacklistUnsafe(
@ -990,6 +989,8 @@ class AutoMixedPrecisionImpl {
absl::flat_hash_set<int>* white_set) const; absl::flat_hash_set<int>* white_set) const;
void MakeCastsWhiteIfAllOutputsWhite( void MakeCastsWhiteIfAllOutputsWhite(
absl::flat_hash_set<int>* white_set) const; absl::flat_hash_set<int>* white_set) const;
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
const string& device) const;
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set); Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
VirtualPlacer virtual_placer_; VirtualPlacer virtual_placer_;
@ -1003,21 +1004,44 @@ class AutoMixedPrecisionImpl {
NodeTypeAttrMap node_type_map_; NodeTypeAttrMap node_type_map_;
GraphTypeTopologyView graph_type_view_; GraphTypeTopologyView graph_type_view_;
bool force_all_fp16_; bool force_all_fp16_;
gtl::FlatSet<string> fp16_whitelist_; AutoMixedPrecisionMode mode_;
gtl::FlatSet<string> fp16_blacklist_; gtl::FlatSet<string> f16_whitelist_;
gtl::FlatSet<string> fp16_graylist_; gtl::FlatSet<string> f16_blacklist_;
gtl::FlatSet<string> fp16_clearlist_; gtl::FlatSet<string> f16_graylist_;
gtl::FlatSet<string> f16_clearlist_;
absl::flat_hash_set<const NodeDef*> should_process_nodes_; absl::flat_hash_set<const NodeDef*> should_process_nodes_;
DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
}; };
bool AutoMixedPrecisionImpl::NodeHasFP16KernelForTypeAttr( NodeDef AutoMixedPrecisionImpl::BuildCastNode(
const MutableGraphView::OutputPort& src, bool to_f16,
const string& device) const {
DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
const char* cast_string =
!to_f16 ? kCastToFp32
: target_dtype_ == DT_HALF ? kCastToFp16 : kCastToBf16;
string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
cast_string, "-", kSuffix);
NodeDef node;
node.set_name(name);
node.set_op("Cast");
node.set_device(device);
node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
(*node.mutable_attr())["SrcT"].set_type(src_type);
(*node.mutable_attr())["DstT"].set_type(dst_type);
(*node.mutable_attr())["Truncate"].set_b(false);
return node;
}
bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
const NodeDef& node, TypeAttrId taid) const { const NodeDef& node, TypeAttrId taid) const {
NodeDef node_copy(node); NodeDef node_copy(node);
if (node.device().empty()) { if (node.device().empty()) {
string device_name = virtual_placer_.get_canonical_device_name(node); string device_name = virtual_placer_.get_canonical_device_name(node);
node_copy.set_device(device_name); node_copy.set_device(device_name);
} }
if (!SetDataType(&node_copy, taid, DataType::DT_HALF)) { if (!SetDataType(&node_copy, taid, target_dtype_)) {
return false; return false;
} }
return IsKernelRegisteredForNode(node_copy).ok(); return IsKernelRegisteredForNode(node_copy).ok();
@ -1053,21 +1077,22 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
fname = io::JoinPath(prepend_path, fname = io::JoinPath(prepend_path,
strings::StrCat("paintbuckets", suffix, ".txt")); strings::StrCat("paintbuckets", suffix, ".txt"));
f.open(fname.c_str(), std::fstream::out); f.open(fname.c_str(), std::fstream::out);
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
get_mixed_precision_lists();
f << "WhiteList:\n"; f << "WhiteList:\n";
for (const auto& x : for (const auto& x : mp_lists->WhiteList()) {
AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_)) {
f << x << "\n"; f << x << "\n";
} }
f << "\nBlackList:\n"; f << "\nBlackList:\n";
for (const auto& x : AutoMixedPrecisionLists::BlackList()) { for (const auto& x : mp_lists->BlackList()) {
f << x << "\n"; f << x << "\n";
} }
f << "\nGrayList:\n"; f << "\nGrayList:\n";
for (const auto& x : AutoMixedPrecisionLists::GrayList()) { for (const auto& x : mp_lists->GrayList()) {
f << x << "\n"; f << x << "\n";
} }
f << "\nClearList:\n"; f << "\nClearList:\n";
for (const auto& x : AutoMixedPrecisionLists::ClearList()) { for (const auto& x : mp_lists->ClearList()) {
f << x << "\n"; f << x << "\n";
} }
f.close(); f.close();
@ -1088,7 +1113,8 @@ bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
return nodes_to_preserve_.count(node.name()); return nodes_to_preserve_.count(node.name());
} }
bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const { bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
const string& device_type) const {
string device_name; string device_name;
if (node.device().empty()) { if (node.device().empty()) {
device_name = virtual_placer_.get_canonical_device_name(node); device_name = virtual_placer_.get_canonical_device_name(node);
@ -1099,7 +1125,7 @@ bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const {
string not_used; string not_used;
if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) && if (DeviceNameUtils::SplitDeviceName(device_name, &not_used, &device) &&
absl::StrContains(absl::AsciiStrToLower(device), absl::StrContains(absl::AsciiStrToLower(device),
absl::AsciiStrToLower(DEVICE_GPU))) { absl::AsciiStrToLower(device_type))) {
return true; return true;
} }
return false; return false;
@ -1164,15 +1190,14 @@ bool IsTensorListWriterOp(const string& op) {
return tensor_list_writer_ops.count(op); return tensor_list_writer_ops.count(op);
} }
bool AutoMixedPrecisionImpl::SupportsFloat16( bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
const NodeTypeId& node_type) const {
const OpDef* op_def; const OpDef* op_def;
Status status = Status status =
OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def); OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
if (!status.ok()) return false; if (!status.ok()) return false;
return AllowedDataTypes(*op_def, node_type.type_attr) return AllowedDataTypes(*op_def, node_type.type_attr)
.Contains(DataType::DT_HALF) && .Contains(target_dtype_) &&
NodeHasFP16KernelForTypeAttr(*node_type.node, node_type.type_attr); NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
} }
// TODO(mconley): Make this change the node's name (to aid debugging). Need to // TODO(mconley): Make this change the node's name (to aid debugging). Need to
@ -1219,22 +1244,40 @@ Status AutoMixedPrecisionImpl::Optimize() {
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level)); "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
optimization_level = absl::AsciiStrToUpper(optimization_level); optimization_level = absl::AsciiStrToUpper(optimization_level);
force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL"; force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
// Many ops do not support bfloat16 on the CPU so we disallowing forcing to
// bfloat16.
return errors::InvalidArgument(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
"UNSAFE_FORCE_ALL when MKL is used");
}
fp16_whitelist_ = std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_); get_mixed_precision_lists();
fp16_blacklist_ = AutoMixedPrecisionLists::BlackList(); f16_whitelist_ = mp_lists->WhiteList();
fp16_graylist_ = AutoMixedPrecisionLists::GrayList(); f16_blacklist_ = mp_lists->BlackList();
fp16_clearlist_ = AutoMixedPrecisionLists::ClearList(); f16_graylist_ = mp_lists->GrayList();
TF_RETURN_IF_ERROR(ValidateLists(fp16_whitelist_, fp16_blacklist_, f16_clearlist_ = mp_lists->ClearList();
fp16_graylist_, fp16_clearlist_)); TF_RETURN_IF_ERROR(ValidateLists(f16_whitelist_, f16_blacklist_,
f16_graylist_, f16_clearlist_));
size_t timestamp = Env::Default()->NowMicros() / 1000; size_t timestamp = Env::Default()->NowMicros() / 1000;
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp)); TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
VLOG(2) << "Identifying nodes that should be processed"; VLOG(2) << "Identifying nodes that should be processed";
for (const NodeDef& node : graph_->node()) { for (const NodeDef& node : graph_->node()) {
if (!MustPreserve(node) && IsOnGPU(node) && bool should_process;
(ShouldIgnorePerformance() || IsOnSuitableGPUArch(node))) { switch (mode_) {
case AutoMixedPrecisionMode::CUDA:
should_process =
!MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
(ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
break;
case AutoMixedPrecisionMode::MKL:
should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
break;
}
if (should_process) {
should_process_nodes_.insert(&node); should_process_nodes_.insert(&node);
} else { } else {
LogSkippedNode(node); LogSkippedNode(node);
@ -1260,29 +1303,29 @@ Status AutoMixedPrecisionImpl::Optimize() {
for (const auto& cluster : tensor_list_clusters) { for (const auto& cluster : tensor_list_clusters) {
VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size(); VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
for (const NodeDef* node : cluster) { for (const NodeDef* node : cluster) {
VLOG(2) << "Cluster member: " << node->op() << " node " << node->name(); VLOG(2) << " Cluster member: " << node->op() << " node " << node->name();
} }
FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges); FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
} }
TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges)); TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
// The goal here is to change performance-critical ops to fp16, and to do so // The goal here is to change performance-critical ops to fp16 or bf16, and to
// with the minimal number of casts, subject to the constraint that the // do so with the minimal number of casts, subject to the constraint that the
// model's convergence is not affected. This is achieved by first identifying // model's convergence is not affected. This is achieved by first identifying
// which nodes should be changed to fp16 and then inserting casts at the // which nodes should be changed to f16 and then inserting casts at the
// boundaries between fp16/non-fp16 nodes. // boundaries between f16/non-f16 nodes.
// The algorithm for deciding which nodes to change to fp16 is as follows: // The algorithm for deciding which nodes to change to f16 is as follows:
// 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set. // 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
// This is done under the assumption that whitelist ops are always // This is done under the assumption that whitelist ops are always
// numerically-safe in fp16 and that they are the most important ops for // numerically-safe in f16 and that they are the most important ops for
// improving performance. // improving performance.
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka // 2) Add nodes to the black_set iff they are numerically-dangerous (aka
// "blacklist" ops) or they are on a forward path from a blacklist node to // "blacklist" ops) or they are on a forward path from a blacklist node to
// a black/gray node (including the node at the end of the path) through // a black/gray node (including the node at the end of the path) through
// non-numerically-dangerous ops (aka "greylist" and "clearlist" ops). // non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
// This is done to prevent numerically-dangerous ops and their downstream // This is done to prevent numerically-dangerous ops and their downstream
// effects from being changed to fp16, which would risk breaking the // effects from being changed to f16, which would risk breaking the
// numerical accuracy of the model. // numerical accuracy of the model.
// 3) For all remaining nodes that are not considered dangerous (greylist // 3) For all remaining nodes that are not considered dangerous (greylist
// and clearlist ops), find those that are between (i.e., both upstream // and clearlist ops), find those that are between (i.e., both upstream
@ -1480,7 +1523,7 @@ void AutoMixedPrecisionImpl::AddWhitelistOps(
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (!ShouldProcess(*root.node)) continue; if (!ShouldProcess(*root.node)) continue;
bool force_white = force_all_fp16_ && CanForceFP16(*root.node); bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
if (fp16_whitelist_.count(root.node->op()) || force_white) { if (f16_whitelist_.count(root.node->op()) || force_white) {
bool inserted = white_set->insert(root_idx).second; bool inserted = white_set->insert(root_idx).second;
if (VLOG_IS_ON(2) && inserted) { if (VLOG_IS_ON(2) && inserted) {
VLOG(2) << "Painting type " << root.type_attr.DebugString() VLOG(2) << "Painting type " << root.type_attr.DebugString()
@ -1504,8 +1547,8 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
absl::flat_hash_set<int> upstream_of_black_or_gray_set; absl::flat_hash_set<int> upstream_of_black_or_gray_set;
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (!(fp16_blacklist_.count(root.node->op()) || if (!(f16_blacklist_.count(root.node->op()) ||
fp16_graylist_.count(root.node->op()))) { f16_graylist_.count(root.node->op()))) {
continue; continue;
} }
DfsTypeTraversal(graph_type_view_, {&root}, DfsTypeTraversal(graph_type_view_, {&root},
@ -1514,7 +1557,7 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
const NodeTypeId& item = *graph_type_view_.GetNode(idx); const NodeTypeId& item = *graph_type_view_.GetNode(idx);
return idx == root_idx || return idx == root_idx ||
(!upstream_of_black_or_gray_set.count(idx) && (!upstream_of_black_or_gray_set.count(idx) &&
fp16_clearlist_.count(item.node->op())); f16_clearlist_.count(item.node->op()));
}), }),
DfsTypeCallbacks::PreOrder([&](int idx) { DfsTypeCallbacks::PreOrder([&](int idx) {
upstream_of_black_or_gray_set.insert(idx); upstream_of_black_or_gray_set.insert(idx);
@ -1524,7 +1567,7 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
// Propagate black forward through nodes in upstream_of_black_or_gray_set. // Propagate black forward through nodes in upstream_of_black_or_gray_set.
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (black_set->count(root_idx) || !fp16_blacklist_.count(root.node->op())) { if (black_set->count(root_idx) || !f16_blacklist_.count(root.node->op())) {
continue; continue;
} }
DfsTypeTraversal( DfsTypeTraversal(
@ -1552,7 +1595,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
absl::flat_hash_set<int> downstream_of_white_set; absl::flat_hash_set<int> downstream_of_white_set;
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (!ShouldProcess(*root.node) || !fp16_whitelist_.count(root.node->op())) { if (!ShouldProcess(*root.node) || !f16_whitelist_.count(root.node->op())) {
continue; continue;
} }
DfsTypeTraversal( DfsTypeTraversal(
@ -1561,14 +1604,14 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
const NodeTypeId& item = *graph_type_view_.GetNode(idx); const NodeTypeId& item = *graph_type_view_.GetNode(idx);
return idx == root_idx || return idx == root_idx ||
(!downstream_of_white_set.count(idx) && (!downstream_of_white_set.count(idx) &&
!fp16_whitelist_.count(item.node->op()) && !f16_whitelist_.count(item.node->op()) &&
!black_set.count(idx) && ShouldProcess(*item.node) && !black_set.count(idx) && ShouldProcess(*item.node) &&
// TODO(benbarsdell): Consider allowing propagation through // TODO(benbarsdell): Consider allowing propagation through
// ops that are already float16 in order to reduce the number // ops that are already float16 in order to reduce the number
// of casts. // of casts.
IsFloat32(item) && SupportsFloat16(item) && IsFloat32(item) && SupportsF16(item) &&
(fp16_clearlist_.count(item.node->op()) || (f16_clearlist_.count(item.node->op()) ||
fp16_graylist_.count(item.node->op()))); f16_graylist_.count(item.node->op())));
}), }),
DfsTypeCallbacks::PreOrder( DfsTypeCallbacks::PreOrder(
[&](int idx) { downstream_of_white_set.insert(idx); })); [&](int idx) { downstream_of_white_set.insert(idx); }));
@ -1579,7 +1622,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) || if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
!fp16_whitelist_.count(root.node->op())) { !f16_whitelist_.count(root.node->op())) {
continue; continue;
} }
DfsTypeTraversal( DfsTypeTraversal(
@ -1620,8 +1663,8 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
return idx == root_idx || return idx == root_idx ||
(!white_set->count(idx) && !black_set.count(idx) && (!white_set->count(idx) && !black_set.count(idx) &&
ShouldProcess(*item.node) && IsFloat32(item) && ShouldProcess(*item.node) && IsFloat32(item) &&
SupportsFloat16(item) && SupportsF16(item) &&
(fp16_clearlist_.count(item.node->op())) && (f16_clearlist_.count(item.node->op())) &&
// We don't propagate (backwards) through nodes that read // We don't propagate (backwards) through nodes that read
// Variables because it can break the behavior of TensorBoard // Variables because it can break the behavior of TensorBoard
// visualization and/or (in the case of Enter nodes) the model // visualization and/or (in the case of Enter nodes) the model
@ -1806,13 +1849,13 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
} }
} }
// Changes all white-painted type attributes to DT_HALF, and inserts Cast nodes // Changes all white-painted type attributes to DT_HALF or DT_BFLOAT16, and
// at node outputs for all edges that connect white-painted <-> // inserts Cast nodes at node outputs for all edges that connect
// non-white-painted type attributes. // white-painted <-> non-white-painted type attributes.
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts( Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
const absl::flat_hash_set<int>& white_set) { const absl::flat_hash_set<int>& white_set) {
int num_nodes_changed = 0; int num_nodes_changed = 0;
int num_nonvar_casts_to_fp16 = 0; int num_nonvar_casts_to_f16 = 0;
int num_nodes_preop = graph_->node_size(); int num_nodes_preop = graph_->node_size();
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) { for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
NodeDef* node = graph_->mutable_node(node_idx); NodeDef* node = graph_->mutable_node(node_idx);
@ -1829,8 +1872,9 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
bool src_is_white = white_set.count(node_type_idx); bool src_is_white = white_set.count(node_type_idx);
if (src_is_white) { if (src_is_white) {
VLOG(1) << "Changing type " << type_attr.DebugString() << " of " VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
<< node->op() << " node " << node->name() << " to DT_HALF"; << node->op() << " node " << node->name() << " to "
if (!SetDataType(node, type_attr, DT_HALF)) { << DataTypeString(target_dtype_);
if (!SetDataType(node, type_attr, target_dtype_)) {
return errors::Internal("Failed to set type attribute"); return errors::Internal("Failed to set type attribute");
} }
++num_nodes_changed; ++num_nodes_changed;
@ -1855,16 +1899,16 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
bool dst_is_white = white_set.count(dst_type_idx); bool dst_is_white = white_set.count(dst_type_idx);
if (src_is_white != dst_is_white) { if (src_is_white != dst_is_white) {
if (!added_cast_node) { if (!added_cast_node) {
bool to_fp16 = dst_is_white; bool to_f16 = dst_is_white;
VLOG(1) << "Inserting cast to " VLOG(1) << "Inserting cast to "
<< (to_fp16 ? "DT_HALF" : "DT_FLOAT") << " at " << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
<< src.node->op() << " " << src.node->name() << ":" << " at " << src.node->op() << " " << src.node->name()
<< src.port_id; << ":" << src.port_id;
added_cast_node = graph_view_.AddNode( added_cast_node = graph_view_.AddNode(
BuildCastNode(src, to_fp16, src.node->device())); BuildCastNode(src, to_f16, src.node->device()));
if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) && if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
!NodeImplicitlyReadsNonResourceVariable(*node)) { !NodeImplicitlyReadsNonResourceVariable(*node)) {
++num_nonvar_casts_to_fp16; ++num_nonvar_casts_to_f16;
} }
} }
TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort( TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
@ -1874,9 +1918,13 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
} }
} }
} }
// Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
// since many Python users will see this message.
const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
<< " nodes to float16 precision using " << num_nonvar_casts_to_fp16 << " nodes to " << type_str << " precision using "
<< " cast(s) to float16 (excluding Const and Variable casts)"; << num_nonvar_casts_to_f16 << " cast(s) to " << type_str
<< " (excluding Const and Variable casts)";
return Status::OK(); return Status::OK();
} }
@ -1902,12 +1950,23 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
return errors::InvalidArgument("cluster == nullptr"); return errors::InvalidArgument("cluster == nullptr");
} }
#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
if (mode_ == AutoMixedPrecisionMode::MKL) {
return errors::Unimplemented(
"The auto_mixed_precision_mkl optimizer cannot be used since "
"this build of TensorFlow is not compiled with MKL support for bfloat16. "
"For information on MKL builds, see: "
"https://software.intel.com/en-us/articles/intel-optimization-for-"
"tensorflow-installation-guide");
}
#endif
// Start by copying input graph to output. // Start by copying input graph to output.
*output = item.graph; *output = item.graph;
int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster) int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
: GetNumGPUs(*cluster, kMinGPUArch); : GetNumGPUs(*cluster, kMinGPUArch);
if (num_gpus < 1) { if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
// AutoMixedPrecision is currently only tuned for GPU. // AutoMixedPrecision is currently only tuned for GPU.
LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name() LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
<< " graph optimizer"; << " graph optimizer";
@ -1916,7 +1975,7 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
// Optimize the output graph in-place. // Optimize the output graph in-place.
AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output, AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
item.id); item.id, mode_);
if (item.id == "tf_graph") { if (item.id == "tf_graph") {
LOG(INFO) << "Running " << name() << " graph optimizer"; LOG(INFO) << "Running " << name() << " graph optimizer";
} else { } else {

View File

@ -22,16 +22,25 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
// Convert data types to float16 where appropriate to improve performance on enum class AutoMixedPrecisionMode { CUDA, MKL };
// GPUs.
// Convert data types to float16 or bfloat16 where appropriate to improve
// performance on GPUs or CPUs.
class AutoMixedPrecision : public GraphOptimizer { class AutoMixedPrecision : public GraphOptimizer {
public: public:
// If 'mode' is CUDA, converts nodes to float16 on Nvidia GPUs. If MKL,
// converts nodes to bfloat16 on CPUs in order to take advantage of MKL
// performance improvements with bfloat16.
explicit AutoMixedPrecision( explicit AutoMixedPrecision(
RewriterConfig::Toggle opt_level = RewriterConfig::ON) {} AutoMixedPrecisionMode mode = AutoMixedPrecisionMode::CUDA)
: mode_(mode) {}
~AutoMixedPrecision() override {} ~AutoMixedPrecision() override {}
string name() const override { return "auto_mixed_precision"; }; string name() const override {
return mode_ == AutoMixedPrecisionMode::CUDA ? "auto_mixed_precision_cuda"
: "auto_mixed_precision_mkl";
};
bool UsesFunctionLibrary() const override { return false; } bool UsesFunctionLibrary() const override { return false; }
@ -40,6 +49,9 @@ class AutoMixedPrecision : public GraphOptimizer {
void Feedback(Cluster* cluster, const GrapplerItem& item, void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimize_output, double result) override; const GraphDef& optimize_output, double result) override;
private:
const AutoMixedPrecisionMode mode_;
}; };
} // end namespace grappler } // end namespace grappler

View File

@ -23,10 +23,44 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
// Represents the four lists of ops: the white list, gray list, black list, and
// clear list. These lists determine which ops are converted to fp16/bf16
// (referred to as 'f16' for short) and which ops stay as fp32.
class AutoMixedPrecisionLists { class AutoMixedPrecisionLists {
private: public:
static void UpdateList(gtl::FlatSet<string>* list, const string& to_add,
const string& to_remove) { virtual ~AutoMixedPrecisionLists() {}
// Returns the set of ops that are considered numerically-safe (for execution
// in f16), performance-critical, and can run in f16. These ops are always
// converted to f16.
virtual gtl::FlatSet<string> WhiteList() = 0;
// Returns the set of ops that can run in f16 and are considered numerically-
// safe (for execution in f16), but which may be made unsafe by an upstream
// blacklist op.
virtual gtl::FlatSet<string> GrayList() = 0;
// Returns the set of ops that are considered numerically-dangerous (i.e.,
// unsafe for execution in f16) and whose effects may also be observed in
// downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
// the Exp).
virtual gtl::FlatSet<string> BlackList() = 0;
// Returns the set of ops that do not have numerically-significant effects
// (i.e., they are always considered safe for execution in f16 precision), and
// can run in f16.
virtual gtl::FlatSet<string> ClearList() = 0;
protected:
// Adds or removes ops from list if certain environmental variables are set.
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
CHECK(list_name == "WHITELIST" || list_name == "GRAYLIST" || // Crash OK.
list_name == "BLACKLIST" || list_name == "CLEARLIST");
string add_env_var =
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
string remove_env_var =
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE";
string to_add, to_remove;
TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add));
TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove));
for (const auto& x : str_util::Split(to_add, ",")) { for (const auto& x : str_util::Split(to_add, ",")) {
list->insert(x); list->insert(x);
} }
@ -35,6 +69,35 @@ class AutoMixedPrecisionLists {
} }
} }
// Subclasses should include these on the ClearList.
static void AddTensorListOps(gtl::FlatSet<string>* list) {
// Note: if a data structure op (such as TensorListPopBack) is added here,
// IsTensorListReaderOp or IsTensorListWriterOp may need to be modified
constexpr char* tensor_list_ops[] = {
"TensorListConcat",
"TensorListConcatLists",
"TensorListConcatV2",
"TensorListGather",
"TensorListGetItem",
"TensorListPopBack",
"TensorListPushBack",
"TensorListPushBackBatch",
"TensorListFromTensor",
"TensorListScatter",
"TensorListScatterV2",
"TensorListScatterIntoExistingList",
"TensorListSetItem",
"TensorListSplit",
"TensorListStack"
};
for (auto op : tensor_list_ops) {
list->insert(op);
}
}
};
class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
private:
static bool IsPseudoFastMath() { static bool IsPseudoFastMath() {
string optimization_level; string optimization_level;
TF_CHECK_OK( TF_CHECK_OK(
@ -45,16 +108,10 @@ class AutoMixedPrecisionLists {
} }
public: public:
// Returns the set of ops that are considered numerically-safe (for execution AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
// in fp16) and performance-critical. These ops are always converted to fp16. : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
static gtl::FlatSet<string> WhiteList(int cuda_version, int cudnn_version) {
string to_add, to_remove;
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_ADD", "", &to_add));
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_REMOVE", "",
&to_remove));
gtl::FlatSet<string> WhiteList() override {
auto list = gtl::FlatSet<string>{ auto list = gtl::FlatSet<string>{
"BlockLSTM", "BlockLSTM",
"BlockLSTMV2", "BlockLSTMV2",
@ -81,12 +138,12 @@ class AutoMixedPrecisionLists {
// "DepthwiseConv2dNativeBackpropInput", // "DepthwiseConv2dNativeBackpropInput",
"MatMul", "MatMul",
}; };
if (cuda_version >= 9010) { if (cuda_version_ >= 9010) {
// Fp16 BatchMatMul is slow before CUDA 9.1. // Fp16 BatchMatMul is slow before CUDA 9.1.
list.insert("BatchMatMul"); list.insert("BatchMatMul");
list.insert("BatchMatMulV2"); list.insert("BatchMatMulV2");
} }
if (cudnn_version >= 7602) { if (cudnn_version_ >= 7602) {
// Fp16 3D conv is slow before CUDNN 7.6.2. // Fp16 3D conv is slow before CUDNN 7.6.2.
list.insert("Conv3D"); list.insert("Conv3D");
list.insert("Conv3DBackpropFilter"); list.insert("Conv3DBackpropFilter");
@ -94,22 +151,14 @@ class AutoMixedPrecisionLists {
list.insert("Conv3DBackpropInput"); list.insert("Conv3DBackpropInput");
list.insert("Conv3DBackpropInputV2"); list.insert("Conv3DBackpropInputV2");
} }
UpdateList(&list, to_add, to_remove); UpdateList("WHITELIST", &list);
return list; return list;
} }
// Returns the set of ops that are considered numerically-safe (for execution gtl::FlatSet<string> GrayList() override {
// in fp16), but which may be made unsafe by an upstream blacklist op.
static gtl::FlatSet<string> GrayList() {
if (IsPseudoFastMath()) { if (IsPseudoFastMath()) {
return gtl::FlatSet<string>{}; return gtl::FlatSet<string>{};
} }
string to_add, to_remove;
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_ADD", "", &to_add));
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_REMOVE", "",
&to_remove));
auto list = gtl::FlatSet<string>{ auto list = gtl::FlatSet<string>{
"Add", "Add",
@ -156,23 +205,14 @@ class AutoMixedPrecisionLists {
"Tanh", "Tanh",
"TanhGrad", "TanhGrad",
}; };
UpdateList(&list, to_add, to_remove); UpdateList("GRAYLIST", &list);
return list; return list;
} }
// Returns the set of ops that are considered numerically-dangerous (i.e., gtl::FlatSet<string> BlackList() override {
// unsafe for execution in fp16) and whose effects may also be observed in
// downstream nodes (e.g., in Exp -> Add, the Add is unsafe due to the Exp).
static gtl::FlatSet<string> BlackList() {
if (IsPseudoFastMath()) { if (IsPseudoFastMath()) {
return gtl::FlatSet<string>{}; return gtl::FlatSet<string>{};
} }
string to_add, to_remove;
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_ADD", "", &to_add));
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_REMOVE", "",
&to_remove));
auto list = gtl::FlatSet<string>{ auto list = gtl::FlatSet<string>{
"Exp", "Exp",
@ -185,22 +225,14 @@ class AutoMixedPrecisionLists {
"SparseSoftmaxCrossEntropyWithLogits", "SparseSoftmaxCrossEntropyWithLogits",
"Sum", "Sum",
}; };
UpdateList(&list, to_add, to_remove); UpdateList("BLACKLIST", &list);
return list; return list;
} }
// Returns the set of ops that do not have numerically-significant effects gtl::FlatSet<string> ClearList() override {
// (i.e., they are always considered safe for execution in fp16 precision).
static gtl::FlatSet<string> ClearList() {
if (IsPseudoFastMath()) { if (IsPseudoFastMath()) {
return gtl::FlatSet<string>{}; return gtl::FlatSet<string>{};
} }
string to_add, to_remove;
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_ADD", "", &to_add));
TF_CHECK_OK(ReadStringFromEnvVar(
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_REMOVE", "",
&to_remove));
auto list = gtl::FlatSet<string>{ auto list = gtl::FlatSet<string>{
"Abs", "Abs",
@ -291,21 +323,6 @@ class AutoMixedPrecisionLists {
"StridedSlice", "StridedSlice",
"StridedSliceGrad", "StridedSliceGrad",
"Switch", "Switch",
"TensorListConcat",
"TensorListConcatLists",
"TensorListConcatV2",
"TensorListGather",
"TensorListGetItem",
"TensorListPopBack",
"TensorListPushBack",
"TensorListPushBackBatch",
"TensorListFromTensor",
"TensorListScatter",
"TensorListScatterV2",
"TensorListScatterIntoExistingList",
"TensorListSetItem",
"TensorListSplit",
"TensorListStack",
"Tile", "Tile",
"TopK", "TopK",
"TopKV2", "TopKV2",
@ -313,7 +330,125 @@ class AutoMixedPrecisionLists {
"Where", "Where",
"ZerosLike", "ZerosLike",
}; };
UpdateList(&list, to_add, to_remove); AddTensorListOps(&list);
UpdateList("CLEARLIST", &list);
return list;
}
private:
int cuda_version_;
int cudnn_version_;
};
class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
private:
public:
AutoMixedPrecisionListsMkl() {}
// Only ops which are supported by MKL in bfloat16 should be added to the
// white list, gray list, or clear list.
gtl::FlatSet<string> WhiteList() override {
auto list = gtl::FlatSet<string>{
"Conv2D",
"Conv2DBackpropFilter",
"Conv2DBackpropInput",
"Conv3D",
"Conv3DBackpropFilterV2",
"Conv3DBackpropInputV2",
"DepthwiseConv2dNative",
"DepthwiseConv2dNativeBackpropFilter",
"DepthwiseConv2dNativeBackpropInput",
"MatMul",
"BatchMatMul",
"BatchMatMulV2"
};
UpdateList("WHITELIST", &list);
return list;
}
gtl::FlatSet<string> GrayList() override {
auto list = gtl::FlatSet<string>{
"Add",
"AddN",
"AddV2",
"AvgPool",
"AvgPool3D",
"AvgPool3DGrad",
"AvgPoolGrad",
"BiasAdd",
"BiasAddGrad",
"BiasAddV1",
"FusedBatchNormV2",
"FusedBatchNormGradV2",
"FusedBatchNormV3",
"FusedBatchNormGradV3",
"LeakyRelu",
"LeakyReluGrad",
"Mul",
"Sub",
};
UpdateList("GRAYLIST", &list);
return list;
}
gtl::FlatSet<string> BlackList() override {
auto list = gtl::FlatSet<string>{
"Exp",
"Expm1",
"L2Loss",
"Mean",
"Pow",
"SaveV2",
"Softmax",
"SoftmaxCrossEntropyWithLogits",
"SparseSoftmaxCrossEntropyWithLogits",
"Sum",
};
UpdateList("BLACKLIST", &list);
return list;
}
gtl::FlatSet<string> ClearList() override {
auto list = gtl::FlatSet<string>{
"Concat",
"ConcatV2",
"Enter",
"EnsureShape",
"Equal",
"Exit",
"ExpandDims",
"Identity",
"MaxPool",
"MaxPool3D",
"MaxPool3DGrad",
"MaxPoolGrad",
"MaxPoolV2",
"Maximum",
"Merge",
"NextIteration",
"PreventGradient",
"Relu",
"Relu6",
"Relu6Grad",
"ReluGrad",
"Reshape",
"Select",
"SelectV2",
"Shape",
"ShapeN",
"Slice",
"Split",
"SplitV",
"Squeeze",
"StopGradient",
"Switch",
"Transpose",
"ZerosLike",
};
AddTensorListOps(&list);
UpdateList("CLEARLIST", &list);
return list; return list;
} }
}; };

View File

@ -13,12 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// Currently, this test only passes when TensorFlow passes with CUDA, because
// otherwise the optimizer will not turn clearlist nodes to float16. When
// looking at clearlist nodes, this optimizer checks if the nodes have a float16
// GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h" #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
#include <utility> #include <utility>
@ -70,6 +64,31 @@ Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
return tensor; return tensor;
} }
void VerifyGraphsEquivalent(const GraphDef& original_graph,
const GraphDef& optimized_graph,
const string& func) {
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
GraphView optimized_view(&optimized_graph);
for (int i = 0; i < original_graph.node_size(); ++i) {
const NodeDef& original = original_graph.node(i);
const NodeDef& optimized = *optimized_view.GetNode(original.name());
EXPECT_EQ(original.name(), optimized.name()) << func;
EXPECT_EQ(original.op(), optimized.op()) << func;
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
if (original.input_size() == optimized.input_size()) {
for (int j = 0; j < original.input_size(); ++j) {
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
}
}
}
}
// Currently, this test suite only passes when TensorFlow passes with CUDA,
// because otherwise the optimizer will not turn clearlist nodes to float16.
// When looking at clearlist nodes, this optimizer checks if the nodes have a
// float16 GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
const std::pair<int, int> kMinGPUArch = {7, 0}; const std::pair<int, int> kMinGPUArch = {7, 0};
class AutoMixedPrecisionTest : public GrapplerTest { class AutoMixedPrecisionTest : public GrapplerTest {
@ -184,25 +203,6 @@ class AutoMixedPrecisionTest : public GrapplerTest {
bool gpu_available_; bool gpu_available_;
}; };
void VerifyGraphsEquivalent(const GraphDef& original_graph,
const GraphDef& optimized_graph,
const string& func) {
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
GraphView optimized_view(&optimized_graph);
for (int i = 0; i < original_graph.node_size(); ++i) {
const NodeDef& original = original_graph.node(i);
const NodeDef& optimized = *optimized_view.GetNode(original.name());
EXPECT_EQ(original.name(), optimized.name()) << func;
EXPECT_EQ(original.op(), optimized.op()) << func;
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
if (original.input_size() == optimized.input_size()) {
for (int j = 0; j < original.input_size(); ++j) {
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
}
}
}
}
TEST_F(AutoMixedPrecisionTest, NoOp) { TEST_F(AutoMixedPrecisionTest, NoOp) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Const(s.WithOpName("input"), 1.234f, {32}); Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
@ -1164,8 +1164,188 @@ TEST_F(AutoMixedPrecisionTest, TanhOp) {
}); });
} }
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if INTEL_MKL
#ifdef ENABLE_INTEL_MKL_BFLOAT16
class AutoMixedPrecisionMklTest : public GrapplerTest {
protected:
void SetUp() override {
virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
TF_CHECK_OK(virtual_cluster_->Provision());
}
void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
std::unique_ptr<Cluster> virtual_cluster_;
};
TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
GrapplerItem item;
item.fetch = {"fetch"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
VLOG(1) << output.DebugString();
VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
GraphView output_view(&output);
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());
EXPECT_EQ(tensors.size(), item.fetch.size());
for (int i = 0; i < item.fetch.size(); ++i) {
test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
}
}
TEST_F(AutoMixedPrecisionMklTest, Simple) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk3);
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
GrapplerItem item;
item.fetch = {"fetch"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
VLOG(1) << output.DebugString();
GraphView output_view(&output);
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Ta").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Tb").type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());
EXPECT_EQ(tensors.size(), item.fetch.size());
for (int i = 0; i < item.fetch.size(); ++i) {
test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
}
}
TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
tensorflow::Input shape = {32, 32};
auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
auto tl1w1 =
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
auto tl1w2 =
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
// Ensure that TensorListResize doesn't cause any problems.
Output tl1rs =
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
shape, DT_FLOAT)
.item;
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
auto tl1w3 =
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2);
Output tl1r2 =
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
shape, DT_FLOAT)
.item;
auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
auto tl2w1 =
ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
Output tl2r1 =
ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
shape, DT_FLOAT)
.item;
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
GrapplerItem item;
item.fetch = {"fetch1", "fetch2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
VLOG(1) << output.DebugString();
GraphView output_view(&output);
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
const char* type_key = "element_dtype";
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
DT_BFLOAT16);
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());
EXPECT_EQ(tensors.size(), item.fetch.size());
for (int i = 0; i < item.fetch.size(); ++i) {
test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
}
}
#endif // ENABLE_INTEL_MKL_BFLOAT16
#endif // INTEL_MKL
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -188,7 +188,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("remap", new Remapper(cfg_.remapping())); MK_OPT("remap", new Remapper(cfg_.remapping()));
MK_OPT("layout", new GenericLayoutOptimizer()); MK_OPT("layout", new GenericLayoutOptimizer());
MK_OPT("auto_mixed_precision", MK_OPT("auto_mixed_precision",
new AutoMixedPrecision(cfg_.auto_mixed_precision())); new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
MK_OPT("auto_mixed_precision_mkl",
new AutoMixedPrecision(AutoMixedPrecisionMode::MKL));
MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL)); MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
MK_OPT("common_subgraph_elimination", MK_OPT("common_subgraph_elimination",
new CommonSubgraphElimination(cfg_.common_subgraph_elimination())); new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
@ -249,7 +251,11 @@ Status MetaOptimizer::InitializeOptimizers(
} }
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) { if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
optimizers->push_back( optimizers->push_back(
MakeUnique<AutoMixedPrecision>(cfg_.auto_mixed_precision())); MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
}
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())) {
optimizers->push_back(
MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::MKL));
} }
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) { if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>()); optimizers->push_back(MakeUnique<PinToHostOptimizer>());
@ -835,6 +841,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON || rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON || rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) || AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
!rewrite_cfg.optimizers().empty() || !rewrite_cfg.optimizers().empty() ||
!rewrite_cfg.custom_optimizers().empty(); !rewrite_cfg.custom_optimizers().empty();
} }

View File

@ -85,11 +85,15 @@ message RewriterConfig {
// Enable the swap of kernel implementations based on the device placement // Enable the swap of kernel implementations based on the device placement
// (default is ON). // (default is ON).
Toggle implementation_selector = 22; Toggle implementation_selector = 22;
// Optimize data types (default is OFF). // Optimize data types for CUDA (default is OFF).
// e.g., This will try to use float16 on GPU which is faster. // This will try to use float16 on GPU which is faster.
// Note that this can change the numerical stability of the graph and may // Note that this can change the numerical stability of the graph and may
// require the use of loss scaling to maintain model convergence. // require the use of loss scaling to maintain model convergence.
Toggle auto_mixed_precision = 23; Toggle auto_mixed_precision = 23;
// Optimize data types for MKL (default is OFF).
// This will try to use bfloat16 on CPUs, which is faster.
// Note that this can change the numerical stability of the graph.
Toggle auto_mixed_precision_mkl = 25;
// Disable the entire meta optimizer (off by default). // Disable the entire meta optimizer (off by default).
bool disable_meta_optimizer = 19; bool disable_meta_optimizer = 19;

View File

@ -19,8 +19,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import unittest
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
@ -209,7 +209,7 @@ def _make_node_with_color(color, input_tensor, name=None):
if color == 'c': # Clear node if color == 'c': # Clear node
return nn.relu(input_tensor, name=name) return nn.relu(input_tensor, name=name)
if color == 'b': # Black node if color == 'b': # Black node
return math_ops.sqrt(math_ops.pow(input_tensor, 2.), name=name) return math_ops.pow(math_ops.pow(input_tensor, 2.), 0.5, name=name)
raise ValueError('Invalid node color: ' + str(color)) raise ValueError('Invalid node color: ' + str(color))
@ -231,18 +231,21 @@ def _build_simple_loop_graph(inp_colors, body_colors, out_colors):
return a return a
def _get_config(auto_mixed_precision=True): def _get_config(auto_mixed_precision_mode):
"""Returns a ConfigProto with auto mixed precision enabled if appropriate.""" """Returns a ConfigProto with auto mixed precision enabled if appropriate."""
if auto_mixed_precision: rewrite_config = rewriter_config_pb2.RewriterConfig(
rewrite_config = rewriter_config_pb2.RewriterConfig( # do not remove duplicated nodes
auto_mixed_precision=rewriter_config_pb2.RewriterConfig.ON, arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
# do not remove duplicated nodes # do not turn Conv2D and other nodes into _FusedConv2D
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF) remapping=rewriter_config_pb2.RewriterConfig.OFF,
)
if auto_mixed_precision_mode == 'cuda':
rewrite_config.auto_mixed_precision = rewriter_config_pb2.RewriterConfig.ON
elif auto_mixed_precision_mode == 'mkl':
rewrite_config.auto_mixed_precision_mkl = (
rewriter_config_pb2.RewriterConfig.ON)
else: else:
rewrite_config = rewriter_config_pb2.RewriterConfig( assert auto_mixed_precision_mode is None
auto_mixed_precision=rewriter_config_pb2.RewriterConfig.OFF,
# do not remove duplicated nodes
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
rewrite_config.min_graph_nodes = -1 rewrite_config.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions( graph_options = config_pb2.GraphOptions(
rewrite_options=rewrite_config, build_cost_model=1) rewrite_options=rewrite_config, build_cost_model=1)
@ -255,19 +258,33 @@ def _is_cast_to_fp16(node_name):
return node_name.endswith('-CastToFp16-AutoMixedPrecision') return node_name.endswith('-CastToFp16-AutoMixedPrecision')
def _is_cast_to_bf16(node_name):
return node_name.endswith('-CastToBf16-AutoMixedPrecision')
def _is_cast_to_fp32(node_name): def _is_cast_to_fp32(node_name):
return node_name.endswith('-CastToFp32-AutoMixedPrecision') return node_name.endswith('-CastToFp32-AutoMixedPrecision')
def _count_casts(nodes): def _count_casts(mode, nodes):
"""Counts the number of casts to f16 and fp32."""
num_to_fp16 = 0 num_to_fp16 = 0
num_to_bf16 = 0
num_to_fp32 = 0 num_to_fp32 = 0
for node in nodes: for node in nodes:
if _is_cast_to_fp16(node.name): if _is_cast_to_fp16(node.name):
num_to_fp16 += 1 num_to_fp16 += 1
if _is_cast_to_bf16(node.name):
num_to_bf16 += 1
elif _is_cast_to_fp32(node.name): elif _is_cast_to_fp32(node.name):
num_to_fp32 += 1 num_to_fp32 += 1
return num_to_fp16, num_to_fp32 if mode == 'cuda':
assert num_to_bf16 == 0
return num_to_fp16, num_to_fp32
else:
assert mode == 'mkl'
assert num_to_fp16 == 0
return num_to_bf16, num_to_fp32
def _build_node_map(nodes): def _build_node_map(nodes):
@ -303,7 +320,7 @@ def _example_noninlined_funcdef(features):
return features * math_ops.sigmoid(features) return features * math_ops.sigmoid(features)
class AutoMixedPrecisionTest(test.TestCase): class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
"""Tests the Grappler auto mixed precision optimizer.""" """Tests the Grappler auto mixed precision optimizer."""
IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE' IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
@ -311,8 +328,8 @@ class AutoMixedPrecisionTest(test.TestCase):
def setUp(self): def setUp(self):
super(AutoMixedPrecisionTest, self).setUp() super(AutoMixedPrecisionTest, self).setUp()
# Enable the tests to be run on pre-Volta GPUs by telling the grappler pass # Enable the CUDA tests to be run on pre-Volta GPUs by telling the grappler
# to ignore performance and always transform the graph. # pass to ignore performance and always transform the graph.
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR) self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
os.environ[self.IGNORE_PERF_VAR] = '1' os.environ[self.IGNORE_PERF_VAR] = '1'
@ -323,24 +340,33 @@ class AutoMixedPrecisionTest(test.TestCase):
del os.environ[self.IGNORE_PERF_VAR] del os.environ[self.IGNORE_PERF_VAR]
super(AutoMixedPrecisionTest, self).tearDown() super(AutoMixedPrecisionTest, self).tearDown()
def _assert_output_fp16(self, node_map, node_name, output_port=0): def _lower_precision_dtype(self, mode):
self.assertEqual(node_map[node_name].output_info[output_port].dtype, return dtypes.float16 if mode == 'cuda' else dtypes.bfloat16
types_pb2.DT_HALF)
def _run(self, fetches): def _assert_output_f16(self, mode, node_map, node_name, output_port=0):
self.assertEqual(node_map[node_name].output_info[output_port].dtype,
self._lower_precision_dtype(mode).as_datatype_enum)
def _run(self, mode, fetches):
"""Runs the graph and returns the evaluation of the fetches.""" """Runs the graph and returns the evaluation of the fetches."""
with session.Session(config=_get_config(False)) as sess: with session.Session(config=_get_config(None)) as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
output_val_ref = self.evaluate(fetches) output_val_ref = self.evaluate(fetches)
with session.Session(config=_get_config()) as sess: with session.Session(config=_get_config(mode)) as sess:
sess.run(variables.global_variables_initializer()) sess.run(variables.global_variables_initializer())
metadata = config_pb2.RunMetadata() metadata = config_pb2.RunMetadata()
output_val = sess.run(fetches, run_metadata=metadata) output_val = sess.run(fetches, run_metadata=metadata)
return output_val_ref, output_val, metadata.cost_graph return output_val_ref, output_val, metadata.cost_graph
def _run_simple_loop_test(self, inp, body, out): def _maybe_skip(self, mode):
if mode == 'cuda' and not test.is_gpu_available(cuda_only=True):
self.skipTest('No GPU is available')
if mode == 'mkl' and not test_util.IsMklEnabled():
self.skipTest('MKL is not enabled')
def _run_simple_loop_test(self, mode, inp, body, out):
"""Runs a test of a simple loop. """Runs a test of a simple loop.
The loop has different node colors in different sections of the graph. The The loop has different node colors in different sections of the graph. The
@ -359,398 +385,441 @@ class AutoMixedPrecisionTest(test.TestCase):
out: A string of letters indicating the colors and expected dtypes of the out: A string of letters indicating the colors and expected dtypes of the
output nodes. output nodes.
""" """
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
expected_types = [] expected_types = []
for section in [inp, body, out]: for section in [inp, body, out]:
section_expected_types = [] section_expected_types = []
for color in section: for color in section:
if color.isupper(): if color.isupper():
expected_type = types_pb2.DT_HALF expected_type = self._lower_precision_dtype(mode).as_datatype_enum
else: else:
expected_type = types_pb2.DT_FLOAT expected_type = types_pb2.DT_FLOAT
section_expected_types.append(expected_type) section_expected_types.append(expected_type)
expected_types.append(section_expected_types) expected_types.append(section_expected_types)
a = _build_simple_loop_graph(inp, body, out) a = _build_simple_loop_graph(inp, body, out)
output_val_ref, output_val, cost_graph = self._run(a) output_val_ref, output_val, cost_graph = self._run(mode, a)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
section_names = ['input', 'while/body', 'output'] section_names = ['input', 'while/body', 'output']
all_types_correct = True all_types_correct = True
for section_name, expected_types in zip(section_names, expected_types): for section_name, expected_types in zip(section_names, expected_types):
for i, expected_type in enumerate(expected_types): for i, expected_type in enumerate(expected_types):
node_name = section_name + '_%i' % i node_name = section_name + '_%i' % i
output_port = 0 output_port = 0
optimized_type = node_map[node_name].output_info[output_port].dtype optimized_type = node_map[node_name].output_info[output_port].dtype
if optimized_type != expected_type: if optimized_type != expected_type:
print('Expected node %s to have type %s but got type %s' % print('Expected node %s to have type %s but got type %s' %
(node_name, expected_type, optimized_type)) (node_name, expected_type, optimized_type))
all_types_correct = False all_types_correct = False
self.assertTrue(all_types_correct) self.assertTrue(all_types_correct)
if mode == 'mkl':
self.assertAllClose(output_val_ref, output_val, atol=2e-2, rtol=2e-2)
else:
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3) self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn(self): def test_conv_bn(self, mode):
"""Test graph with convolution followed by batch norm.""" """Test graph with convolution followed by batch norm."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([2, 8, 8, 1]) x = _input([2, 8, 8, 1])
x = _conv_bn(x) x = _conv_bn(x)
output = _conv_bn(x) output = _conv_bn(x)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node) num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
self._assert_output_fp16(node_map, 'Conv2D') self._assert_output_f16(mode, node_map, 'Conv2D')
self._assert_output_fp16(node_map, 'FusedBatchNormV3') self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
self._assert_output_fp16(node_map, 'Conv2D_1') self._assert_output_f16(mode, node_map, 'Conv2D_1')
self.assertEqual(num_to_fp16, self.assertEqual(num_to_f16, 3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1 self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0 if mode == 'mkl':
tol = 1e-2
elif test.is_built_with_rocm():
# Bump up the tolerance for the ROCm platform # Bump up the tolerance for the ROCm platform
# The default tolerance (1e-3) results in a tiny fraction (<1%) of # The default tolerance (1e-3) results in a tiny fraction (<1%) of
# miscompares on ROCm platform, and hence the tolerance bump # miscompares on ROCm platform, and hence the tolerance bump
tol = 2e-3 if test.is_built_with_rocm else 1e-3 tol = 2e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) else:
tol = 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
# TODO: enable these tests when cuDNN is upgraded to >= 7.6.2. Same with the @parameterized.parameters(['cuda', 'mkl'])
# test_conv3d() below.
@unittest.skip('Test case should be skipped when cuDNN < 7.6.2')
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_conv3d_bn(self): def test_conv3d_bn(self, mode):
"""Test graph with convolution followed by batch norm.""" """Test graph with convolution followed by batch norm."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) if mode == 'cuda':
x = _input([2, 8, 8, 8, 1]) # TODO: enable these tests when cuDNN is upgraded to >= 7.6.2.
x = _conv3d_bn(x) self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
output = _conv3d_bn(x) random_seed.set_random_seed(0)
x = _input([2, 8, 8, 8, 1])
x = _conv3d_bn(x)
output = _conv3d_bn(x)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node) num_to_fp16, num_to_fp32 = _count_casts(mode, cost_graph.node)
self._assert_output_fp16(node_map, 'Conv3D') self._assert_output_f16(mode, node_map, 'Conv3D')
self._assert_output_fp16(node_map, 'FusedBatchNormV3') self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
self._assert_output_fp16(node_map, 'Conv3D_1') self._assert_output_f16(mode, node_map, 'Conv3D_1')
self.assertEqual(num_to_fp16, 3) # Before Conv3D:0, Conv3D:1, Conv3D_1:1 self.assertEqual(num_to_fp16, 3) # Before Conv3D:0, Conv3D:1, Conv3D_1:1
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0 self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2) self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2)
@unittest.skip('Test case should be skipped when cuDNN < 7.6.2') @parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_conv3d(self): def test_conv3d(self, mode):
"""Test grad ops with convolution3d graph.""" """Test grad ops with convolution3d graph."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) if mode == 'cuda':
x = _input([2, 8, 8, 8, 1]) # TODO: enable these tests when cuDNN is upgraded to >= 7.6.2.
f = _weight([3, 3, 3, 1, 6]) self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
y = _conv3d(x, f) random_seed.set_random_seed(0)
y = array_ops.identity(y) x = _input([2, 8, 8, 8, 1])
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) f = _weight([3, 3, 3, 1, 6])
g = optimizer.compute_gradients(y, [x, f]) y = _conv3d(x, f)
output = (y, g) y = array_ops.identity(y)
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
g = optimizer.compute_gradients(y, [x, f])
output = (y, g)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'Conv3D') self._assert_output_f16(mode, node_map, 'Conv3D')
self._assert_output_fp16(node_map, self._assert_output_f16(mode, node_map,
'gradients/Conv3D_grad/Conv3DBackpropInputV2') 'gradients/Conv3D_grad/Conv3DBackpropInputV2')
self._assert_output_fp16(node_map, self._assert_output_f16(mode, node_map,
'gradients/Conv3D_grad/Conv3DBackpropFilterV2') 'gradients/Conv3D_grad/Conv3DBackpropFilterV2')
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) tol = 5e-2 if mode == 'mkl' else 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
# MKL
@parameterized.parameters(['cuda'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_conv_bn_dropout(self): def test_conv_bn_dropout(self, mode):
"""Test dropout precision of convolution batch norm graph.""" """Test dropout precision of convolution batch norm graph."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([2, 8, 8, 1]) x = _input([2, 8, 8, 1])
y = _conv_bn(x) y = _conv_bn(x)
y = nn.dropout(y, rate=0.5) y = nn.dropout(y, rate=0.5)
y = math_ops.add(y, 1, name='addition') y = math_ops.add(y, 1, name='addition')
y = _conv_bn(y) y = _conv_bn(y)
y = array_ops.identity(y) y = array_ops.identity(y)
optimizer = gradient_descent.GradientDescentOptimizer( optimizer = gradient_descent.GradientDescentOptimizer(
learning_rate=0.01) learning_rate=0.01)
g = optimizer.compute_gradients(y, [x]) g = optimizer.compute_gradients(y, [x])
output = (y, g) output = (y, g)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'Conv2D') self._assert_output_f16(mode, node_map, 'Conv2D')
self._assert_output_fp16(node_map, 'FusedBatchNormV3') self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
# We do not assert dropout's dtype because we do not want to rely on the # We do not assert dropout's dtype because we do not want to rely on the
# node names of dropout's internal implementation. # node names of dropout's internal implementation.
self._assert_output_fp16(node_map, 'addition') self._assert_output_f16(mode, node_map, 'addition')
self._assert_output_fp16(node_map, 'Conv2D_1') self._assert_output_f16(mode, node_map, 'Conv2D_1')
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
# Bump up the tolerance for the ROCm platform # Bump up the tolerance for the ROCm platform
# The default tolerance (1e-3) results in a tiny fraction (<1%) of # The default tolerance (1e-3) results in a tiny fraction (<1%) of
# miscompares on ROCm platform, and hence the tolerance bump # miscompares on ROCm platform, and hence the tolerance bump
tol = 2e-3 if test.is_built_with_rocm else 1e-3 tol = 2e-3 if test.is_built_with_rocm else 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
# MKL
@parameterized.parameters(['cuda'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_conv_pool(self): def test_conv_pool(self, mode):
"""Test graph with convolution followed by pooling.""" """Test graph with convolution followed by pooling."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([2, 8, 8, 1]) x = _input([2, 8, 8, 1])
output = _conv_pool(x) output = _conv_pool(x)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node) num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
self._assert_output_fp16(node_map, 'Conv2D') self._assert_output_f16(mode, node_map, 'Conv2D')
self._assert_output_fp16(node_map, 'Relu') self._assert_output_f16(mode, node_map, 'Relu')
self._assert_output_fp16(node_map, 'MaxPool') self._assert_output_f16(mode, node_map, 'MaxPool')
self._assert_output_fp16(node_map, 'Conv2D_1') self._assert_output_f16(mode, node_map, 'Conv2D_1')
self.assertEqual(num_to_fp16, 4) self.assertEqual(num_to_f16, 4)
self.assertEqual(num_to_fp32, 1) self.assertEqual(num_to_fp32, 1)
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) tol = 5e-3 if mode == 'mkl' else 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235') @test_util.run_v1_only('b/138749235')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_simple_loop(self): def test_simple_loop(self, mode):
"""Test graph with while loop.""" """Test graph with while loop."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([8, 8]) x = _input([8, 8])
y = _simple_loop(x, _matmul_act)[1] y = _simple_loop(x, _matmul_act)[1]
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
g = optimizer.compute_gradients(y, [x]) g = optimizer.compute_gradients(y, [x])
output = (y, g) output = (y, g)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'while/MatMul') self._assert_output_f16(mode, node_map, 'while/MatMul')
self._assert_output_fp16(node_map, 'while/Relu') self._assert_output_f16(mode, node_map, 'while/Relu')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) tol = 1e-2 if mode == 'mkl' else 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235') @test_util.run_v1_only('b/138749235')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_loop_with_vars_intertwined(self): def test_loop_with_vars_intertwined(self, mode):
"""Test graph with intertwined while loops.""" """Test graph with intertwined while loops."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([8, 8]) x = _input([8, 8])
_, _, k, l = _loop_vars_intertwined( _, _, k, l = _loop_vars_intertwined(
array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act) array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act)
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
g = optimizer.compute_gradients(k, [x]) g = optimizer.compute_gradients(k, [x])
output = (k, l, g) output = (k, l, g)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'while/MatMul') self._assert_output_f16(mode, node_map, 'while/MatMul')
self._assert_output_fp16(node_map, 'while/Relu') self._assert_output_f16(mode, node_map, 'while/Relu')
self._assert_output_fp16(node_map, 'while/MatMul_1') self._assert_output_f16(mode, node_map, 'while/MatMul_1')
self._assert_output_fp16(node_map, 'while/Relu_1') self._assert_output_f16(mode, node_map, 'while/Relu_1')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) tol = 5e-3 if mode == 'mkl' else 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
@parameterized.parameters(['cuda'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_multi_paths(self): def test_multi_paths(self, mode):
"""Test graph with multiple paths.""" """Test graph with multiple paths."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([2, 8, 8, 3]) x = _input([2, 8, 8, 3])
x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3) x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3)
y1 = _conv_pool(x1) y1 = _conv_pool(x1)
y2 = _conv_pool(x2) y2 = _conv_pool(x2)
y3 = _conv_pool(x3) y3 = _conv_pool(x3)
y = array_ops.concat([y1, y2, y3], axis=3) y = array_ops.concat([y1, y2, y3], axis=3)
y = array_ops.identity(y) y = array_ops.identity(y)
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
g = optimizer.compute_gradients(y, [x]) g = optimizer.compute_gradients(y, [x])
output = (y, g) output = (y, g)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'split') self._assert_output_f16(mode, node_map, 'split')
for suffix in [''] + ['_%i' % i for i in range(1, 6)]: for suffix in [''] + ['_%i' % i for i in range(1, 6)]:
self._assert_output_fp16(node_map, 'Conv2D' + suffix) self._assert_output_f16(mode, node_map, 'Conv2D' + suffix)
self._assert_output_fp16(node_map, 'Relu' + suffix) self._assert_output_f16(mode, node_map, 'Relu' + suffix)
self._assert_output_fp16(node_map, 'MaxPool' + suffix) self._assert_output_f16(mode, node_map, 'MaxPool' + suffix)
self._assert_output_fp16(node_map, 'concat') self._assert_output_f16(mode, node_map, 'concat')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_multi_paths_2(self): def test_multi_paths_2(self, mode):
"""Test graph with multiple paths.""" """Test graph with multiple paths."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([8, 8]) x = _input([8, 8])
y1 = _matmul_act(x) y1 = _matmul_act(x)
y2 = _matmul_act(x) y2 = _matmul_act(x)
y = y1 + y2 + x y = y1 + y2 + x
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
g = optimizer.compute_gradients(y, [x]) g = optimizer.compute_gradients(y, [x])
output = (g, y) output = (g, y)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'MatMul') self._assert_output_f16(mode, node_map, 'MatMul')
self._assert_output_fp16(node_map, 'Relu') self._assert_output_f16(mode, node_map, 'Relu')
self._assert_output_fp16(node_map, 'MatMul_1') self._assert_output_f16(mode, node_map, 'MatMul_1')
self._assert_output_fp16(node_map, 'Relu_1') self._assert_output_f16(mode, node_map, 'Relu_1')
if mode == 'mkl':
tol = 2e-2
elif test.is_built_with_rocm():
# Bump up the tolerance for the ROCm platform # Bump up the tolerance for the ROCm platform
# The default tolerance (1e-3) results in a tiny fraction (<1%) of # The default tolerance (1e-3) results in a tiny fraction (<1%) of
# miscompares on ROCm platform, and hence the tolerance bump # miscompares on ROCm platform, and hence the tolerance bump
tol = 2e-3 if test.is_built_with_rocm else 1e-3 tol = 2e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol) else:
tol = 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
@parameterized.parameters(['cuda']) # MKL doesn't support bf16 Sigmoid
@test_util.run_v1_only('b/138749235') @test_util.run_v1_only('b/138749235')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_recurrent_lstm(self): def test_recurrent_lstm(self, mode):
"""Test graph with recurrent lstm.""" """Test graph with recurrent lstm."""
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
init_c = _input([8, 4]) init_c = _input([8, 4])
init_h = _input([8, 4]) init_h = _input([8, 4])
_, _, h, _ = _recurrent_lstm(init_c, init_h) _, _, h, _ = _recurrent_lstm(init_c, init_h)
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
g = optimizer.compute_gradients(h, [init_c, init_h]) g = optimizer.compute_gradients(h, [init_c, init_h])
output = (h, g) output = (h, g)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'while/concat') self._assert_output_f16(mode, node_map, 'while/concat')
self._assert_output_fp16(node_map, 'while/MatMul') self._assert_output_f16(mode, node_map, 'while/MatMul')
self._assert_output_fp16(node_map, 'while/split') self._assert_output_f16(mode, node_map, 'while/split')
self._assert_output_fp16(node_map, 'while/Sigmoid') self._assert_output_f16(mode, node_map, 'while/Sigmoid')
self._assert_output_fp16(node_map, 'while/Sigmoid_1') self._assert_output_f16(mode, node_map, 'while/Sigmoid_1')
self._assert_output_fp16(node_map, 'while/Sigmoid_2') self._assert_output_f16(mode, node_map, 'while/Sigmoid_2')
self._assert_output_fp16(node_map, 'while/Tanh') self._assert_output_f16(mode, node_map, 'while/Tanh')
self._assert_output_fp16(node_map, 'while/Tanh_1') self._assert_output_f16(mode, node_map, 'while/Tanh_1')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test') @test_util.run_v1_only('v1 loop test')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_1(self): def test_propagation_through_simple_loop_1(self, mode):
self._run_simple_loop_test('W', 'C', 'C') self._run_simple_loop_test(mode, 'W', 'C', 'C')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test') @test_util.run_v1_only('v1 loop test')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_2(self): def test_propagation_through_simple_loop_2(self, mode):
self._run_simple_loop_test('C', 'C', 'W') self._run_simple_loop_test(mode, 'C', 'C', 'W')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test') @test_util.run_v1_only('v1 loop test')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_3(self): def test_propagation_through_simple_loop_3(self, mode):
self._run_simple_loop_test('W', 'G', 'W') self._run_simple_loop_test(mode, 'W', 'G', 'W')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('v1 loop test') @test_util.run_v1_only('v1 loop test')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_4(self): def test_propagation_through_simple_loop_4(self, mode):
self._run_simple_loop_test('W', 'gbg', 'W') self._run_simple_loop_test(mode, 'W', 'gbg', 'W')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235') @test_util.run_v1_only('b/138749235')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_5(self): def test_propagation_through_simple_loop_5(self, mode):
self._run_simple_loop_test('b', 'gWC', 'c') self._run_simple_loop_test(mode, 'b', 'gWC', 'c')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235') @test_util.run_v1_only('b/138749235')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_6(self): def test_propagation_through_simple_loop_6(self, mode):
self._run_simple_loop_test('b', 'CWCG', 'C') self._run_simple_loop_test(mode, 'b', 'CWCG', 'C')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235') @test_util.run_v1_only('b/138749235')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_7(self): def test_propagation_through_simple_loop_7(self, mode):
self._run_simple_loop_test('C', 'GWCG', 'C') self._run_simple_loop_test(mode, 'C', 'GWCG', 'C')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_v1_only('b/138749235') @test_util.run_v1_only('b/138749235')
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_propagation_through_simple_loop_8(self): def test_propagation_through_simple_loop_8(self, mode):
self._run_simple_loop_test('C', 'CgbgWC', 'g') self._run_simple_loop_test(mode, 'C', 'CgbgWC', 'g')
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_noninlined_funcdef(self): def test_noninlined_funcdef(self, mode):
"""Test graph with non-inlined function subgraph. """Test graph with non-inlined function subgraph.
This requires the grappler pass to handle an OpDef that only appears in the This requires the grappler pass to handle an OpDef that only appears in the
graph's function registry instead of the global op registry. graph's function registry instead of the global op registry.
""" """
if test.is_gpu_available(cuda_only=True): self._maybe_skip(mode)
random_seed.set_random_seed(0) random_seed.set_random_seed(0)
x = _input([8, 8]) x = _input([8, 8])
y = _matmul_act(x) y = _matmul_act(x)
y = _example_noninlined_funcdef(y) y = _example_noninlined_funcdef(y)
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01) optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
g = optimizer.compute_gradients(y, [x]) g = optimizer.compute_gradients(y, [x])
output = (g, y) output = (g, y)
output_val_ref, output_val, cost_graph = self._run(output) output_val_ref, output_val, cost_graph = self._run(mode, output)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'MatMul') self._assert_output_f16(mode, node_map, 'MatMul')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) tol = 1e-2 if mode == 'mkl' else 1e-3
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
@parameterized.parameters(['cuda', 'mkl'])
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
@test_util.disable_xla('This test does not pass with XLA') @test_util.disable_xla('This test does not pass with XLA')
def test_ingraph_train_loop(self): def test_ingraph_train_loop(self, mode):
"""Tests a graph containing a while loop around a training update. """Tests a graph containing a while loop around a training update.
This requires the grappler pass to take special care with its handling of This requires the grappler pass to take special care with its handling of
Enter ops that appear in front of reads from non-resource variables. See Enter ops that appear in front of reads from non-resource variables. See
the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc. the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc.
""" """
self._maybe_skip(mode)
if tf2.enabled(): if tf2.enabled():
# This test tests non-resource variables, which are only used in TF1. # This test tests non-resource variables, which are only used in TF1.
self.skipTest('TensorFlow 1 required') self.skipTest('TensorFlow 1 required')
if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(1234)
random_seed.set_random_seed(1234) np.random.seed(1234)
np.random.seed(1234) num_iter, bs, nchan, nclass = 100, 64, 32, 100
num_iter, bs, nchan, nclass = 100, 64, 32, 100
data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32) data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
labels = np.random.randint(nclass, size=(bs * num_iter,)) labels = np.random.randint(nclass, size=(bs * num_iter,))
ds = dataset_ops.Dataset.from_tensor_slices((data, labels)) ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
ds = ds.batch(bs).prefetch(3) ds = ds.batch(bs).prefetch(3)
it = ds.make_one_shot_iterator() it = ds.make_one_shot_iterator()
def body(_, i): def body(_, i):
i += 1 i += 1
x, yt = it.get_next() x, yt = it.get_next()
dense = layers.Dense(nclass) dense = layers.Dense(nclass)
y = dense(x) y = dense(x)
loss = losses.sparse_softmax_cross_entropy(yt, y) loss = losses.sparse_softmax_cross_entropy(yt, y)
opt = adam.AdamOptimizer() opt = adam.AdamOptimizer()
train_op = opt.minimize(loss, var_list=dense.trainable_weights) train_op = opt.minimize(loss, var_list=dense.trainable_weights)
with ops.control_dependencies([train_op]): with ops.control_dependencies([train_op]):
loss = array_ops.identity(loss) loss = array_ops.identity(loss)
return loss, i return loss, i
begin, end = constant_op.constant(0), constant_op.constant(num_iter) begin, end = constant_op.constant(0), constant_op.constant(num_iter)
loss, _ = control_flow_ops.while_loop( loss, _ = control_flow_ops.while_loop(
lambda loss, i: math_ops.less(i, end), body, [0.0, begin]) lambda loss, i: math_ops.less(i, end), body, [0.0, begin])
output_val_ref, output_val, cost_graph = self._run(loss) output_val_ref, output_val, cost_graph = self._run(mode, loss)
node_map = _build_node_map(cost_graph.node) node_map = _build_node_map(cost_graph.node)
self._assert_output_fp16(node_map, 'while/dense/MatMul') self._assert_output_f16(mode, node_map, 'while/dense/MatMul')
self._assert_output_fp16( self._assert_output_f16(
node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1') mode, node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1')
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3) self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
# TODO(benbarsdell): Add tests for list ops (TensorList*) that pass through # TODO(benbarsdell): Add tests for list ops (TensorList*) that pass through
# graph source/sink nodes, similar to the TensorListThroughFunction C++ test. # graph source/sink nodes, similar to the TensorListThroughFunction C++ test.