Add MKL supoprt to auto_mixed_precision.
This extends the auto mixed precision grappler pass to support converting nodes to bfloat16 on MKL-supported CPUs. Co-authored-by: Niranjan Hasabnis <niranjan.hasabnis@intel.com>
This commit is contained in:
parent
81041bcd82
commit
d8bfc935fd
@ -1,5 +1,5 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cc_test_mkl", "tf_kernel_library")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cc_test_mkl", "tf_kernel_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_copts", "tf_cuda_cc_test")
|
||||||
|
|
||||||
# Platform specific build config
|
# Platform specific build config
|
||||||
load(
|
load(
|
||||||
@ -7,6 +7,11 @@ load(
|
|||||||
"if_static",
|
"if_static",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//third_party/mkl:build_defs.bzl",
|
||||||
|
"mkl_deps",
|
||||||
|
)
|
||||||
|
|
||||||
package(
|
package(
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
@ -611,6 +616,7 @@ cc_library(
|
|||||||
"auto_mixed_precision_lists.h",
|
"auto_mixed_precision_lists.h",
|
||||||
],
|
],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
copts = tf_copts(),
|
||||||
deps = [
|
deps = [
|
||||||
":custom_graph_optimizer_registry",
|
":custom_graph_optimizer_registry",
|
||||||
":graph_optimizer",
|
":graph_optimizer",
|
||||||
@ -627,7 +633,7 @@ cc_library(
|
|||||||
"//tensorflow/core/grappler/costs:virtual_placer",
|
"//tensorflow/core/grappler/costs:virtual_placer",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
],
|
] + mkl_deps(),
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
|
|||||||
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
@ -52,6 +53,7 @@ const std::pair<int, int> kMinGPUArch = {0, 0};
|
|||||||
|
|
||||||
const char kSuffix[] = "AutoMixedPrecision";
|
const char kSuffix[] = "AutoMixedPrecision";
|
||||||
const char kCastToFp16[] = "CastToFp16";
|
const char kCastToFp16[] = "CastToFp16";
|
||||||
|
const char kCastToBf16[] = "CastToBf16";
|
||||||
const char kCastToFp32[] = "CastToFp32";
|
const char kCastToFp32[] = "CastToFp32";
|
||||||
|
|
||||||
// Instances of this class represent unique type attribute identifiers within a
|
// Instances of this class represent unique type attribute identifiers within a
|
||||||
@ -840,22 +842,6 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
|
|||||||
return AllowedDataTypes(*attr_def);
|
return AllowedDataTypes(*attr_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_fp16,
|
|
||||||
const string& device) {
|
|
||||||
const char* cast_string = to_fp16 ? kCastToFp16 : kCastToFp32;
|
|
||||||
string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
|
|
||||||
cast_string, "-", kSuffix);
|
|
||||||
NodeDef node;
|
|
||||||
node.set_name(name);
|
|
||||||
node.set_op("Cast");
|
|
||||||
node.set_device(device);
|
|
||||||
node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
|
|
||||||
(*node.mutable_attr())["SrcT"].set_type(to_fp16 ? DT_FLOAT : DT_HALF);
|
|
||||||
(*node.mutable_attr())["DstT"].set_type(to_fp16 ? DT_HALF : DT_FLOAT);
|
|
||||||
(*node.mutable_attr())["Truncate"].set_b(false);
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status ValidateLists(const gtl::FlatSet<string>& white_list,
|
Status ValidateLists(const gtl::FlatSet<string>& white_list,
|
||||||
const gtl::FlatSet<string>& black_list,
|
const gtl::FlatSet<string>& black_list,
|
||||||
const gtl::FlatSet<string>& gray_list,
|
const gtl::FlatSet<string>& gray_list,
|
||||||
@ -941,7 +927,8 @@ class AutoMixedPrecisionImpl {
|
|||||||
public:
|
public:
|
||||||
AutoMixedPrecisionImpl(Cluster* cluster,
|
AutoMixedPrecisionImpl(Cluster* cluster,
|
||||||
const std::unordered_set<string>& nodes_to_preserve,
|
const std::unordered_set<string>& nodes_to_preserve,
|
||||||
GraphDef* graph, string id)
|
GraphDef* graph, string id,
|
||||||
|
AutoMixedPrecisionMode mode)
|
||||||
: virtual_placer_(cluster->GetDevices()),
|
: virtual_placer_(cluster->GetDevices()),
|
||||||
nodes_to_preserve_(nodes_to_preserve),
|
nodes_to_preserve_(nodes_to_preserve),
|
||||||
graph_(graph),
|
graph_(graph),
|
||||||
@ -949,23 +936,35 @@ class AutoMixedPrecisionImpl {
|
|||||||
id_(id),
|
id_(id),
|
||||||
graph_view_(graph),
|
graph_view_(graph),
|
||||||
cuda_version_(GetCudaVersion(*cluster)),
|
cuda_version_(GetCudaVersion(*cluster)),
|
||||||
cudnn_version_(GetCudnnVersion(*cluster)) {}
|
cudnn_version_(GetCudnnVersion(*cluster)),
|
||||||
|
mode_(mode),
|
||||||
|
target_dtype_(mode_ == AutoMixedPrecisionMode::CUDA ? DT_HALF
|
||||||
|
: DT_BFLOAT16) {}
|
||||||
|
|
||||||
Status Optimize();
|
Status Optimize();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
|
typedef absl::flat_hash_set<NodeTypeId> NodeTypeIdSet;
|
||||||
|
|
||||||
|
std::unique_ptr<AutoMixedPrecisionLists> get_mixed_precision_lists() const {
|
||||||
|
switch (mode_) {
|
||||||
|
case AutoMixedPrecisionMode::CUDA:
|
||||||
|
return std::make_unique<AutoMixedPrecisionListsCuda>(cuda_version_,
|
||||||
|
cudnn_version_);
|
||||||
|
case AutoMixedPrecisionMode::MKL:
|
||||||
|
return std::make_unique<AutoMixedPrecisionListsMkl>();
|
||||||
|
}
|
||||||
|
}
|
||||||
Status PrintDebugLogs(bool preop, size_t timestamp);
|
Status PrintDebugLogs(bool preop, size_t timestamp);
|
||||||
void LogSkippedNode(const NodeDef& node) const;
|
void LogSkippedNode(const NodeDef& node) const;
|
||||||
bool MustPreserve(const NodeDef& node) const;
|
bool MustPreserve(const NodeDef& node) const;
|
||||||
bool IsOnGPU(const NodeDef& node) const;
|
bool IsOnDevice(const NodeDef& node, const string& device_type) const;
|
||||||
bool IsOnSuitableGPUArch(const NodeDef& node) const;
|
bool IsOnSuitableGPUArch(const NodeDef& node) const;
|
||||||
bool ShouldProcess(const NodeDef& node) const;
|
bool ShouldProcess(const NodeDef& node) const;
|
||||||
bool NodeHasFP16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
|
bool NodeHasF16KernelForTypeAttr(const NodeDef& node, TypeAttrId taid) const;
|
||||||
bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
|
bool NodeImplicitlyReadsNonResourceVariable(const NodeDef& node) const;
|
||||||
void ConvertBatchNormOpsToV2();
|
void ConvertBatchNormOpsToV2();
|
||||||
bool SupportsFloat16(const NodeTypeId& node_type) const;
|
bool SupportsF16(const NodeTypeId& node_type) const;
|
||||||
const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
|
const NodeTypeId* GetTensorListFloat32NodeTypeId(const NodeDef& node) const;
|
||||||
bool IsSourceOrSinkOp(const string& op) const;
|
bool IsSourceOrSinkOp(const string& op) const;
|
||||||
void FindFloat32TensorListOpClustersAndBlacklistUnsafe(
|
void FindFloat32TensorListOpClustersAndBlacklistUnsafe(
|
||||||
@ -990,6 +989,8 @@ class AutoMixedPrecisionImpl {
|
|||||||
absl::flat_hash_set<int>* white_set) const;
|
absl::flat_hash_set<int>* white_set) const;
|
||||||
void MakeCastsWhiteIfAllOutputsWhite(
|
void MakeCastsWhiteIfAllOutputsWhite(
|
||||||
absl::flat_hash_set<int>* white_set) const;
|
absl::flat_hash_set<int>* white_set) const;
|
||||||
|
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
|
||||||
|
const string& device) const;
|
||||||
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
|
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
|
||||||
|
|
||||||
VirtualPlacer virtual_placer_;
|
VirtualPlacer virtual_placer_;
|
||||||
@ -1003,21 +1004,44 @@ class AutoMixedPrecisionImpl {
|
|||||||
NodeTypeAttrMap node_type_map_;
|
NodeTypeAttrMap node_type_map_;
|
||||||
GraphTypeTopologyView graph_type_view_;
|
GraphTypeTopologyView graph_type_view_;
|
||||||
bool force_all_fp16_;
|
bool force_all_fp16_;
|
||||||
gtl::FlatSet<string> fp16_whitelist_;
|
AutoMixedPrecisionMode mode_;
|
||||||
gtl::FlatSet<string> fp16_blacklist_;
|
gtl::FlatSet<string> f16_whitelist_;
|
||||||
gtl::FlatSet<string> fp16_graylist_;
|
gtl::FlatSet<string> f16_blacklist_;
|
||||||
gtl::FlatSet<string> fp16_clearlist_;
|
gtl::FlatSet<string> f16_graylist_;
|
||||||
|
gtl::FlatSet<string> f16_clearlist_;
|
||||||
absl::flat_hash_set<const NodeDef*> should_process_nodes_;
|
absl::flat_hash_set<const NodeDef*> should_process_nodes_;
|
||||||
|
DataType target_dtype_; // Either DT_HALF or DT_BFLOAT16
|
||||||
};
|
};
|
||||||
|
|
||||||
bool AutoMixedPrecisionImpl::NodeHasFP16KernelForTypeAttr(
|
NodeDef AutoMixedPrecisionImpl::BuildCastNode(
|
||||||
|
const MutableGraphView::OutputPort& src, bool to_f16,
|
||||||
|
const string& device) const {
|
||||||
|
DataType src_type = to_f16 ? DT_FLOAT : target_dtype_;
|
||||||
|
DataType dst_type = to_f16 ? target_dtype_ : DT_FLOAT;
|
||||||
|
const char* cast_string =
|
||||||
|
!to_f16 ? kCastToFp32
|
||||||
|
: target_dtype_ == DT_HALF ? kCastToFp16 : kCastToBf16;
|
||||||
|
string name = strings::StrCat(src.node->name(), "-", src.port_id, "-",
|
||||||
|
cast_string, "-", kSuffix);
|
||||||
|
NodeDef node;
|
||||||
|
node.set_name(name);
|
||||||
|
node.set_op("Cast");
|
||||||
|
node.set_device(device);
|
||||||
|
node.add_input(strings::StrCat(src.node->name(), ":", src.port_id));
|
||||||
|
(*node.mutable_attr())["SrcT"].set_type(src_type);
|
||||||
|
(*node.mutable_attr())["DstT"].set_type(dst_type);
|
||||||
|
(*node.mutable_attr())["Truncate"].set_b(false);
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AutoMixedPrecisionImpl::NodeHasF16KernelForTypeAttr(
|
||||||
const NodeDef& node, TypeAttrId taid) const {
|
const NodeDef& node, TypeAttrId taid) const {
|
||||||
NodeDef node_copy(node);
|
NodeDef node_copy(node);
|
||||||
if (node.device().empty()) {
|
if (node.device().empty()) {
|
||||||
string device_name = virtual_placer_.get_canonical_device_name(node);
|
string device_name = virtual_placer_.get_canonical_device_name(node);
|
||||||
node_copy.set_device(device_name);
|
node_copy.set_device(device_name);
|
||||||
}
|
}
|
||||||
if (!SetDataType(&node_copy, taid, DataType::DT_HALF)) {
|
if (!SetDataType(&node_copy, taid, target_dtype_)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return IsKernelRegisteredForNode(node_copy).ok();
|
return IsKernelRegisteredForNode(node_copy).ok();
|
||||||
@ -1053,21 +1077,22 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
|
|||||||
fname = io::JoinPath(prepend_path,
|
fname = io::JoinPath(prepend_path,
|
||||||
strings::StrCat("paintbuckets", suffix, ".txt"));
|
strings::StrCat("paintbuckets", suffix, ".txt"));
|
||||||
f.open(fname.c_str(), std::fstream::out);
|
f.open(fname.c_str(), std::fstream::out);
|
||||||
|
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||||
|
get_mixed_precision_lists();
|
||||||
f << "WhiteList:\n";
|
f << "WhiteList:\n";
|
||||||
for (const auto& x :
|
for (const auto& x : mp_lists->WhiteList()) {
|
||||||
AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_)) {
|
|
||||||
f << x << "\n";
|
f << x << "\n";
|
||||||
}
|
}
|
||||||
f << "\nBlackList:\n";
|
f << "\nBlackList:\n";
|
||||||
for (const auto& x : AutoMixedPrecisionLists::BlackList()) {
|
for (const auto& x : mp_lists->BlackList()) {
|
||||||
f << x << "\n";
|
f << x << "\n";
|
||||||
}
|
}
|
||||||
f << "\nGrayList:\n";
|
f << "\nGrayList:\n";
|
||||||
for (const auto& x : AutoMixedPrecisionLists::GrayList()) {
|
for (const auto& x : mp_lists->GrayList()) {
|
||||||
f << x << "\n";
|
f << x << "\n";
|
||||||
}
|
}
|
||||||
f << "\nClearList:\n";
|
f << "\nClearList:\n";
|
||||||
for (const auto& x : AutoMixedPrecisionLists::ClearList()) {
|
for (const auto& x : mp_lists->ClearList()) {
|
||||||
f << x << "\n";
|
f << x << "\n";
|
||||||
}
|
}
|
||||||
f.close();
|
f.close();
|
||||||
@ -1088,7 +1113,8 @@ bool AutoMixedPrecisionImpl::MustPreserve(const NodeDef& node) const {
|
|||||||
return nodes_to_preserve_.count(node.name());
|
return nodes_to_preserve_.count(node.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const {
|
bool AutoMixedPrecisionImpl::IsOnDevice(const NodeDef& node,
|
||||||
|
const string& device_type) const {
|
||||||
string device_name;
|
string device_name;
|
||||||
if (node.device().empty()) {
|
if (node.device().empty()) {
|
||||||
device_name = virtual_placer_.get_canonical_device_name(node);
|
device_name = virtual_placer_.get_canonical_device_name(node);
|
||||||
@ -1099,7 +1125,7 @@ bool AutoMixedPrecisionImpl::IsOnGPU(const NodeDef& node) const {
|
|||||||
string not_used;
|
string not_used;
|
||||||
if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
|
if (DeviceNameUtils::SplitDeviceName(device_name, ¬_used, &device) &&
|
||||||
absl::StrContains(absl::AsciiStrToLower(device),
|
absl::StrContains(absl::AsciiStrToLower(device),
|
||||||
absl::AsciiStrToLower(DEVICE_GPU))) {
|
absl::AsciiStrToLower(device_type))) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
@ -1164,15 +1190,14 @@ bool IsTensorListWriterOp(const string& op) {
|
|||||||
return tensor_list_writer_ops.count(op);
|
return tensor_list_writer_ops.count(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AutoMixedPrecisionImpl::SupportsFloat16(
|
bool AutoMixedPrecisionImpl::SupportsF16(const NodeTypeId& node_type) const {
|
||||||
const NodeTypeId& node_type) const {
|
|
||||||
const OpDef* op_def;
|
const OpDef* op_def;
|
||||||
Status status =
|
Status status =
|
||||||
OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
|
OpRegistry::Global()->LookUpOpDef(node_type.node->op(), &op_def);
|
||||||
if (!status.ok()) return false;
|
if (!status.ok()) return false;
|
||||||
return AllowedDataTypes(*op_def, node_type.type_attr)
|
return AllowedDataTypes(*op_def, node_type.type_attr)
|
||||||
.Contains(DataType::DT_HALF) &&
|
.Contains(target_dtype_) &&
|
||||||
NodeHasFP16KernelForTypeAttr(*node_type.node, node_type.type_attr);
|
NodeHasF16KernelForTypeAttr(*node_type.node, node_type.type_attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(mconley): Make this change the node's name (to aid debugging). Need to
|
// TODO(mconley): Make this change the node's name (to aid debugging). Need to
|
||||||
@ -1219,22 +1244,40 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
|||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
|
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level));
|
||||||
optimization_level = absl::AsciiStrToUpper(optimization_level);
|
optimization_level = absl::AsciiStrToUpper(optimization_level);
|
||||||
force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
|
force_all_fp16_ = optimization_level == "UNSAFE_FORCE_ALL";
|
||||||
|
if (force_all_fp16_ && mode_ == AutoMixedPrecisionMode::MKL) {
|
||||||
|
// Many ops do not support bfloat16 on the CPU so we disallowing forcing to
|
||||||
|
// bfloat16.
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL cannot be set to "
|
||||||
|
"UNSAFE_FORCE_ALL when MKL is used");
|
||||||
|
}
|
||||||
|
|
||||||
fp16_whitelist_ =
|
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||||
AutoMixedPrecisionLists::WhiteList(cuda_version_, cudnn_version_);
|
get_mixed_precision_lists();
|
||||||
fp16_blacklist_ = AutoMixedPrecisionLists::BlackList();
|
f16_whitelist_ = mp_lists->WhiteList();
|
||||||
fp16_graylist_ = AutoMixedPrecisionLists::GrayList();
|
f16_blacklist_ = mp_lists->BlackList();
|
||||||
fp16_clearlist_ = AutoMixedPrecisionLists::ClearList();
|
f16_graylist_ = mp_lists->GrayList();
|
||||||
TF_RETURN_IF_ERROR(ValidateLists(fp16_whitelist_, fp16_blacklist_,
|
f16_clearlist_ = mp_lists->ClearList();
|
||||||
fp16_graylist_, fp16_clearlist_));
|
TF_RETURN_IF_ERROR(ValidateLists(f16_whitelist_, f16_blacklist_,
|
||||||
|
f16_graylist_, f16_clearlist_));
|
||||||
|
|
||||||
size_t timestamp = Env::Default()->NowMicros() / 1000;
|
size_t timestamp = Env::Default()->NowMicros() / 1000;
|
||||||
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
|
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ true, timestamp));
|
||||||
|
|
||||||
VLOG(2) << "Identifying nodes that should be processed";
|
VLOG(2) << "Identifying nodes that should be processed";
|
||||||
for (const NodeDef& node : graph_->node()) {
|
for (const NodeDef& node : graph_->node()) {
|
||||||
if (!MustPreserve(node) && IsOnGPU(node) &&
|
bool should_process;
|
||||||
(ShouldIgnorePerformance() || IsOnSuitableGPUArch(node))) {
|
switch (mode_) {
|
||||||
|
case AutoMixedPrecisionMode::CUDA:
|
||||||
|
should_process =
|
||||||
|
!MustPreserve(node) && IsOnDevice(node, DEVICE_GPU) &&
|
||||||
|
(ShouldIgnorePerformance() || IsOnSuitableGPUArch(node));
|
||||||
|
break;
|
||||||
|
case AutoMixedPrecisionMode::MKL:
|
||||||
|
should_process = !MustPreserve(node) && IsOnDevice(node, DEVICE_CPU);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (should_process) {
|
||||||
should_process_nodes_.insert(&node);
|
should_process_nodes_.insert(&node);
|
||||||
} else {
|
} else {
|
||||||
LogSkippedNode(node);
|
LogSkippedNode(node);
|
||||||
@ -1260,29 +1303,29 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
|||||||
for (const auto& cluster : tensor_list_clusters) {
|
for (const auto& cluster : tensor_list_clusters) {
|
||||||
VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
|
VLOG(1) << "Found safe Tensor List cluster of size " << cluster.size();
|
||||||
for (const NodeDef* node : cluster) {
|
for (const NodeDef* node : cluster) {
|
||||||
VLOG(2) << "Cluster member: " << node->op() << " node " << node->name();
|
VLOG(2) << " Cluster member: " << node->op() << " node " << node->name();
|
||||||
}
|
}
|
||||||
FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
|
FindTensorListImplicitFloat32Edges(cluster, &ephemeral_edges);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
|
TF_RETURN_IF_ERROR(graph_type_view_.AddEphemeralEdges(ephemeral_edges));
|
||||||
|
|
||||||
// The goal here is to change performance-critical ops to fp16, and to do so
|
// The goal here is to change performance-critical ops to fp16 or bf16, and to
|
||||||
// with the minimal number of casts, subject to the constraint that the
|
// do so with the minimal number of casts, subject to the constraint that the
|
||||||
// model's convergence is not affected. This is achieved by first identifying
|
// model's convergence is not affected. This is achieved by first identifying
|
||||||
// which nodes should be changed to fp16 and then inserting casts at the
|
// which nodes should be changed to f16 and then inserting casts at the
|
||||||
// boundaries between fp16/non-fp16 nodes.
|
// boundaries between f16/non-f16 nodes.
|
||||||
|
|
||||||
// The algorithm for deciding which nodes to change to fp16 is as follows:
|
// The algorithm for deciding which nodes to change to f16 is as follows:
|
||||||
// 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
|
// 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
|
||||||
// This is done under the assumption that whitelist ops are always
|
// This is done under the assumption that whitelist ops are always
|
||||||
// numerically-safe in fp16 and that they are the most important ops for
|
// numerically-safe in f16 and that they are the most important ops for
|
||||||
// improving performance.
|
// improving performance.
|
||||||
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
|
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
|
||||||
// "blacklist" ops) or they are on a forward path from a blacklist node to
|
// "blacklist" ops) or they are on a forward path from a blacklist node to
|
||||||
// a black/gray node (including the node at the end of the path) through
|
// a black/gray node (including the node at the end of the path) through
|
||||||
// non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
|
// non-numerically-dangerous ops (aka "greylist" and "clearlist" ops).
|
||||||
// This is done to prevent numerically-dangerous ops and their downstream
|
// This is done to prevent numerically-dangerous ops and their downstream
|
||||||
// effects from being changed to fp16, which would risk breaking the
|
// effects from being changed to f16, which would risk breaking the
|
||||||
// numerical accuracy of the model.
|
// numerical accuracy of the model.
|
||||||
// 3) For all remaining nodes that are not considered dangerous (greylist
|
// 3) For all remaining nodes that are not considered dangerous (greylist
|
||||||
// and clearlist ops), find those that are between (i.e., both upstream
|
// and clearlist ops), find those that are between (i.e., both upstream
|
||||||
@ -1480,7 +1523,7 @@ void AutoMixedPrecisionImpl::AddWhitelistOps(
|
|||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!ShouldProcess(*root.node)) continue;
|
if (!ShouldProcess(*root.node)) continue;
|
||||||
bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
|
bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
|
||||||
if (fp16_whitelist_.count(root.node->op()) || force_white) {
|
if (f16_whitelist_.count(root.node->op()) || force_white) {
|
||||||
bool inserted = white_set->insert(root_idx).second;
|
bool inserted = white_set->insert(root_idx).second;
|
||||||
if (VLOG_IS_ON(2) && inserted) {
|
if (VLOG_IS_ON(2) && inserted) {
|
||||||
VLOG(2) << "Painting type " << root.type_attr.DebugString()
|
VLOG(2) << "Painting type " << root.type_attr.DebugString()
|
||||||
@ -1504,8 +1547,8 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
|||||||
absl::flat_hash_set<int> upstream_of_black_or_gray_set;
|
absl::flat_hash_set<int> upstream_of_black_or_gray_set;
|
||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!(fp16_blacklist_.count(root.node->op()) ||
|
if (!(f16_blacklist_.count(root.node->op()) ||
|
||||||
fp16_graylist_.count(root.node->op()))) {
|
f16_graylist_.count(root.node->op()))) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DfsTypeTraversal(graph_type_view_, {&root},
|
DfsTypeTraversal(graph_type_view_, {&root},
|
||||||
@ -1514,7 +1557,7 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
|||||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||||
return idx == root_idx ||
|
return idx == root_idx ||
|
||||||
(!upstream_of_black_or_gray_set.count(idx) &&
|
(!upstream_of_black_or_gray_set.count(idx) &&
|
||||||
fp16_clearlist_.count(item.node->op()));
|
f16_clearlist_.count(item.node->op()));
|
||||||
}),
|
}),
|
||||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||||
upstream_of_black_or_gray_set.insert(idx);
|
upstream_of_black_or_gray_set.insert(idx);
|
||||||
@ -1524,7 +1567,7 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
|||||||
// Propagate black forward through nodes in upstream_of_black_or_gray_set.
|
// Propagate black forward through nodes in upstream_of_black_or_gray_set.
|
||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (black_set->count(root_idx) || !fp16_blacklist_.count(root.node->op())) {
|
if (black_set->count(root_idx) || !f16_blacklist_.count(root.node->op())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DfsTypeTraversal(
|
DfsTypeTraversal(
|
||||||
@ -1552,7 +1595,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
|||||||
absl::flat_hash_set<int> downstream_of_white_set;
|
absl::flat_hash_set<int> downstream_of_white_set;
|
||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!ShouldProcess(*root.node) || !fp16_whitelist_.count(root.node->op())) {
|
if (!ShouldProcess(*root.node) || !f16_whitelist_.count(root.node->op())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DfsTypeTraversal(
|
DfsTypeTraversal(
|
||||||
@ -1561,14 +1604,14 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
|||||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||||
return idx == root_idx ||
|
return idx == root_idx ||
|
||||||
(!downstream_of_white_set.count(idx) &&
|
(!downstream_of_white_set.count(idx) &&
|
||||||
!fp16_whitelist_.count(item.node->op()) &&
|
!f16_whitelist_.count(item.node->op()) &&
|
||||||
!black_set.count(idx) && ShouldProcess(*item.node) &&
|
!black_set.count(idx) && ShouldProcess(*item.node) &&
|
||||||
// TODO(benbarsdell): Consider allowing propagation through
|
// TODO(benbarsdell): Consider allowing propagation through
|
||||||
// ops that are already float16 in order to reduce the number
|
// ops that are already float16 in order to reduce the number
|
||||||
// of casts.
|
// of casts.
|
||||||
IsFloat32(item) && SupportsFloat16(item) &&
|
IsFloat32(item) && SupportsF16(item) &&
|
||||||
(fp16_clearlist_.count(item.node->op()) ||
|
(f16_clearlist_.count(item.node->op()) ||
|
||||||
fp16_graylist_.count(item.node->op())));
|
f16_graylist_.count(item.node->op())));
|
||||||
}),
|
}),
|
||||||
DfsTypeCallbacks::PreOrder(
|
DfsTypeCallbacks::PreOrder(
|
||||||
[&](int idx) { downstream_of_white_set.insert(idx); }));
|
[&](int idx) { downstream_of_white_set.insert(idx); }));
|
||||||
@ -1579,7 +1622,7 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
|||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
|
if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
|
||||||
!fp16_whitelist_.count(root.node->op())) {
|
!f16_whitelist_.count(root.node->op())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DfsTypeTraversal(
|
DfsTypeTraversal(
|
||||||
@ -1620,8 +1663,8 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
|||||||
return idx == root_idx ||
|
return idx == root_idx ||
|
||||||
(!white_set->count(idx) && !black_set.count(idx) &&
|
(!white_set->count(idx) && !black_set.count(idx) &&
|
||||||
ShouldProcess(*item.node) && IsFloat32(item) &&
|
ShouldProcess(*item.node) && IsFloat32(item) &&
|
||||||
SupportsFloat16(item) &&
|
SupportsF16(item) &&
|
||||||
(fp16_clearlist_.count(item.node->op())) &&
|
(f16_clearlist_.count(item.node->op())) &&
|
||||||
// We don't propagate (backwards) through nodes that read
|
// We don't propagate (backwards) through nodes that read
|
||||||
// Variables because it can break the behavior of TensorBoard
|
// Variables because it can break the behavior of TensorBoard
|
||||||
// visualization and/or (in the case of Enter nodes) the model
|
// visualization and/or (in the case of Enter nodes) the model
|
||||||
@ -1806,13 +1849,13 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Changes all white-painted type attributes to DT_HALF, and inserts Cast nodes
|
// Changes all white-painted type attributes to DT_HALF or DT_BFLOAT16, and
|
||||||
// at node outputs for all edges that connect white-painted <->
|
// inserts Cast nodes at node outputs for all edges that connect
|
||||||
// non-white-painted type attributes.
|
// white-painted <-> non-white-painted type attributes.
|
||||||
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||||
const absl::flat_hash_set<int>& white_set) {
|
const absl::flat_hash_set<int>& white_set) {
|
||||||
int num_nodes_changed = 0;
|
int num_nodes_changed = 0;
|
||||||
int num_nonvar_casts_to_fp16 = 0;
|
int num_nonvar_casts_to_f16 = 0;
|
||||||
int num_nodes_preop = graph_->node_size();
|
int num_nodes_preop = graph_->node_size();
|
||||||
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
|
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
|
||||||
NodeDef* node = graph_->mutable_node(node_idx);
|
NodeDef* node = graph_->mutable_node(node_idx);
|
||||||
@ -1829,8 +1872,9 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
|||||||
bool src_is_white = white_set.count(node_type_idx);
|
bool src_is_white = white_set.count(node_type_idx);
|
||||||
if (src_is_white) {
|
if (src_is_white) {
|
||||||
VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
|
VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
|
||||||
<< node->op() << " node " << node->name() << " to DT_HALF";
|
<< node->op() << " node " << node->name() << " to "
|
||||||
if (!SetDataType(node, type_attr, DT_HALF)) {
|
<< DataTypeString(target_dtype_);
|
||||||
|
if (!SetDataType(node, type_attr, target_dtype_)) {
|
||||||
return errors::Internal("Failed to set type attribute");
|
return errors::Internal("Failed to set type attribute");
|
||||||
}
|
}
|
||||||
++num_nodes_changed;
|
++num_nodes_changed;
|
||||||
@ -1855,16 +1899,16 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
|||||||
bool dst_is_white = white_set.count(dst_type_idx);
|
bool dst_is_white = white_set.count(dst_type_idx);
|
||||||
if (src_is_white != dst_is_white) {
|
if (src_is_white != dst_is_white) {
|
||||||
if (!added_cast_node) {
|
if (!added_cast_node) {
|
||||||
bool to_fp16 = dst_is_white;
|
bool to_f16 = dst_is_white;
|
||||||
VLOG(1) << "Inserting cast to "
|
VLOG(1) << "Inserting cast to "
|
||||||
<< (to_fp16 ? "DT_HALF" : "DT_FLOAT") << " at "
|
<< (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
|
||||||
<< src.node->op() << " " << src.node->name() << ":"
|
<< " at " << src.node->op() << " " << src.node->name()
|
||||||
<< src.port_id;
|
<< ":" << src.port_id;
|
||||||
added_cast_node = graph_view_.AddNode(
|
added_cast_node = graph_view_.AddNode(
|
||||||
BuildCastNode(src, to_fp16, src.node->device()));
|
BuildCastNode(src, to_f16, src.node->device()));
|
||||||
if (to_fp16 && !IsConstant(*node) && !IsVariable(*node) &&
|
if (to_f16 && !IsConstant(*node) && !IsVariable(*node) &&
|
||||||
!NodeImplicitlyReadsNonResourceVariable(*node)) {
|
!NodeImplicitlyReadsNonResourceVariable(*node)) {
|
||||||
++num_nonvar_casts_to_fp16;
|
++num_nonvar_casts_to_f16;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
|
TF_RETURN_IF_ERROR(graph_view_.UpdateRegularFaninByPort(
|
||||||
@ -1874,9 +1918,13 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Use Python type names (e.g. float16) instead of C++ type names (e.g. half)
|
||||||
|
// since many Python users will see this message.
|
||||||
|
const char* type_str = target_dtype_ == DT_HALF ? "float16" : "bfloat16";
|
||||||
LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
|
LOG(INFO) << "Converted " << num_nodes_changed << "/" << num_nodes_preop
|
||||||
<< " nodes to float16 precision using " << num_nonvar_casts_to_fp16
|
<< " nodes to " << type_str << " precision using "
|
||||||
<< " cast(s) to float16 (excluding Const and Variable casts)";
|
<< num_nonvar_casts_to_f16 << " cast(s) to " << type_str
|
||||||
|
<< " (excluding Const and Variable casts)";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1902,12 +1950,23 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
return errors::InvalidArgument("cluster == nullptr");
|
return errors::InvalidArgument("cluster == nullptr");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(INTEL_MKL) || !defined(ENABLE_INTEL_MKL_BFLOAT16)
|
||||||
|
if (mode_ == AutoMixedPrecisionMode::MKL) {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"The auto_mixed_precision_mkl optimizer cannot be used since "
|
||||||
|
"this build of TensorFlow is not compiled with MKL support for bfloat16. "
|
||||||
|
"For information on MKL builds, see: "
|
||||||
|
"https://software.intel.com/en-us/articles/intel-optimization-for-"
|
||||||
|
"tensorflow-installation-guide");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Start by copying input graph to output.
|
// Start by copying input graph to output.
|
||||||
*output = item.graph;
|
*output = item.graph;
|
||||||
|
|
||||||
int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
|
int num_gpus = ShouldIgnorePerformance() ? GetNumGPUs(*cluster)
|
||||||
: GetNumGPUs(*cluster, kMinGPUArch);
|
: GetNumGPUs(*cluster, kMinGPUArch);
|
||||||
if (num_gpus < 1) {
|
if (num_gpus < 1 && mode_ == AutoMixedPrecisionMode::CUDA) {
|
||||||
// AutoMixedPrecision is currently only tuned for GPU.
|
// AutoMixedPrecision is currently only tuned for GPU.
|
||||||
LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
|
LOG(WARNING) << "No (suitable) GPUs detected, skipping " << name()
|
||||||
<< " graph optimizer";
|
<< " graph optimizer";
|
||||||
@ -1916,7 +1975,7 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// Optimize the output graph in-place.
|
// Optimize the output graph in-place.
|
||||||
AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
|
AutoMixedPrecisionImpl optimizer(cluster, item.NodesToPreserve(), output,
|
||||||
item.id);
|
item.id, mode_);
|
||||||
if (item.id == "tf_graph") {
|
if (item.id == "tf_graph") {
|
||||||
LOG(INFO) << "Running " << name() << " graph optimizer";
|
LOG(INFO) << "Running " << name() << " graph optimizer";
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -22,16 +22,25 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
// Convert data types to float16 where appropriate to improve performance on
|
enum class AutoMixedPrecisionMode { CUDA, MKL };
|
||||||
// GPUs.
|
|
||||||
|
// Convert data types to float16 or bfloat16 where appropriate to improve
|
||||||
|
// performance on GPUs or CPUs.
|
||||||
class AutoMixedPrecision : public GraphOptimizer {
|
class AutoMixedPrecision : public GraphOptimizer {
|
||||||
public:
|
public:
|
||||||
|
// If 'mode' is CUDA, converts nodes to float16 on Nvidia GPUs. If MKL,
|
||||||
|
// converts nodes to bfloat16 on CPUs in order to take advantage of MKL
|
||||||
|
// performance improvements with bfloat16.
|
||||||
explicit AutoMixedPrecision(
|
explicit AutoMixedPrecision(
|
||||||
RewriterConfig::Toggle opt_level = RewriterConfig::ON) {}
|
AutoMixedPrecisionMode mode = AutoMixedPrecisionMode::CUDA)
|
||||||
|
: mode_(mode) {}
|
||||||
|
|
||||||
~AutoMixedPrecision() override {}
|
~AutoMixedPrecision() override {}
|
||||||
|
|
||||||
string name() const override { return "auto_mixed_precision"; };
|
string name() const override {
|
||||||
|
return mode_ == AutoMixedPrecisionMode::CUDA ? "auto_mixed_precision_cuda"
|
||||||
|
: "auto_mixed_precision_mkl";
|
||||||
|
};
|
||||||
|
|
||||||
bool UsesFunctionLibrary() const override { return false; }
|
bool UsesFunctionLibrary() const override { return false; }
|
||||||
|
|
||||||
@ -40,6 +49,9 @@ class AutoMixedPrecision : public GraphOptimizer {
|
|||||||
|
|
||||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||||
const GraphDef& optimize_output, double result) override;
|
const GraphDef& optimize_output, double result) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const AutoMixedPrecisionMode mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
|
|||||||
@ -23,10 +23,44 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
// Represents the four lists of ops: the white list, gray list, black list, and
|
||||||
|
// clear list. These lists determine which ops are converted to fp16/bf16
|
||||||
|
// (referred to as 'f16' for short) and which ops stay as fp32.
|
||||||
class AutoMixedPrecisionLists {
|
class AutoMixedPrecisionLists {
|
||||||
private:
|
public:
|
||||||
static void UpdateList(gtl::FlatSet<string>* list, const string& to_add,
|
|
||||||
const string& to_remove) {
|
virtual ~AutoMixedPrecisionLists() {}
|
||||||
|
|
||||||
|
// Returns the set of ops that are considered numerically-safe (for execution
|
||||||
|
// in f16), performance-critical, and can run in f16. These ops are always
|
||||||
|
// converted to f16.
|
||||||
|
virtual gtl::FlatSet<string> WhiteList() = 0;
|
||||||
|
// Returns the set of ops that can run in f16 and are considered numerically-
|
||||||
|
// safe (for execution in f16), but which may be made unsafe by an upstream
|
||||||
|
// blacklist op.
|
||||||
|
virtual gtl::FlatSet<string> GrayList() = 0;
|
||||||
|
// Returns the set of ops that are considered numerically-dangerous (i.e.,
|
||||||
|
// unsafe for execution in f16) and whose effects may also be observed in
|
||||||
|
// downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
|
||||||
|
// the Exp).
|
||||||
|
virtual gtl::FlatSet<string> BlackList() = 0;
|
||||||
|
// Returns the set of ops that do not have numerically-significant effects
|
||||||
|
// (i.e., they are always considered safe for execution in f16 precision), and
|
||||||
|
// can run in f16.
|
||||||
|
virtual gtl::FlatSet<string> ClearList() = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Adds or removes ops from list if certain environmental variables are set.
|
||||||
|
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
|
||||||
|
CHECK(list_name == "WHITELIST" || list_name == "GRAYLIST" || // Crash OK.
|
||||||
|
list_name == "BLACKLIST" || list_name == "CLEARLIST");
|
||||||
|
string add_env_var =
|
||||||
|
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
|
||||||
|
string remove_env_var =
|
||||||
|
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE";
|
||||||
|
string to_add, to_remove;
|
||||||
|
TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add));
|
||||||
|
TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove));
|
||||||
for (const auto& x : str_util::Split(to_add, ",")) {
|
for (const auto& x : str_util::Split(to_add, ",")) {
|
||||||
list->insert(x);
|
list->insert(x);
|
||||||
}
|
}
|
||||||
@ -35,6 +69,35 @@ class AutoMixedPrecisionLists {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Subclasses should include these on the ClearList.
|
||||||
|
static void AddTensorListOps(gtl::FlatSet<string>* list) {
|
||||||
|
// Note: if a data structure op (such as TensorListPopBack) is added here,
|
||||||
|
// IsTensorListReaderOp or IsTensorListWriterOp may need to be modified
|
||||||
|
constexpr char* tensor_list_ops[] = {
|
||||||
|
"TensorListConcat",
|
||||||
|
"TensorListConcatLists",
|
||||||
|
"TensorListConcatV2",
|
||||||
|
"TensorListGather",
|
||||||
|
"TensorListGetItem",
|
||||||
|
"TensorListPopBack",
|
||||||
|
"TensorListPushBack",
|
||||||
|
"TensorListPushBackBatch",
|
||||||
|
"TensorListFromTensor",
|
||||||
|
"TensorListScatter",
|
||||||
|
"TensorListScatterV2",
|
||||||
|
"TensorListScatterIntoExistingList",
|
||||||
|
"TensorListSetItem",
|
||||||
|
"TensorListSplit",
|
||||||
|
"TensorListStack"
|
||||||
|
};
|
||||||
|
for (auto op : tensor_list_ops) {
|
||||||
|
list->insert(op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||||
|
private:
|
||||||
static bool IsPseudoFastMath() {
|
static bool IsPseudoFastMath() {
|
||||||
string optimization_level;
|
string optimization_level;
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
@ -45,16 +108,10 @@ class AutoMixedPrecisionLists {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Returns the set of ops that are considered numerically-safe (for execution
|
AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
|
||||||
// in fp16) and performance-critical. These ops are always converted to fp16.
|
: cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
|
||||||
static gtl::FlatSet<string> WhiteList(int cuda_version, int cudnn_version) {
|
|
||||||
string to_add, to_remove;
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_ADD", "", &to_add));
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_WHITELIST_REMOVE", "",
|
|
||||||
&to_remove));
|
|
||||||
|
|
||||||
|
gtl::FlatSet<string> WhiteList() override {
|
||||||
auto list = gtl::FlatSet<string>{
|
auto list = gtl::FlatSet<string>{
|
||||||
"BlockLSTM",
|
"BlockLSTM",
|
||||||
"BlockLSTMV2",
|
"BlockLSTMV2",
|
||||||
@ -81,12 +138,12 @@ class AutoMixedPrecisionLists {
|
|||||||
// "DepthwiseConv2dNativeBackpropInput",
|
// "DepthwiseConv2dNativeBackpropInput",
|
||||||
"MatMul",
|
"MatMul",
|
||||||
};
|
};
|
||||||
if (cuda_version >= 9010) {
|
if (cuda_version_ >= 9010) {
|
||||||
// Fp16 BatchMatMul is slow before CUDA 9.1.
|
// Fp16 BatchMatMul is slow before CUDA 9.1.
|
||||||
list.insert("BatchMatMul");
|
list.insert("BatchMatMul");
|
||||||
list.insert("BatchMatMulV2");
|
list.insert("BatchMatMulV2");
|
||||||
}
|
}
|
||||||
if (cudnn_version >= 7602) {
|
if (cudnn_version_ >= 7602) {
|
||||||
// Fp16 3D conv is slow before CUDNN 7.6.2.
|
// Fp16 3D conv is slow before CUDNN 7.6.2.
|
||||||
list.insert("Conv3D");
|
list.insert("Conv3D");
|
||||||
list.insert("Conv3DBackpropFilter");
|
list.insert("Conv3DBackpropFilter");
|
||||||
@ -94,22 +151,14 @@ class AutoMixedPrecisionLists {
|
|||||||
list.insert("Conv3DBackpropInput");
|
list.insert("Conv3DBackpropInput");
|
||||||
list.insert("Conv3DBackpropInputV2");
|
list.insert("Conv3DBackpropInputV2");
|
||||||
}
|
}
|
||||||
UpdateList(&list, to_add, to_remove);
|
UpdateList("WHITELIST", &list);
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the set of ops that are considered numerically-safe (for execution
|
gtl::FlatSet<string> GrayList() override {
|
||||||
// in fp16), but which may be made unsafe by an upstream blacklist op.
|
|
||||||
static gtl::FlatSet<string> GrayList() {
|
|
||||||
if (IsPseudoFastMath()) {
|
if (IsPseudoFastMath()) {
|
||||||
return gtl::FlatSet<string>{};
|
return gtl::FlatSet<string>{};
|
||||||
}
|
}
|
||||||
string to_add, to_remove;
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_ADD", "", &to_add));
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_GRAYLIST_REMOVE", "",
|
|
||||||
&to_remove));
|
|
||||||
|
|
||||||
auto list = gtl::FlatSet<string>{
|
auto list = gtl::FlatSet<string>{
|
||||||
"Add",
|
"Add",
|
||||||
@ -156,23 +205,14 @@ class AutoMixedPrecisionLists {
|
|||||||
"Tanh",
|
"Tanh",
|
||||||
"TanhGrad",
|
"TanhGrad",
|
||||||
};
|
};
|
||||||
UpdateList(&list, to_add, to_remove);
|
UpdateList("GRAYLIST", &list);
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the set of ops that are considered numerically-dangerous (i.e.,
|
gtl::FlatSet<string> BlackList() override {
|
||||||
// unsafe for execution in fp16) and whose effects may also be observed in
|
|
||||||
// downstream nodes (e.g., in Exp -> Add, the Add is unsafe due to the Exp).
|
|
||||||
static gtl::FlatSet<string> BlackList() {
|
|
||||||
if (IsPseudoFastMath()) {
|
if (IsPseudoFastMath()) {
|
||||||
return gtl::FlatSet<string>{};
|
return gtl::FlatSet<string>{};
|
||||||
}
|
}
|
||||||
string to_add, to_remove;
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_ADD", "", &to_add));
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_BLACKLIST_REMOVE", "",
|
|
||||||
&to_remove));
|
|
||||||
|
|
||||||
auto list = gtl::FlatSet<string>{
|
auto list = gtl::FlatSet<string>{
|
||||||
"Exp",
|
"Exp",
|
||||||
@ -185,22 +225,14 @@ class AutoMixedPrecisionLists {
|
|||||||
"SparseSoftmaxCrossEntropyWithLogits",
|
"SparseSoftmaxCrossEntropyWithLogits",
|
||||||
"Sum",
|
"Sum",
|
||||||
};
|
};
|
||||||
UpdateList(&list, to_add, to_remove);
|
UpdateList("BLACKLIST", &list);
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the set of ops that do not have numerically-significant effects
|
gtl::FlatSet<string> ClearList() override {
|
||||||
// (i.e., they are always considered safe for execution in fp16 precision).
|
|
||||||
static gtl::FlatSet<string> ClearList() {
|
|
||||||
if (IsPseudoFastMath()) {
|
if (IsPseudoFastMath()) {
|
||||||
return gtl::FlatSet<string>{};
|
return gtl::FlatSet<string>{};
|
||||||
}
|
}
|
||||||
string to_add, to_remove;
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_ADD", "", &to_add));
|
|
||||||
TF_CHECK_OK(ReadStringFromEnvVar(
|
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_CLEARLIST_REMOVE", "",
|
|
||||||
&to_remove));
|
|
||||||
|
|
||||||
auto list = gtl::FlatSet<string>{
|
auto list = gtl::FlatSet<string>{
|
||||||
"Abs",
|
"Abs",
|
||||||
@ -291,21 +323,6 @@ class AutoMixedPrecisionLists {
|
|||||||
"StridedSlice",
|
"StridedSlice",
|
||||||
"StridedSliceGrad",
|
"StridedSliceGrad",
|
||||||
"Switch",
|
"Switch",
|
||||||
"TensorListConcat",
|
|
||||||
"TensorListConcatLists",
|
|
||||||
"TensorListConcatV2",
|
|
||||||
"TensorListGather",
|
|
||||||
"TensorListGetItem",
|
|
||||||
"TensorListPopBack",
|
|
||||||
"TensorListPushBack",
|
|
||||||
"TensorListPushBackBatch",
|
|
||||||
"TensorListFromTensor",
|
|
||||||
"TensorListScatter",
|
|
||||||
"TensorListScatterV2",
|
|
||||||
"TensorListScatterIntoExistingList",
|
|
||||||
"TensorListSetItem",
|
|
||||||
"TensorListSplit",
|
|
||||||
"TensorListStack",
|
|
||||||
"Tile",
|
"Tile",
|
||||||
"TopK",
|
"TopK",
|
||||||
"TopKV2",
|
"TopKV2",
|
||||||
@ -313,7 +330,125 @@ class AutoMixedPrecisionLists {
|
|||||||
"Where",
|
"Where",
|
||||||
"ZerosLike",
|
"ZerosLike",
|
||||||
};
|
};
|
||||||
UpdateList(&list, to_add, to_remove);
|
AddTensorListOps(&list);
|
||||||
|
UpdateList("CLEARLIST", &list);
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int cuda_version_;
|
||||||
|
int cudnn_version_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||||
|
private:
|
||||||
|
|
||||||
|
public:
|
||||||
|
AutoMixedPrecisionListsMkl() {}
|
||||||
|
|
||||||
|
// Only ops which are supported by MKL in bfloat16 should be added to the
|
||||||
|
// white list, gray list, or clear list.
|
||||||
|
gtl::FlatSet<string> WhiteList() override {
|
||||||
|
auto list = gtl::FlatSet<string>{
|
||||||
|
"Conv2D",
|
||||||
|
"Conv2DBackpropFilter",
|
||||||
|
"Conv2DBackpropInput",
|
||||||
|
"Conv3D",
|
||||||
|
"Conv3DBackpropFilterV2",
|
||||||
|
"Conv3DBackpropInputV2",
|
||||||
|
"DepthwiseConv2dNative",
|
||||||
|
"DepthwiseConv2dNativeBackpropFilter",
|
||||||
|
"DepthwiseConv2dNativeBackpropInput",
|
||||||
|
"MatMul",
|
||||||
|
"BatchMatMul",
|
||||||
|
"BatchMatMulV2"
|
||||||
|
};
|
||||||
|
|
||||||
|
UpdateList("WHITELIST", &list);
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
gtl::FlatSet<string> GrayList() override {
|
||||||
|
auto list = gtl::FlatSet<string>{
|
||||||
|
"Add",
|
||||||
|
"AddN",
|
||||||
|
"AddV2",
|
||||||
|
"AvgPool",
|
||||||
|
"AvgPool3D",
|
||||||
|
"AvgPool3DGrad",
|
||||||
|
"AvgPoolGrad",
|
||||||
|
"BiasAdd",
|
||||||
|
"BiasAddGrad",
|
||||||
|
"BiasAddV1",
|
||||||
|
"FusedBatchNormV2",
|
||||||
|
"FusedBatchNormGradV2",
|
||||||
|
"FusedBatchNormV3",
|
||||||
|
"FusedBatchNormGradV3",
|
||||||
|
"LeakyRelu",
|
||||||
|
"LeakyReluGrad",
|
||||||
|
"Mul",
|
||||||
|
"Sub",
|
||||||
|
};
|
||||||
|
UpdateList("GRAYLIST", &list);
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
gtl::FlatSet<string> BlackList() override {
|
||||||
|
auto list = gtl::FlatSet<string>{
|
||||||
|
"Exp",
|
||||||
|
"Expm1",
|
||||||
|
"L2Loss",
|
||||||
|
"Mean",
|
||||||
|
"Pow",
|
||||||
|
"SaveV2",
|
||||||
|
"Softmax",
|
||||||
|
"SoftmaxCrossEntropyWithLogits",
|
||||||
|
"SparseSoftmaxCrossEntropyWithLogits",
|
||||||
|
"Sum",
|
||||||
|
};
|
||||||
|
UpdateList("BLACKLIST", &list);
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
gtl::FlatSet<string> ClearList() override {
|
||||||
|
auto list = gtl::FlatSet<string>{
|
||||||
|
"Concat",
|
||||||
|
"ConcatV2",
|
||||||
|
"Enter",
|
||||||
|
"EnsureShape",
|
||||||
|
"Equal",
|
||||||
|
"Exit",
|
||||||
|
"ExpandDims",
|
||||||
|
"Identity",
|
||||||
|
"MaxPool",
|
||||||
|
"MaxPool3D",
|
||||||
|
"MaxPool3DGrad",
|
||||||
|
"MaxPoolGrad",
|
||||||
|
"MaxPoolV2",
|
||||||
|
"Maximum",
|
||||||
|
"Merge",
|
||||||
|
"NextIteration",
|
||||||
|
"PreventGradient",
|
||||||
|
"Relu",
|
||||||
|
"Relu6",
|
||||||
|
"Relu6Grad",
|
||||||
|
"ReluGrad",
|
||||||
|
"Reshape",
|
||||||
|
"Select",
|
||||||
|
"SelectV2",
|
||||||
|
"Shape",
|
||||||
|
"ShapeN",
|
||||||
|
"Slice",
|
||||||
|
"Split",
|
||||||
|
"SplitV",
|
||||||
|
"Squeeze",
|
||||||
|
"StopGradient",
|
||||||
|
"Switch",
|
||||||
|
"Transpose",
|
||||||
|
"ZerosLike",
|
||||||
|
};
|
||||||
|
AddTensorListOps(&list);
|
||||||
|
UpdateList("CLEARLIST", &list);
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -13,12 +13,6 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
// Currently, this test only passes when TensorFlow passes with CUDA, because
|
|
||||||
// otherwise the optimizer will not turn clearlist nodes to float16. When
|
|
||||||
// looking at clearlist nodes, this optimizer checks if the nodes have a float16
|
|
||||||
// GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
|
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
||||||
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -70,6 +64,31 @@ Tensor GenerateRandomTensorInRange(const TensorShape& shape, double minval,
|
|||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void VerifyGraphsEquivalent(const GraphDef& original_graph,
|
||||||
|
const GraphDef& optimized_graph,
|
||||||
|
const string& func) {
|
||||||
|
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
|
||||||
|
GraphView optimized_view(&optimized_graph);
|
||||||
|
for (int i = 0; i < original_graph.node_size(); ++i) {
|
||||||
|
const NodeDef& original = original_graph.node(i);
|
||||||
|
const NodeDef& optimized = *optimized_view.GetNode(original.name());
|
||||||
|
EXPECT_EQ(original.name(), optimized.name()) << func;
|
||||||
|
EXPECT_EQ(original.op(), optimized.op()) << func;
|
||||||
|
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
|
||||||
|
if (original.input_size() == optimized.input_size()) {
|
||||||
|
for (int j = 0; j < original.input_size(); ++j) {
|
||||||
|
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Currently, this test suite only passes when TensorFlow passes with CUDA,
|
||||||
|
// because otherwise the optimizer will not turn clearlist nodes to float16.
|
||||||
|
// When looking at clearlist nodes, this optimizer checks if the nodes have a
|
||||||
|
// float16 GPU OpKernel, but without CUDA there are no GPU OpKernels at all.
|
||||||
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
const std::pair<int, int> kMinGPUArch = {7, 0};
|
const std::pair<int, int> kMinGPUArch = {7, 0};
|
||||||
|
|
||||||
class AutoMixedPrecisionTest : public GrapplerTest {
|
class AutoMixedPrecisionTest : public GrapplerTest {
|
||||||
@ -184,25 +203,6 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
|||||||
bool gpu_available_;
|
bool gpu_available_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void VerifyGraphsEquivalent(const GraphDef& original_graph,
|
|
||||||
const GraphDef& optimized_graph,
|
|
||||||
const string& func) {
|
|
||||||
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
|
|
||||||
GraphView optimized_view(&optimized_graph);
|
|
||||||
for (int i = 0; i < original_graph.node_size(); ++i) {
|
|
||||||
const NodeDef& original = original_graph.node(i);
|
|
||||||
const NodeDef& optimized = *optimized_view.GetNode(original.name());
|
|
||||||
EXPECT_EQ(original.name(), optimized.name()) << func;
|
|
||||||
EXPECT_EQ(original.op(), optimized.op()) << func;
|
|
||||||
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
|
|
||||||
if (original.input_size() == optimized.input_size()) {
|
|
||||||
for (int j = 0; j < original.input_size(); ++j) {
|
|
||||||
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(AutoMixedPrecisionTest, NoOp) {
|
TEST_F(AutoMixedPrecisionTest, NoOp) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
|
Output input = ops::Const(s.WithOpName("input"), 1.234f, {32});
|
||||||
@ -1164,8 +1164,188 @@ TEST_F(AutoMixedPrecisionTest, TanhOp) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
|
#if INTEL_MKL
|
||||||
|
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||||
|
|
||||||
|
class AutoMixedPrecisionMklTest : public GrapplerTest {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
virtual_cluster_.reset(new SingleMachine(/* timeout_s = */ 10, 1, 0));
|
||||||
|
TF_CHECK_OK(virtual_cluster_->Provision());
|
||||||
|
}
|
||||||
|
void TearDown() override { TF_CHECK_OK(virtual_cluster_->Shutdown()); }
|
||||||
|
|
||||||
|
std::unique_ptr<Cluster> virtual_cluster_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
||||||
|
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
|
||||||
|
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
||||||
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
||||||
|
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
||||||
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
|
||||||
|
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||||
|
VLOG(1) << output.DebugString();
|
||||||
|
|
||||||
|
VerifyGraphsEquivalent(item.graph, output, __FUNCTION__);
|
||||||
|
GraphView output_view(&output);
|
||||||
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
|
EXPECT_EQ(tensors.size(), item.fetch.size());
|
||||||
|
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||||
|
test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
|
Output blk1 = ops::Exp(s.WithOpName("blk1"), input);
|
||||||
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||||
|
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||||
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||||
|
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
|
||||||
|
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
|
||||||
|
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
|
||||||
|
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||||
|
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
|
||||||
|
Output clr5 = ops::Relu(s.WithOpName("clr5"), blk3);
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("fetch"), clr5);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
|
||||||
|
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||||
|
|
||||||
|
VLOG(1) << output.DebugString();
|
||||||
|
|
||||||
|
GraphView output_view(&output);
|
||||||
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Ta").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("blk3")->attr().at("Tb").type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("clr5")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
|
EXPECT_EQ(tensors.size(), item.fetch.size());
|
||||||
|
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||||
|
test::ExpectClose(tensors_expected[i], tensors[i], -1, 5e-4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
tensorflow::Input shape = {32, 32};
|
||||||
|
auto tl1 = ops::TensorListReserve(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
||||||
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
|
Output idx1 = ops::Const(s.WithOpName("idx1"), 1);
|
||||||
|
Output idx2 = ops::Const(s.WithOpName("idx2"), 2);
|
||||||
|
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
||||||
|
auto tl1w1 =
|
||||||
|
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
||||||
|
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||||
|
auto tl1w2 =
|
||||||
|
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
|
||||||
|
// Ensure that TensorListResize doesn't cause any problems.
|
||||||
|
Output tl1rs =
|
||||||
|
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
||||||
|
Output tl1r1 = ops::TensorListGetItem(s.WithOpName("tl1r1"), tl1rs, idx2,
|
||||||
|
shape, DT_FLOAT)
|
||||||
|
.item;
|
||||||
|
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
|
||||||
|
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||||
|
auto tl1w3 =
|
||||||
|
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2);
|
||||||
|
Output tl1r2 =
|
||||||
|
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
||||||
|
shape, DT_FLOAT)
|
||||||
|
.item;
|
||||||
|
auto tl2 = ops::TensorListReserve(s.WithOpName("tl2"), shape, 8, DT_FLOAT);
|
||||||
|
auto tl2w1 =
|
||||||
|
ops::TensorListSetItem(s.WithOpName("tl2w1"), tl2.handle, idx1, input);
|
||||||
|
Output tl2r1 =
|
||||||
|
ops::TensorListGetItem(s.WithOpName("tl2r1"), tl2w1.output_handle, idx1,
|
||||||
|
shape, DT_FLOAT)
|
||||||
|
.item;
|
||||||
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), tl1r2);
|
||||||
|
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), tl2r1);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch1", "fetch2"};
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
|
||||||
|
AutoMixedPrecision optimizer{AutoMixedPrecisionMode::MKL};
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output));
|
||||||
|
|
||||||
|
VLOG(1) << output.DebugString();
|
||||||
|
|
||||||
|
GraphView output_view(&output);
|
||||||
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
|
const char* type_key = "element_dtype";
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(),
|
||||||
|
DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
|
||||||
|
DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
|
||||||
|
DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
|
||||||
|
DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
|
||||||
|
DT_BFLOAT16);
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_view.GetNode("tl2r1")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
|
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
|
EXPECT_EQ(tensors.size(), item.fetch.size());
|
||||||
|
for (int i = 0; i < item.fetch.size(); ++i) {
|
||||||
|
test::ExpectClose(tensors_expected[i], tensors[i], -1, 1e-2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // ENABLE_INTEL_MKL_BFLOAT16
|
||||||
|
#endif // INTEL_MKL
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
|
||||||
|
|||||||
@ -188,7 +188,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
|
|||||||
MK_OPT("remap", new Remapper(cfg_.remapping()));
|
MK_OPT("remap", new Remapper(cfg_.remapping()));
|
||||||
MK_OPT("layout", new GenericLayoutOptimizer());
|
MK_OPT("layout", new GenericLayoutOptimizer());
|
||||||
MK_OPT("auto_mixed_precision",
|
MK_OPT("auto_mixed_precision",
|
||||||
new AutoMixedPrecision(cfg_.auto_mixed_precision()));
|
new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
|
||||||
|
MK_OPT("auto_mixed_precision_mkl",
|
||||||
|
new AutoMixedPrecision(AutoMixedPrecisionMode::MKL));
|
||||||
MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
|
MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
|
||||||
MK_OPT("common_subgraph_elimination",
|
MK_OPT("common_subgraph_elimination",
|
||||||
new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
|
new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
|
||||||
@ -249,7 +251,11 @@ Status MetaOptimizer::InitializeOptimizers(
|
|||||||
}
|
}
|
||||||
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
|
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
|
||||||
optimizers->push_back(
|
optimizers->push_back(
|
||||||
MakeUnique<AutoMixedPrecision>(cfg_.auto_mixed_precision()));
|
MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
|
||||||
|
}
|
||||||
|
if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())) {
|
||||||
|
optimizers->push_back(
|
||||||
|
MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::MKL));
|
||||||
}
|
}
|
||||||
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
|
||||||
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
|
||||||
@ -835,6 +841,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
|
|||||||
rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
||||||
rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
|
rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
|
||||||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
|
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
|
||||||
|
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
|
||||||
!rewrite_cfg.optimizers().empty() ||
|
!rewrite_cfg.optimizers().empty() ||
|
||||||
!rewrite_cfg.custom_optimizers().empty();
|
!rewrite_cfg.custom_optimizers().empty();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -85,11 +85,15 @@ message RewriterConfig {
|
|||||||
// Enable the swap of kernel implementations based on the device placement
|
// Enable the swap of kernel implementations based on the device placement
|
||||||
// (default is ON).
|
// (default is ON).
|
||||||
Toggle implementation_selector = 22;
|
Toggle implementation_selector = 22;
|
||||||
// Optimize data types (default is OFF).
|
// Optimize data types for CUDA (default is OFF).
|
||||||
// e.g., This will try to use float16 on GPU which is faster.
|
// This will try to use float16 on GPU which is faster.
|
||||||
// Note that this can change the numerical stability of the graph and may
|
// Note that this can change the numerical stability of the graph and may
|
||||||
// require the use of loss scaling to maintain model convergence.
|
// require the use of loss scaling to maintain model convergence.
|
||||||
Toggle auto_mixed_precision = 23;
|
Toggle auto_mixed_precision = 23;
|
||||||
|
// Optimize data types for MKL (default is OFF).
|
||||||
|
// This will try to use bfloat16 on CPUs, which is faster.
|
||||||
|
// Note that this can change the numerical stability of the graph.
|
||||||
|
Toggle auto_mixed_precision_mkl = 25;
|
||||||
// Disable the entire meta optimizer (off by default).
|
// Disable the entire meta optimizer (off by default).
|
||||||
bool disable_meta_optimizer = 19;
|
bool disable_meta_optimizer = 19;
|
||||||
|
|
||||||
|
|||||||
@ -19,8 +19,8 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
@ -209,7 +209,7 @@ def _make_node_with_color(color, input_tensor, name=None):
|
|||||||
if color == 'c': # Clear node
|
if color == 'c': # Clear node
|
||||||
return nn.relu(input_tensor, name=name)
|
return nn.relu(input_tensor, name=name)
|
||||||
if color == 'b': # Black node
|
if color == 'b': # Black node
|
||||||
return math_ops.sqrt(math_ops.pow(input_tensor, 2.), name=name)
|
return math_ops.pow(math_ops.pow(input_tensor, 2.), 0.5, name=name)
|
||||||
raise ValueError('Invalid node color: ' + str(color))
|
raise ValueError('Invalid node color: ' + str(color))
|
||||||
|
|
||||||
|
|
||||||
@ -231,18 +231,21 @@ def _build_simple_loop_graph(inp_colors, body_colors, out_colors):
|
|||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
def _get_config(auto_mixed_precision=True):
|
def _get_config(auto_mixed_precision_mode):
|
||||||
"""Returns a ConfigProto with auto mixed precision enabled if appropriate."""
|
"""Returns a ConfigProto with auto mixed precision enabled if appropriate."""
|
||||||
if auto_mixed_precision:
|
rewrite_config = rewriter_config_pb2.RewriterConfig(
|
||||||
rewrite_config = rewriter_config_pb2.RewriterConfig(
|
# do not remove duplicated nodes
|
||||||
auto_mixed_precision=rewriter_config_pb2.RewriterConfig.ON,
|
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
# do not remove duplicated nodes
|
# do not turn Conv2D and other nodes into _FusedConv2D
|
||||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
remapping=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
|
)
|
||||||
|
if auto_mixed_precision_mode == 'cuda':
|
||||||
|
rewrite_config.auto_mixed_precision = rewriter_config_pb2.RewriterConfig.ON
|
||||||
|
elif auto_mixed_precision_mode == 'mkl':
|
||||||
|
rewrite_config.auto_mixed_precision_mkl = (
|
||||||
|
rewriter_config_pb2.RewriterConfig.ON)
|
||||||
else:
|
else:
|
||||||
rewrite_config = rewriter_config_pb2.RewriterConfig(
|
assert auto_mixed_precision_mode is None
|
||||||
auto_mixed_precision=rewriter_config_pb2.RewriterConfig.OFF,
|
|
||||||
# do not remove duplicated nodes
|
|
||||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
|
|
||||||
rewrite_config.min_graph_nodes = -1
|
rewrite_config.min_graph_nodes = -1
|
||||||
graph_options = config_pb2.GraphOptions(
|
graph_options = config_pb2.GraphOptions(
|
||||||
rewrite_options=rewrite_config, build_cost_model=1)
|
rewrite_options=rewrite_config, build_cost_model=1)
|
||||||
@ -255,19 +258,33 @@ def _is_cast_to_fp16(node_name):
|
|||||||
return node_name.endswith('-CastToFp16-AutoMixedPrecision')
|
return node_name.endswith('-CastToFp16-AutoMixedPrecision')
|
||||||
|
|
||||||
|
|
||||||
|
def _is_cast_to_bf16(node_name):
|
||||||
|
return node_name.endswith('-CastToBf16-AutoMixedPrecision')
|
||||||
|
|
||||||
|
|
||||||
def _is_cast_to_fp32(node_name):
|
def _is_cast_to_fp32(node_name):
|
||||||
return node_name.endswith('-CastToFp32-AutoMixedPrecision')
|
return node_name.endswith('-CastToFp32-AutoMixedPrecision')
|
||||||
|
|
||||||
|
|
||||||
def _count_casts(nodes):
|
def _count_casts(mode, nodes):
|
||||||
|
"""Counts the number of casts to f16 and fp32."""
|
||||||
num_to_fp16 = 0
|
num_to_fp16 = 0
|
||||||
|
num_to_bf16 = 0
|
||||||
num_to_fp32 = 0
|
num_to_fp32 = 0
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if _is_cast_to_fp16(node.name):
|
if _is_cast_to_fp16(node.name):
|
||||||
num_to_fp16 += 1
|
num_to_fp16 += 1
|
||||||
|
if _is_cast_to_bf16(node.name):
|
||||||
|
num_to_bf16 += 1
|
||||||
elif _is_cast_to_fp32(node.name):
|
elif _is_cast_to_fp32(node.name):
|
||||||
num_to_fp32 += 1
|
num_to_fp32 += 1
|
||||||
return num_to_fp16, num_to_fp32
|
if mode == 'cuda':
|
||||||
|
assert num_to_bf16 == 0
|
||||||
|
return num_to_fp16, num_to_fp32
|
||||||
|
else:
|
||||||
|
assert mode == 'mkl'
|
||||||
|
assert num_to_fp16 == 0
|
||||||
|
return num_to_bf16, num_to_fp32
|
||||||
|
|
||||||
|
|
||||||
def _build_node_map(nodes):
|
def _build_node_map(nodes):
|
||||||
@ -303,7 +320,7 @@ def _example_noninlined_funcdef(features):
|
|||||||
return features * math_ops.sigmoid(features)
|
return features * math_ops.sigmoid(features)
|
||||||
|
|
||||||
|
|
||||||
class AutoMixedPrecisionTest(test.TestCase):
|
class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||||
"""Tests the Grappler auto mixed precision optimizer."""
|
"""Tests the Grappler auto mixed precision optimizer."""
|
||||||
IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
|
IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
|
||||||
|
|
||||||
@ -311,8 +328,8 @@ class AutoMixedPrecisionTest(test.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(AutoMixedPrecisionTest, self).setUp()
|
super(AutoMixedPrecisionTest, self).setUp()
|
||||||
# Enable the tests to be run on pre-Volta GPUs by telling the grappler pass
|
# Enable the CUDA tests to be run on pre-Volta GPUs by telling the grappler
|
||||||
# to ignore performance and always transform the graph.
|
# pass to ignore performance and always transform the graph.
|
||||||
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
|
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
|
||||||
os.environ[self.IGNORE_PERF_VAR] = '1'
|
os.environ[self.IGNORE_PERF_VAR] = '1'
|
||||||
|
|
||||||
@ -323,24 +340,33 @@ class AutoMixedPrecisionTest(test.TestCase):
|
|||||||
del os.environ[self.IGNORE_PERF_VAR]
|
del os.environ[self.IGNORE_PERF_VAR]
|
||||||
super(AutoMixedPrecisionTest, self).tearDown()
|
super(AutoMixedPrecisionTest, self).tearDown()
|
||||||
|
|
||||||
def _assert_output_fp16(self, node_map, node_name, output_port=0):
|
def _lower_precision_dtype(self, mode):
|
||||||
self.assertEqual(node_map[node_name].output_info[output_port].dtype,
|
return dtypes.float16 if mode == 'cuda' else dtypes.bfloat16
|
||||||
types_pb2.DT_HALF)
|
|
||||||
|
|
||||||
def _run(self, fetches):
|
def _assert_output_f16(self, mode, node_map, node_name, output_port=0):
|
||||||
|
self.assertEqual(node_map[node_name].output_info[output_port].dtype,
|
||||||
|
self._lower_precision_dtype(mode).as_datatype_enum)
|
||||||
|
|
||||||
|
def _run(self, mode, fetches):
|
||||||
"""Runs the graph and returns the evaluation of the fetches."""
|
"""Runs the graph and returns the evaluation of the fetches."""
|
||||||
with session.Session(config=_get_config(False)) as sess:
|
with session.Session(config=_get_config(None)) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
output_val_ref = self.evaluate(fetches)
|
output_val_ref = self.evaluate(fetches)
|
||||||
|
|
||||||
with session.Session(config=_get_config()) as sess:
|
with session.Session(config=_get_config(mode)) as sess:
|
||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
metadata = config_pb2.RunMetadata()
|
metadata = config_pb2.RunMetadata()
|
||||||
output_val = sess.run(fetches, run_metadata=metadata)
|
output_val = sess.run(fetches, run_metadata=metadata)
|
||||||
|
|
||||||
return output_val_ref, output_val, metadata.cost_graph
|
return output_val_ref, output_val, metadata.cost_graph
|
||||||
|
|
||||||
def _run_simple_loop_test(self, inp, body, out):
|
def _maybe_skip(self, mode):
|
||||||
|
if mode == 'cuda' and not test.is_gpu_available(cuda_only=True):
|
||||||
|
self.skipTest('No GPU is available')
|
||||||
|
if mode == 'mkl' and not test_util.IsMklEnabled():
|
||||||
|
self.skipTest('MKL is not enabled')
|
||||||
|
|
||||||
|
def _run_simple_loop_test(self, mode, inp, body, out):
|
||||||
"""Runs a test of a simple loop.
|
"""Runs a test of a simple loop.
|
||||||
|
|
||||||
The loop has different node colors in different sections of the graph. The
|
The loop has different node colors in different sections of the graph. The
|
||||||
@ -359,398 +385,441 @@ class AutoMixedPrecisionTest(test.TestCase):
|
|||||||
out: A string of letters indicating the colors and expected dtypes of the
|
out: A string of letters indicating the colors and expected dtypes of the
|
||||||
output nodes.
|
output nodes.
|
||||||
"""
|
"""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
expected_types = []
|
expected_types = []
|
||||||
for section in [inp, body, out]:
|
for section in [inp, body, out]:
|
||||||
section_expected_types = []
|
section_expected_types = []
|
||||||
for color in section:
|
for color in section:
|
||||||
if color.isupper():
|
if color.isupper():
|
||||||
expected_type = types_pb2.DT_HALF
|
expected_type = self._lower_precision_dtype(mode).as_datatype_enum
|
||||||
else:
|
else:
|
||||||
expected_type = types_pb2.DT_FLOAT
|
expected_type = types_pb2.DT_FLOAT
|
||||||
section_expected_types.append(expected_type)
|
section_expected_types.append(expected_type)
|
||||||
expected_types.append(section_expected_types)
|
expected_types.append(section_expected_types)
|
||||||
|
|
||||||
a = _build_simple_loop_graph(inp, body, out)
|
a = _build_simple_loop_graph(inp, body, out)
|
||||||
output_val_ref, output_val, cost_graph = self._run(a)
|
output_val_ref, output_val, cost_graph = self._run(mode, a)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
section_names = ['input', 'while/body', 'output']
|
section_names = ['input', 'while/body', 'output']
|
||||||
all_types_correct = True
|
all_types_correct = True
|
||||||
for section_name, expected_types in zip(section_names, expected_types):
|
for section_name, expected_types in zip(section_names, expected_types):
|
||||||
for i, expected_type in enumerate(expected_types):
|
for i, expected_type in enumerate(expected_types):
|
||||||
node_name = section_name + '_%i' % i
|
node_name = section_name + '_%i' % i
|
||||||
output_port = 0
|
output_port = 0
|
||||||
optimized_type = node_map[node_name].output_info[output_port].dtype
|
optimized_type = node_map[node_name].output_info[output_port].dtype
|
||||||
if optimized_type != expected_type:
|
if optimized_type != expected_type:
|
||||||
print('Expected node %s to have type %s but got type %s' %
|
print('Expected node %s to have type %s but got type %s' %
|
||||||
(node_name, expected_type, optimized_type))
|
(node_name, expected_type, optimized_type))
|
||||||
all_types_correct = False
|
all_types_correct = False
|
||||||
self.assertTrue(all_types_correct)
|
self.assertTrue(all_types_correct)
|
||||||
|
if mode == 'mkl':
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=2e-2, rtol=2e-2)
|
||||||
|
else:
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
|
self.assertAllClose(output_val_ref, output_val, atol=2e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_conv_bn(self):
|
def test_conv_bn(self, mode):
|
||||||
"""Test graph with convolution followed by batch norm."""
|
"""Test graph with convolution followed by batch norm."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([2, 8, 8, 1])
|
x = _input([2, 8, 8, 1])
|
||||||
x = _conv_bn(x)
|
x = _conv_bn(x)
|
||||||
output = _conv_bn(x)
|
output = _conv_bn(x)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'Conv2D')
|
self._assert_output_f16(mode, node_map, 'Conv2D')
|
||||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
|
||||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
self._assert_output_f16(mode, node_map, 'Conv2D_1')
|
||||||
self.assertEqual(num_to_fp16,
|
self.assertEqual(num_to_f16, 3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
|
||||||
3) # Before Conv2D:0, Conv2D:1, Conv2D_1:1
|
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
if mode == 'mkl':
|
||||||
|
tol = 1e-2
|
||||||
|
elif test.is_built_with_rocm():
|
||||||
# Bump up the tolerance for the ROCm platform
|
# Bump up the tolerance for the ROCm platform
|
||||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||||
# miscompares on ROCm platform, and hence the tolerance bump
|
# miscompares on ROCm platform, and hence the tolerance bump
|
||||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
tol = 2e-3
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
else:
|
||||||
|
tol = 1e-3
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
# TODO: enable these tests when cuDNN is upgraded to >= 7.6.2. Same with the
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
# test_conv3d() below.
|
|
||||||
@unittest.skip('Test case should be skipped when cuDNN < 7.6.2')
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_conv3d_bn(self):
|
def test_conv3d_bn(self, mode):
|
||||||
"""Test graph with convolution followed by batch norm."""
|
"""Test graph with convolution followed by batch norm."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
if mode == 'cuda':
|
||||||
x = _input([2, 8, 8, 8, 1])
|
# TODO: enable these tests when cuDNN is upgraded to >= 7.6.2.
|
||||||
x = _conv3d_bn(x)
|
self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
|
||||||
output = _conv3d_bn(x)
|
random_seed.set_random_seed(0)
|
||||||
|
x = _input([2, 8, 8, 8, 1])
|
||||||
|
x = _conv3d_bn(x)
|
||||||
|
output = _conv3d_bn(x)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
num_to_fp16, num_to_fp32 = _count_casts(mode, cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'Conv3D')
|
self._assert_output_f16(mode, node_map, 'Conv3D')
|
||||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
|
||||||
self._assert_output_fp16(node_map, 'Conv3D_1')
|
self._assert_output_f16(mode, node_map, 'Conv3D_1')
|
||||||
self.assertEqual(num_to_fp16, 3) # Before Conv3D:0, Conv3D:1, Conv3D_1:1
|
self.assertEqual(num_to_fp16, 3) # Before Conv3D:0, Conv3D:1, Conv3D_1:1
|
||||||
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
self.assertEqual(num_to_fp32, 1) # After FusedBatchNormV3:0
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2)
|
self.assertAllClose(output_val_ref, output_val, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
@unittest.skip('Test case should be skipped when cuDNN < 7.6.2')
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_conv3d(self):
|
def test_conv3d(self, mode):
|
||||||
"""Test grad ops with convolution3d graph."""
|
"""Test grad ops with convolution3d graph."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
if mode == 'cuda':
|
||||||
x = _input([2, 8, 8, 8, 1])
|
# TODO: enable these tests when cuDNN is upgraded to >= 7.6.2.
|
||||||
f = _weight([3, 3, 3, 1, 6])
|
self.skipTest('Test case should be skipped when cuDNN < 7.6.2')
|
||||||
y = _conv3d(x, f)
|
random_seed.set_random_seed(0)
|
||||||
y = array_ops.identity(y)
|
x = _input([2, 8, 8, 8, 1])
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
f = _weight([3, 3, 3, 1, 6])
|
||||||
g = optimizer.compute_gradients(y, [x, f])
|
y = _conv3d(x, f)
|
||||||
output = (y, g)
|
y = array_ops.identity(y)
|
||||||
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||||
|
g = optimizer.compute_gradients(y, [x, f])
|
||||||
|
output = (y, g)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
self._assert_output_fp16(node_map, 'Conv3D')
|
self._assert_output_f16(mode, node_map, 'Conv3D')
|
||||||
self._assert_output_fp16(node_map,
|
self._assert_output_f16(mode, node_map,
|
||||||
'gradients/Conv3D_grad/Conv3DBackpropInputV2')
|
'gradients/Conv3D_grad/Conv3DBackpropInputV2')
|
||||||
self._assert_output_fp16(node_map,
|
self._assert_output_f16(mode, node_map,
|
||||||
'gradients/Conv3D_grad/Conv3DBackpropFilterV2')
|
'gradients/Conv3D_grad/Conv3DBackpropFilterV2')
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
tol = 5e-2 if mode == 'mkl' else 1e-3
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
|
||||||
|
# MKL
|
||||||
|
@parameterized.parameters(['cuda'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_conv_bn_dropout(self):
|
def test_conv_bn_dropout(self, mode):
|
||||||
"""Test dropout precision of convolution batch norm graph."""
|
"""Test dropout precision of convolution batch norm graph."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([2, 8, 8, 1])
|
x = _input([2, 8, 8, 1])
|
||||||
y = _conv_bn(x)
|
y = _conv_bn(x)
|
||||||
y = nn.dropout(y, rate=0.5)
|
y = nn.dropout(y, rate=0.5)
|
||||||
y = math_ops.add(y, 1, name='addition')
|
y = math_ops.add(y, 1, name='addition')
|
||||||
y = _conv_bn(y)
|
y = _conv_bn(y)
|
||||||
y = array_ops.identity(y)
|
y = array_ops.identity(y)
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(
|
optimizer = gradient_descent.GradientDescentOptimizer(
|
||||||
learning_rate=0.01)
|
learning_rate=0.01)
|
||||||
g = optimizer.compute_gradients(y, [x])
|
g = optimizer.compute_gradients(y, [x])
|
||||||
output = (y, g)
|
output = (y, g)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
self._assert_output_fp16(node_map, 'Conv2D')
|
self._assert_output_f16(mode, node_map, 'Conv2D')
|
||||||
self._assert_output_fp16(node_map, 'FusedBatchNormV3')
|
self._assert_output_f16(mode, node_map, 'FusedBatchNormV3')
|
||||||
# We do not assert dropout's dtype because we do not want to rely on the
|
# We do not assert dropout's dtype because we do not want to rely on the
|
||||||
# node names of dropout's internal implementation.
|
# node names of dropout's internal implementation.
|
||||||
self._assert_output_fp16(node_map, 'addition')
|
self._assert_output_f16(mode, node_map, 'addition')
|
||||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
self._assert_output_f16(mode, node_map, 'Conv2D_1')
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
# Bump up the tolerance for the ROCm platform
|
# Bump up the tolerance for the ROCm platform
|
||||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||||
# miscompares on ROCm platform, and hence the tolerance bump
|
# miscompares on ROCm platform, and hence the tolerance bump
|
||||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
# TODO(reedwm): Fix and enable this test with MKL. Currently this crashes with
|
||||||
|
# MKL
|
||||||
|
@parameterized.parameters(['cuda'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_conv_pool(self):
|
def test_conv_pool(self, mode):
|
||||||
"""Test graph with convolution followed by pooling."""
|
"""Test graph with convolution followed by pooling."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([2, 8, 8, 1])
|
x = _input([2, 8, 8, 1])
|
||||||
output = _conv_pool(x)
|
output = _conv_pool(x)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
num_to_fp16, num_to_fp32 = _count_casts(cost_graph.node)
|
num_to_f16, num_to_fp32 = _count_casts(mode, cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'Conv2D')
|
self._assert_output_f16(mode, node_map, 'Conv2D')
|
||||||
self._assert_output_fp16(node_map, 'Relu')
|
self._assert_output_f16(mode, node_map, 'Relu')
|
||||||
self._assert_output_fp16(node_map, 'MaxPool')
|
self._assert_output_f16(mode, node_map, 'MaxPool')
|
||||||
self._assert_output_fp16(node_map, 'Conv2D_1')
|
self._assert_output_f16(mode, node_map, 'Conv2D_1')
|
||||||
self.assertEqual(num_to_fp16, 4)
|
self.assertEqual(num_to_f16, 4)
|
||||||
self.assertEqual(num_to_fp32, 1)
|
self.assertEqual(num_to_fp32, 1)
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
tol = 5e-3 if mode == 'mkl' else 1e-3
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('b/138749235')
|
@test_util.run_v1_only('b/138749235')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_simple_loop(self):
|
def test_simple_loop(self, mode):
|
||||||
"""Test graph with while loop."""
|
"""Test graph with while loop."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([8, 8])
|
x = _input([8, 8])
|
||||||
y = _simple_loop(x, _matmul_act)[1]
|
y = _simple_loop(x, _matmul_act)[1]
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||||
g = optimizer.compute_gradients(y, [x])
|
g = optimizer.compute_gradients(y, [x])
|
||||||
output = (y, g)
|
output = (y, g)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
self._assert_output_f16(mode, node_map, 'while/MatMul')
|
||||||
self._assert_output_fp16(node_map, 'while/Relu')
|
self._assert_output_f16(mode, node_map, 'while/Relu')
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
tol = 1e-2 if mode == 'mkl' else 1e-3
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('b/138749235')
|
@test_util.run_v1_only('b/138749235')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_loop_with_vars_intertwined(self):
|
def test_loop_with_vars_intertwined(self, mode):
|
||||||
"""Test graph with intertwined while loops."""
|
"""Test graph with intertwined while loops."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([8, 8])
|
x = _input([8, 8])
|
||||||
_, _, k, l = _loop_vars_intertwined(
|
_, _, k, l = _loop_vars_intertwined(
|
||||||
array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act)
|
array_ops.ones(array_ops.shape(x)), x, _matmul_act, _matmul_act)
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||||
g = optimizer.compute_gradients(k, [x])
|
g = optimizer.compute_gradients(k, [x])
|
||||||
output = (k, l, g)
|
output = (k, l, g)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
self._assert_output_f16(mode, node_map, 'while/MatMul')
|
||||||
self._assert_output_fp16(node_map, 'while/Relu')
|
self._assert_output_f16(mode, node_map, 'while/Relu')
|
||||||
self._assert_output_fp16(node_map, 'while/MatMul_1')
|
self._assert_output_f16(mode, node_map, 'while/MatMul_1')
|
||||||
self._assert_output_fp16(node_map, 'while/Relu_1')
|
self._assert_output_f16(mode, node_map, 'while/Relu_1')
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
tol = 5e-3 if mode == 'mkl' else 1e-3
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_multi_paths(self):
|
def test_multi_paths(self, mode):
|
||||||
"""Test graph with multiple paths."""
|
"""Test graph with multiple paths."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([2, 8, 8, 3])
|
x = _input([2, 8, 8, 3])
|
||||||
x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3)
|
x1, x2, x3 = array_ops.split(x, num_or_size_splits=3, axis=3)
|
||||||
y1 = _conv_pool(x1)
|
y1 = _conv_pool(x1)
|
||||||
y2 = _conv_pool(x2)
|
y2 = _conv_pool(x2)
|
||||||
y3 = _conv_pool(x3)
|
y3 = _conv_pool(x3)
|
||||||
y = array_ops.concat([y1, y2, y3], axis=3)
|
y = array_ops.concat([y1, y2, y3], axis=3)
|
||||||
y = array_ops.identity(y)
|
y = array_ops.identity(y)
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||||
g = optimizer.compute_gradients(y, [x])
|
g = optimizer.compute_gradients(y, [x])
|
||||||
output = (y, g)
|
output = (y, g)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'split')
|
self._assert_output_f16(mode, node_map, 'split')
|
||||||
for suffix in [''] + ['_%i' % i for i in range(1, 6)]:
|
for suffix in [''] + ['_%i' % i for i in range(1, 6)]:
|
||||||
self._assert_output_fp16(node_map, 'Conv2D' + suffix)
|
self._assert_output_f16(mode, node_map, 'Conv2D' + suffix)
|
||||||
self._assert_output_fp16(node_map, 'Relu' + suffix)
|
self._assert_output_f16(mode, node_map, 'Relu' + suffix)
|
||||||
self._assert_output_fp16(node_map, 'MaxPool' + suffix)
|
self._assert_output_f16(mode, node_map, 'MaxPool' + suffix)
|
||||||
self._assert_output_fp16(node_map, 'concat')
|
self._assert_output_f16(mode, node_map, 'concat')
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_multi_paths_2(self):
|
def test_multi_paths_2(self, mode):
|
||||||
"""Test graph with multiple paths."""
|
"""Test graph with multiple paths."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([8, 8])
|
x = _input([8, 8])
|
||||||
y1 = _matmul_act(x)
|
y1 = _matmul_act(x)
|
||||||
y2 = _matmul_act(x)
|
y2 = _matmul_act(x)
|
||||||
y = y1 + y2 + x
|
y = y1 + y2 + x
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||||
g = optimizer.compute_gradients(y, [x])
|
g = optimizer.compute_gradients(y, [x])
|
||||||
output = (g, y)
|
output = (g, y)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'MatMul')
|
self._assert_output_f16(mode, node_map, 'MatMul')
|
||||||
self._assert_output_fp16(node_map, 'Relu')
|
self._assert_output_f16(mode, node_map, 'Relu')
|
||||||
self._assert_output_fp16(node_map, 'MatMul_1')
|
self._assert_output_f16(mode, node_map, 'MatMul_1')
|
||||||
self._assert_output_fp16(node_map, 'Relu_1')
|
self._assert_output_f16(mode, node_map, 'Relu_1')
|
||||||
|
if mode == 'mkl':
|
||||||
|
tol = 2e-2
|
||||||
|
elif test.is_built_with_rocm():
|
||||||
# Bump up the tolerance for the ROCm platform
|
# Bump up the tolerance for the ROCm platform
|
||||||
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
# The default tolerance (1e-3) results in a tiny fraction (<1%) of
|
||||||
# miscompares on ROCm platform, and hence the tolerance bump
|
# miscompares on ROCm platform, and hence the tolerance bump
|
||||||
tol = 2e-3 if test.is_built_with_rocm else 1e-3
|
tol = 2e-3
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
else:
|
||||||
|
tol = 1e-3
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda']) # MKL doesn't support bf16 Sigmoid
|
||||||
@test_util.run_v1_only('b/138749235')
|
@test_util.run_v1_only('b/138749235')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_recurrent_lstm(self):
|
def test_recurrent_lstm(self, mode):
|
||||||
"""Test graph with recurrent lstm."""
|
"""Test graph with recurrent lstm."""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
init_c = _input([8, 4])
|
init_c = _input([8, 4])
|
||||||
init_h = _input([8, 4])
|
init_h = _input([8, 4])
|
||||||
_, _, h, _ = _recurrent_lstm(init_c, init_h)
|
_, _, h, _ = _recurrent_lstm(init_c, init_h)
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||||
g = optimizer.compute_gradients(h, [init_c, init_h])
|
g = optimizer.compute_gradients(h, [init_c, init_h])
|
||||||
output = (h, g)
|
output = (h, g)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'while/concat')
|
self._assert_output_f16(mode, node_map, 'while/concat')
|
||||||
self._assert_output_fp16(node_map, 'while/MatMul')
|
self._assert_output_f16(mode, node_map, 'while/MatMul')
|
||||||
self._assert_output_fp16(node_map, 'while/split')
|
self._assert_output_f16(mode, node_map, 'while/split')
|
||||||
self._assert_output_fp16(node_map, 'while/Sigmoid')
|
self._assert_output_f16(mode, node_map, 'while/Sigmoid')
|
||||||
self._assert_output_fp16(node_map, 'while/Sigmoid_1')
|
self._assert_output_f16(mode, node_map, 'while/Sigmoid_1')
|
||||||
self._assert_output_fp16(node_map, 'while/Sigmoid_2')
|
self._assert_output_f16(mode, node_map, 'while/Sigmoid_2')
|
||||||
self._assert_output_fp16(node_map, 'while/Tanh')
|
self._assert_output_f16(mode, node_map, 'while/Tanh')
|
||||||
self._assert_output_fp16(node_map, 'while/Tanh_1')
|
self._assert_output_f16(mode, node_map, 'while/Tanh_1')
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('v1 loop test')
|
@test_util.run_v1_only('v1 loop test')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_1(self):
|
def test_propagation_through_simple_loop_1(self, mode):
|
||||||
self._run_simple_loop_test('W', 'C', 'C')
|
self._run_simple_loop_test(mode, 'W', 'C', 'C')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('v1 loop test')
|
@test_util.run_v1_only('v1 loop test')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_2(self):
|
def test_propagation_through_simple_loop_2(self, mode):
|
||||||
self._run_simple_loop_test('C', 'C', 'W')
|
self._run_simple_loop_test(mode, 'C', 'C', 'W')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('v1 loop test')
|
@test_util.run_v1_only('v1 loop test')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_3(self):
|
def test_propagation_through_simple_loop_3(self, mode):
|
||||||
self._run_simple_loop_test('W', 'G', 'W')
|
self._run_simple_loop_test(mode, 'W', 'G', 'W')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('v1 loop test')
|
@test_util.run_v1_only('v1 loop test')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_4(self):
|
def test_propagation_through_simple_loop_4(self, mode):
|
||||||
self._run_simple_loop_test('W', 'gbg', 'W')
|
self._run_simple_loop_test(mode, 'W', 'gbg', 'W')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('b/138749235')
|
@test_util.run_v1_only('b/138749235')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_5(self):
|
def test_propagation_through_simple_loop_5(self, mode):
|
||||||
self._run_simple_loop_test('b', 'gWC', 'c')
|
self._run_simple_loop_test(mode, 'b', 'gWC', 'c')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('b/138749235')
|
@test_util.run_v1_only('b/138749235')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_6(self):
|
def test_propagation_through_simple_loop_6(self, mode):
|
||||||
self._run_simple_loop_test('b', 'CWCG', 'C')
|
self._run_simple_loop_test(mode, 'b', 'CWCG', 'C')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('b/138749235')
|
@test_util.run_v1_only('b/138749235')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_7(self):
|
def test_propagation_through_simple_loop_7(self, mode):
|
||||||
self._run_simple_loop_test('C', 'GWCG', 'C')
|
self._run_simple_loop_test(mode, 'C', 'GWCG', 'C')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_v1_only('b/138749235')
|
@test_util.run_v1_only('b/138749235')
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_propagation_through_simple_loop_8(self):
|
def test_propagation_through_simple_loop_8(self, mode):
|
||||||
self._run_simple_loop_test('C', 'CgbgWC', 'g')
|
self._run_simple_loop_test(mode, 'C', 'CgbgWC', 'g')
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_noninlined_funcdef(self):
|
def test_noninlined_funcdef(self, mode):
|
||||||
"""Test graph with non-inlined function subgraph.
|
"""Test graph with non-inlined function subgraph.
|
||||||
|
|
||||||
This requires the grappler pass to handle an OpDef that only appears in the
|
This requires the grappler pass to handle an OpDef that only appears in the
|
||||||
graph's function registry instead of the global op registry.
|
graph's function registry instead of the global op registry.
|
||||||
"""
|
"""
|
||||||
if test.is_gpu_available(cuda_only=True):
|
self._maybe_skip(mode)
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
x = _input([8, 8])
|
x = _input([8, 8])
|
||||||
y = _matmul_act(x)
|
y = _matmul_act(x)
|
||||||
y = _example_noninlined_funcdef(y)
|
y = _example_noninlined_funcdef(y)
|
||||||
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.01)
|
||||||
g = optimizer.compute_gradients(y, [x])
|
g = optimizer.compute_gradients(y, [x])
|
||||||
output = (g, y)
|
output = (g, y)
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(output)
|
output_val_ref, output_val, cost_graph = self._run(mode, output)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'MatMul')
|
self._assert_output_f16(mode, node_map, 'MatMul')
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
tol = 1e-2 if mode == 'mkl' else 1e-3
|
||||||
|
self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
|
||||||
|
|
||||||
|
@parameterized.parameters(['cuda', 'mkl'])
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
@test_util.disable_xla('This test does not pass with XLA')
|
@test_util.disable_xla('This test does not pass with XLA')
|
||||||
def test_ingraph_train_loop(self):
|
def test_ingraph_train_loop(self, mode):
|
||||||
"""Tests a graph containing a while loop around a training update.
|
"""Tests a graph containing a while loop around a training update.
|
||||||
|
|
||||||
This requires the grappler pass to take special care with its handling of
|
This requires the grappler pass to take special care with its handling of
|
||||||
Enter ops that appear in front of reads from non-resource variables. See
|
Enter ops that appear in front of reads from non-resource variables. See
|
||||||
the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc.
|
the use of NodeImplicitlyReadsVariable in auto_mixed_precision.cc.
|
||||||
"""
|
"""
|
||||||
|
self._maybe_skip(mode)
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
# This test tests non-resource variables, which are only used in TF1.
|
# This test tests non-resource variables, which are only used in TF1.
|
||||||
self.skipTest('TensorFlow 1 required')
|
self.skipTest('TensorFlow 1 required')
|
||||||
if test.is_gpu_available(cuda_only=True):
|
random_seed.set_random_seed(1234)
|
||||||
random_seed.set_random_seed(1234)
|
np.random.seed(1234)
|
||||||
np.random.seed(1234)
|
num_iter, bs, nchan, nclass = 100, 64, 32, 100
|
||||||
num_iter, bs, nchan, nclass = 100, 64, 32, 100
|
|
||||||
|
|
||||||
data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
|
data = np.random.normal(size=(bs * num_iter, nchan)).astype(np.float32)
|
||||||
labels = np.random.randint(nclass, size=(bs * num_iter,))
|
labels = np.random.randint(nclass, size=(bs * num_iter,))
|
||||||
ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
|
ds = dataset_ops.Dataset.from_tensor_slices((data, labels))
|
||||||
ds = ds.batch(bs).prefetch(3)
|
ds = ds.batch(bs).prefetch(3)
|
||||||
it = ds.make_one_shot_iterator()
|
it = ds.make_one_shot_iterator()
|
||||||
|
|
||||||
def body(_, i):
|
def body(_, i):
|
||||||
i += 1
|
i += 1
|
||||||
x, yt = it.get_next()
|
x, yt = it.get_next()
|
||||||
dense = layers.Dense(nclass)
|
dense = layers.Dense(nclass)
|
||||||
y = dense(x)
|
y = dense(x)
|
||||||
loss = losses.sparse_softmax_cross_entropy(yt, y)
|
loss = losses.sparse_softmax_cross_entropy(yt, y)
|
||||||
opt = adam.AdamOptimizer()
|
opt = adam.AdamOptimizer()
|
||||||
train_op = opt.minimize(loss, var_list=dense.trainable_weights)
|
train_op = opt.minimize(loss, var_list=dense.trainable_weights)
|
||||||
with ops.control_dependencies([train_op]):
|
with ops.control_dependencies([train_op]):
|
||||||
loss = array_ops.identity(loss)
|
loss = array_ops.identity(loss)
|
||||||
return loss, i
|
return loss, i
|
||||||
|
|
||||||
begin, end = constant_op.constant(0), constant_op.constant(num_iter)
|
begin, end = constant_op.constant(0), constant_op.constant(num_iter)
|
||||||
loss, _ = control_flow_ops.while_loop(
|
loss, _ = control_flow_ops.while_loop(
|
||||||
lambda loss, i: math_ops.less(i, end), body, [0.0, begin])
|
lambda loss, i: math_ops.less(i, end), body, [0.0, begin])
|
||||||
|
|
||||||
output_val_ref, output_val, cost_graph = self._run(loss)
|
output_val_ref, output_val, cost_graph = self._run(mode, loss)
|
||||||
node_map = _build_node_map(cost_graph.node)
|
node_map = _build_node_map(cost_graph.node)
|
||||||
|
|
||||||
self._assert_output_fp16(node_map, 'while/dense/MatMul')
|
self._assert_output_f16(mode, node_map, 'while/dense/MatMul')
|
||||||
self._assert_output_fp16(
|
self._assert_output_f16(
|
||||||
node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1')
|
mode, node_map, 'while/gradients/while/dense/MatMul_grad/MatMul_1')
|
||||||
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
self.assertAllClose(output_val_ref, output_val, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
# TODO(benbarsdell): Add tests for list ops (TensorList*) that pass through
|
# TODO(benbarsdell): Add tests for list ops (TensorList*) that pass through
|
||||||
# graph source/sink nodes, similar to the TensorListThroughFunction C++ test.
|
# graph source/sink nodes, similar to the TensorListThroughFunction C++ test.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user