Replace instances of "whitelist" with "allowlist" where possible. See Google Developer guidelines at https://developers.google.com/style/word-list#blacklist for more information.
PiperOrigin-RevId: 320210110 Change-Id: I480d2b1c80d7d77fdd071b7642011758988f18c0
This commit is contained in:
parent
a85beef408
commit
7eab1f3bfe
@ -50,6 +50,9 @@
|
|||||||
* Tracing and Debugging:
|
* Tracing and Debugging:
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
* Other:
|
* Other:
|
||||||
|
* We have replaced uses of "whitelist" with "allowlist" where possible.
|
||||||
|
Please see https://developers.google.com/style/word-list#blacklist for more
|
||||||
|
context.
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
## Thanks to our Contributors
|
## Thanks to our Contributors
|
||||||
|
@ -44,7 +44,7 @@ Even if the untrusted party only supplies the serialized computation
|
|||||||
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
|
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
|
||||||
set of computation primitives available to TensorFlow is powerful enough that
|
set of computation primitives available to TensorFlow is powerful enough that
|
||||||
you should assume that the TensorFlow process effectively executes arbitrary
|
you should assume that the TensorFlow process effectively executes arbitrary
|
||||||
code. One common solution is to whitelist only a few safe Ops. While this is
|
code. One common solution is to allow only a few safe Ops. While this is
|
||||||
possible in theory, we still recommend you sandbox the execution.
|
possible in theory, we still recommend you sandbox the execution.
|
||||||
|
|
||||||
It depends on the computation graph whether a user provided checkpoint is safe.
|
It depends on the computation graph whether a user provided checkpoint is safe.
|
||||||
|
@ -1096,33 +1096,33 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::flat_hash_set<string> GetOrCreateWhitelist() {
|
absl::flat_hash_set<string> GetOrCreateAllowlist() {
|
||||||
absl::flat_hash_map<string, std::vector<string>>* whitelist_table =
|
absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
|
||||||
tensorflow::GetWhitelistTable();
|
tensorflow::GetAllowlistTable();
|
||||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||||
absl::flat_hash_set<string> whitelist;
|
absl::flat_hash_set<string> allowlist;
|
||||||
|
|
||||||
for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
|
for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
|
||||||
if (s == "FUSIBLE") {
|
if (s == "FUSIBLE") {
|
||||||
for (auto pair : *whitelist_table) {
|
for (auto pair : *allowlist_table) {
|
||||||
whitelist.insert(pair.second.begin(), pair.second.end());
|
allowlist.insert(pair.second.begin(), pair.second.end());
|
||||||
}
|
}
|
||||||
} else if (whitelist_table->contains(s)) {
|
} else if (allowlist_table->contains(s)) {
|
||||||
auto v = whitelist_table->at(s);
|
auto v = allowlist_table->at(s);
|
||||||
whitelist.insert(v.begin(), v.end());
|
allowlist.insert(v.begin(), v.end());
|
||||||
} else if (!s.empty()) {
|
} else if (!s.empty()) {
|
||||||
// Should be a user provided TF operation.
|
// Should be a user provided TF operation.
|
||||||
whitelist.insert(string(s));
|
allowlist.insert(string(s));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (VLOG_IS_ON(2) && !whitelist.empty()) {
|
if (VLOG_IS_ON(2) && !allowlist.empty()) {
|
||||||
std::vector<string> vwhitelist(whitelist.begin(), whitelist.end());
|
std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
|
||||||
absl::c_sort(vwhitelist);
|
absl::c_sort(vallowlist);
|
||||||
VLOG(2) << "XLA clustering will only consider the following TF operations: "
|
VLOG(2) << "XLA clustering will only consider the following TF operations: "
|
||||||
<< absl::StrJoin(vwhitelist, " ");
|
<< absl::StrJoin(vallowlist, " ");
|
||||||
}
|
}
|
||||||
return whitelist;
|
return allowlist;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||||
@ -1156,12 +1156,12 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
|||||||
|
|
||||||
VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
|
VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
|
||||||
|
|
||||||
auto whitelist = GetOrCreateWhitelist();
|
auto allowlist = GetOrCreateAllowlist();
|
||||||
|
|
||||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||||
// Check that user's provided TF operation really exists.
|
// Check that user's provided TF operation really exists.
|
||||||
for (const auto& s : whitelist) {
|
for (const auto& s : allowlist) {
|
||||||
if (!all_ops.contains(string(s))) {
|
if (!all_ops.contains(string(s))) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"The operation '", s,
|
"The operation '", s,
|
||||||
@ -1206,7 +1206,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
|
if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
|
||||||
VLOG(1) << "Rejecting TF operation " << node->def().op()
|
VLOG(1) << "Rejecting TF operation " << node->def().op()
|
||||||
<< " as it is not listed in --tf_xla_ops_to_cluster.";
|
<< " as it is not listed in --tf_xla_ops_to_cluster.";
|
||||||
continue;
|
continue;
|
||||||
@ -1781,7 +1781,7 @@ Status MarkForCompilationPass::RunForTest(
|
|||||||
return MarkForCompilation(options, debug_options);
|
return MarkForCompilation(options, debug_options);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
|
||||||
// Table format: category name: {list of TF operations in that category}
|
// Table format: category name: {list of TF operations in that category}
|
||||||
static absl::flat_hash_map<string, std::vector<string>>* result =
|
static absl::flat_hash_map<string, std::vector<string>>* result =
|
||||||
new absl::flat_hash_map<string, std::vector<string>>{
|
new absl::flat_hash_map<string, std::vector<string>>{
|
||||||
@ -1845,7 +1845,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
|||||||
namespace testing {
|
namespace testing {
|
||||||
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
||||||
|
|
||||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||||
absl::flat_hash_set<string> result{"AdjustContrastv2",
|
absl::flat_hash_set<string> result{"AdjustContrastv2",
|
||||||
"AdjustHue",
|
"AdjustHue",
|
||||||
"AdjustSaturation",
|
"AdjustSaturation",
|
||||||
|
@ -58,7 +58,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
|||||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||||
uncompilable_node_info = nullptr);
|
uncompilable_node_info = nullptr);
|
||||||
|
|
||||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable();
|
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
|
||||||
|
|
||||||
namespace testing {
|
namespace testing {
|
||||||
// DO NOT USE IN PRODUCTION.
|
// DO NOT USE IN PRODUCTION.
|
||||||
@ -66,8 +66,8 @@ namespace testing {
|
|||||||
// Resets some internal state to let us write reliable unit tests.
|
// Resets some internal state to let us write reliable unit tests.
|
||||||
void ResetClusterSequenceNumber();
|
void ResetClusterSequenceNumber();
|
||||||
|
|
||||||
// Return a list of operation that we choose not to put into the whitelist.
|
// Return a list of operation that we choose not to put into the allowlist.
|
||||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp();
|
absl::flat_hash_set<string> GetKnownXLAAllowlistOp();
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -1802,34 +1802,34 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
|
|||||||
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TEST(XlaCompilationTest, XLALiteWhitelist) {
|
TEST(XlaCompilationTest, XLALiteAllowlist) {
|
||||||
auto* whitelist_table = tensorflow::GetWhitelistTable();
|
auto* allowlist_table = tensorflow::GetAllowlistTable();
|
||||||
absl::flat_hash_set<string> hwhitelist;
|
absl::flat_hash_set<string> hallowlist;
|
||||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||||
|
|
||||||
// Check that all the operations in the table are existing TF operations
|
// Check that all the operations in the table are existing TF operations
|
||||||
for (auto pair : *whitelist_table) {
|
for (auto pair : *allowlist_table) {
|
||||||
hwhitelist.insert(pair.second.begin(), pair.second.end());
|
hallowlist.insert(pair.second.begin(), pair.second.end());
|
||||||
for (auto op : pair.second) {
|
for (auto op : pair.second) {
|
||||||
ASSERT_TRUE(all_ops.contains(op));
|
ASSERT_TRUE(all_ops.contains(op));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that all registered XLA operation are in the whitelist
|
// Check that all registered XLA operation are in the allowlist
|
||||||
// table or are known to not be in it.
|
// table or are known to not be in it.
|
||||||
|
|
||||||
absl::flat_hash_set<string> known_not_in_list =
|
absl::flat_hash_set<string> known_not_in_list =
|
||||||
tensorflow::testing::GetKnownXLAWhitelistOp();
|
tensorflow::testing::GetKnownXLAAllowlistOp();
|
||||||
std::vector<string> unknow_op;
|
std::vector<string> unknow_op;
|
||||||
for (string op : vall_ops) {
|
for (string op : vall_ops) {
|
||||||
if (!hwhitelist.contains(op) && !known_not_in_list.contains(op)) {
|
if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
|
||||||
unknow_op.push_back(op);
|
unknow_op.push_back(op);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EXPECT_TRUE(unknow_op.empty())
|
EXPECT_TRUE(unknow_op.empty())
|
||||||
<< "Someone added support for a new TF opeations inside XLA. They must "
|
<< "Someone added support for a new TF opeations inside XLA. They must "
|
||||||
"be included in the XLALite whitelist or blacklist:\n"
|
"be included in the XLALite allowlist or blacklist:\n"
|
||||||
<< absl::StrJoin(unknow_op, "\n");
|
<< absl::StrJoin(unknow_op, "\n");
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -30,7 +30,7 @@ struct PassConfig {
|
|||||||
explicit PassConfig(QuantizationSpecs specs)
|
explicit PassConfig(QuantizationSpecs specs)
|
||||||
: emit_builtin_tflite_ops(true),
|
: emit_builtin_tflite_ops(true),
|
||||||
lower_tensor_list_ops(false),
|
lower_tensor_list_ops(false),
|
||||||
trim_functions_whitelist({}),
|
trim_functions_allowlist({}),
|
||||||
quant_specs(std::move(specs)),
|
quant_specs(std::move(specs)),
|
||||||
form_clusters(false),
|
form_clusters(false),
|
||||||
unfold_batch_matmul(true),
|
unfold_batch_matmul(true),
|
||||||
@ -44,8 +44,8 @@ struct PassConfig {
|
|||||||
// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
|
// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
|
||||||
// TF ops before legalization to TF Lite dialect.
|
// TF ops before legalization to TF Lite dialect.
|
||||||
bool lower_tensor_list_ops;
|
bool lower_tensor_list_ops;
|
||||||
// The whitelist of functions that would be preserved after trimming.
|
// The allowlist of functions that would be preserved after trimming.
|
||||||
llvm::ArrayRef<std::string> trim_functions_whitelist;
|
llvm::ArrayRef<std::string> trim_functions_allowlist;
|
||||||
// All information about quantization.
|
// All information about quantization.
|
||||||
QuantizationSpecs quant_specs;
|
QuantizationSpecs quant_specs;
|
||||||
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
||||||
|
@ -71,7 +71,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/string_util.h"
|
#include "tensorflow/lite/string_util.h"
|
||||||
@ -101,7 +101,7 @@ using mlir::Value;
|
|||||||
using tensorflow::OpOrArgLocNameMapper;
|
using tensorflow::OpOrArgLocNameMapper;
|
||||||
using tensorflow::OpOrArgNameMapper;
|
using tensorflow::OpOrArgNameMapper;
|
||||||
using tensorflow::Status;
|
using tensorflow::Status;
|
||||||
using tflite::flex::IsWhitelistedFlexOp;
|
using tflite::flex::IsAllowlistedFlexOp;
|
||||||
using xla::StatusOr;
|
using xla::StatusOr;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -972,7 +972,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
// model is of an open op system.
|
// model is of an open op system.
|
||||||
//
|
//
|
||||||
// The following algorithm is followed:
|
// The following algorithm is followed:
|
||||||
// if flex is enabled and the op is whitelisted as flex
|
// if flex is enabled and the op is allowlisted as flex
|
||||||
// we emit op as flex.
|
// we emit op as flex.
|
||||||
// if custom is enabled
|
// if custom is enabled
|
||||||
// we emit the op as custom.
|
// we emit the op as custom.
|
||||||
@ -982,11 +982,11 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Flex op case
|
// Flex op case
|
||||||
// Eventually, the whitelist will go away and we will rely on some TF op
|
// Eventually, the allowlist will go away and we will rely on some TF op
|
||||||
// trait (e.g. No side effect) to determine if it is a supported "Flex"
|
// trait (e.g. No side effect) to determine if it is a supported "Flex"
|
||||||
// op or not.
|
// op or not.
|
||||||
if (enabled_op_types_.contains(OpType::kSelectTf) &&
|
if (enabled_op_types_.contains(OpType::kSelectTf) &&
|
||||||
IsWhitelistedFlexOp(node_def->op())) {
|
IsAllowlistedFlexOp(node_def->op())) {
|
||||||
// Construct ops as flex op encoding TensorFlow node definition
|
// Construct ops as flex op encoding TensorFlow node definition
|
||||||
// as custom options.
|
// as custom options.
|
||||||
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
||||||
@ -1037,7 +1037,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Insert failed op to `flex_ops` or `custom_ops`.
|
// Insert failed op to `flex_ops` or `custom_ops`.
|
||||||
if (IsWhitelistedFlexOp(node_def->op())) {
|
if (IsAllowlistedFlexOp(node_def->op())) {
|
||||||
failed_flex_ops_.insert(os.str());
|
failed_flex_ops_.insert(os.str());
|
||||||
} else {
|
} else {
|
||||||
failed_custom_ops_.insert(os.str());
|
failed_custom_ops_.insert(os.str());
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
|
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-allowlist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: quantize_float_placeholder_only
|
// CHECK-LABEL: quantize_float_placeholder_only
|
||||||
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
|
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s
|
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-allowlist="bar,foobar" %s | FileCheck %s
|
||||||
|
|
||||||
func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||||
return %arg0 : tensor<1x4xf32>
|
return %arg0 : tensor<1x4xf32>
|
||||||
|
@ -61,7 +61,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
|
|||||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||||
// pass.
|
// pass.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||||
llvm::ArrayRef<std::string> trim_funcs_whitelist);
|
llvm::ArrayRef<std::string> trim_funcs_allowlist);
|
||||||
|
|
||||||
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
||||||
// pass.
|
// pass.
|
||||||
|
@ -35,9 +35,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::list<std::string> quantize_whitelist(
|
static llvm::cl::list<std::string> quantize_allowlist(
|
||||||
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
|
"tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
|
||||||
llvm::cl::desc("comma separated list of whitelisted functions to be "
|
llvm::cl::desc("comma separated list of allowlisted functions to be "
|
||||||
"quantized. Only used in tests"),
|
"quantized. Only used in tests"),
|
||||||
llvm::cl::CommaSeparated);
|
llvm::cl::CommaSeparated);
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ class PrepareQuantizePass
|
|||||||
|
|
||||||
// Get the min and max values from the quantization specification for the
|
// Get the min and max values from the quantization specification for the
|
||||||
// current function function and argument index. Uses default values if
|
// current function function and argument index. Uses default values if
|
||||||
// the function is specified in the `quantize_whitelist`.
|
// the function is specified in the `quantize_allowlist`.
|
||||||
std::pair<llvm::Optional<double>, llvm::Optional<double>>
|
std::pair<llvm::Optional<double>, llvm::Optional<double>>
|
||||||
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
|
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
|
||||||
if (func_name == quant_specs_.target_func) {
|
if (func_name == quant_specs_.target_func) {
|
||||||
@ -132,7 +132,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
|||||||
// Skip this function because it isn't the target function from the spec or
|
// Skip this function because it isn't the target function from the spec or
|
||||||
// in the function while list.
|
// in the function while list.
|
||||||
if (target_func != func_name &&
|
if (target_func != func_name &&
|
||||||
!llvm::is_contained(quantize_whitelist, func_name)) {
|
!llvm::is_contained(quantize_allowlist, func_name)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,12 +29,12 @@ limitations under the License.
|
|||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||||
|
|
||||||
// The cmd line flag to specify the whitelist of functions. Rest are trimmed
|
// The cmd line flag to specify the allowlist of functions. Rest are trimmed
|
||||||
// after this pass is run.
|
// after this pass is run.
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
||||||
static llvm::cl::list<std::string> trim_funcs_whitelist(
|
static llvm::cl::list<std::string> trim_funcs_allowlist(
|
||||||
"tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"),
|
"tfl-trim-funcs-allowlist", llvm::cl::value_desc("list"),
|
||||||
llvm::cl::desc("comma separated list of whitelisted functions. The first "
|
llvm::cl::desc("comma separated list of allowlisted functions. The first "
|
||||||
"function specified will be used as main."),
|
"function specified will be used as main."),
|
||||||
llvm::cl::CommaSeparated);
|
llvm::cl::CommaSeparated);
|
||||||
|
|
||||||
@ -43,25 +43,25 @@ namespace TFL {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// The pass to trim functions before we legalize to TFL
|
// The pass to trim functions before we legalize to TFL
|
||||||
// dialect using the specified whitelist.
|
// dialect using the specified allowlist.
|
||||||
class TrimFunctionsPass
|
class TrimFunctionsPass
|
||||||
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
|
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
|
||||||
public:
|
public:
|
||||||
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
explicit TrimFunctionsPass() : trim_funcs_allowlist_(trim_funcs_allowlist) {}
|
||||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
|
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_allowlist)
|
||||||
: trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
: trim_funcs_allowlist_(trim_funcs_allowlist) {}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void runOnOperation() override;
|
void runOnOperation() override;
|
||||||
bool TrimModule();
|
bool TrimModule();
|
||||||
void Verify();
|
void Verify();
|
||||||
|
|
||||||
llvm::ArrayRef<std::string> trim_funcs_whitelist_;
|
llvm::ArrayRef<std::string> trim_funcs_allowlist_;
|
||||||
};
|
};
|
||||||
|
|
||||||
void TrimFunctionsPass::runOnOperation() {
|
void TrimFunctionsPass::runOnOperation() {
|
||||||
// trim the functions in the module using the trim_funcs_whitelist_
|
// trim the functions in the module using the trim_funcs_allowlist_
|
||||||
// by removing functions not in the whitelist.
|
// by removing functions not in the allowlist.
|
||||||
if (TrimModule()) {
|
if (TrimModule()) {
|
||||||
// verify the updated module is still valid, if not signal the
|
// verify the updated module is still valid, if not signal the
|
||||||
// pass as failed.
|
// pass as failed.
|
||||||
@ -70,20 +70,20 @@ void TrimFunctionsPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool TrimFunctionsPass::TrimModule() {
|
bool TrimFunctionsPass::TrimModule() {
|
||||||
// if no trim_funcs_whitelist_ is specified, this pass is a no-op.
|
// if no trim_funcs_allowlist_ is specified, this pass is a no-op.
|
||||||
if (trim_funcs_whitelist_.empty()) return false;
|
if (trim_funcs_allowlist_.empty()) return false;
|
||||||
|
|
||||||
llvm::SmallVector<FuncOp, 4> funcs_to_trim;
|
llvm::SmallVector<FuncOp, 4> funcs_to_trim;
|
||||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||||
if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) {
|
if (llvm::is_contained(trim_funcs_allowlist_, func.getName())) {
|
||||||
// If no main is specified in the whitelist, use the 1st func
|
// If no main is specified in the allowlist, use the 1st func
|
||||||
// in trim_funcs_whitelist as the main.
|
// in trim_funcs_allowlist as the main.
|
||||||
// TODO(ashwinm): Currently tflite flatbuffer export assumes there is
|
// TODO(ashwinm): Currently tflite flatbuffer export assumes there is
|
||||||
// always a main. This is strictly not required for TFlite. We need to
|
// always a main. This is strictly not required for TFlite. We need to
|
||||||
// remove that restriction once we have support to attribute the main
|
// remove that restriction once we have support to attribute the main
|
||||||
// tensorflow function in MLIR TF import using an entry_point attr.
|
// tensorflow function in MLIR TF import using an entry_point attr.
|
||||||
if (!llvm::is_contained(trim_funcs_whitelist_, "main") &&
|
if (!llvm::is_contained(trim_funcs_allowlist_, "main") &&
|
||||||
func.getName() == trim_funcs_whitelist_[0]) {
|
func.getName() == trim_funcs_allowlist_[0]) {
|
||||||
func.setName("main");
|
func.setName("main");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -99,7 +99,7 @@ bool TrimFunctionsPass::TrimModule() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// validate that all reachable functions from the remaining functions are
|
// validate that all reachable functions from the remaining functions are
|
||||||
// also in the whitelist.
|
// also in the allowlist.
|
||||||
void TrimFunctionsPass::Verify() {
|
void TrimFunctionsPass::Verify() {
|
||||||
// TODO(ashwinm): Instead, we should make sure that references to all
|
// TODO(ashwinm): Instead, we should make sure that references to all
|
||||||
// SymbolRefAttrs of all ops are present.
|
// SymbolRefAttrs of all ops are present.
|
||||||
@ -109,7 +109,7 @@ void TrimFunctionsPass::Verify() {
|
|||||||
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
|
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
|
||||||
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
|
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
|
||||||
return getOperation().emitError()
|
return getOperation().emitError()
|
||||||
<< func.getName() << " is not in the funcs whitelist";
|
<< func.getName() << " is not in the funcs allowlist";
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
});
|
});
|
||||||
if (walk_result.wasInterrupted()) return signalPassFailure();
|
if (walk_result.wasInterrupted()) return signalPassFailure();
|
||||||
@ -121,13 +121,13 @@ void TrimFunctionsPass::Verify() {
|
|||||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||||
/// pass.
|
/// pass.
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||||
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
|
llvm::ArrayRef<std::string> trim_funcs_allowlist) {
|
||||||
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
|
return std::make_unique<TrimFunctionsPass>(trim_funcs_allowlist);
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<TrimFunctionsPass> pass(
|
static PassRegistration<TrimFunctionsPass> pass(
|
||||||
"tfl-trim-funcs-tf",
|
"tfl-trim-funcs-tf",
|
||||||
"Trim functions to restrict them to a specified whitelist prior to "
|
"Trim functions to restrict them to a specified allowlist prior to "
|
||||||
"legalization to TensorFlow lite dialect");
|
"legalization to TensorFlow lite dialect");
|
||||||
|
|
||||||
} // namespace TFL
|
} // namespace TFL
|
||||||
|
@ -23,8 +23,8 @@ func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
|||||||
return %0 : tensor<2xf32>
|
return %0 : tensor<2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: not_whitelisted_op
|
// CHECK-LABEL: not_allowlisted_op
|
||||||
func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||||
// CHECK: tf.TensorListReserve
|
// CHECK: tf.TensorListReserve
|
||||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||||
// CHECK: tf.TensorListGetItem
|
// CHECK: tf.TensorListGetItem
|
||||||
|
@ -75,10 +75,10 @@ namespace {
|
|||||||
template <typename T, size_t N>
|
template <typename T, size_t N>
|
||||||
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
|
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
|
||||||
|
|
||||||
static bool IsOpWhitelisted(Operation* op) {
|
static bool IsOpAllowlisted(Operation* op) {
|
||||||
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
|
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
|
||||||
// building valid MLIR using MlirHloBuilder.
|
// building valid MLIR using MlirHloBuilder.
|
||||||
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
// TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
|
||||||
// all tf2xla kernels.
|
// all tf2xla kernels.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
||||||
@ -342,7 +342,7 @@ LogicalResult FuncLegalizer::Legalize() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
||||||
if (!IsOpWhitelisted(op)) return success();
|
if (!IsOpAllowlisted(op)) return success();
|
||||||
|
|
||||||
// Only static shaped operands are supported in XLA builders for now.
|
// Only static shaped operands are supported in XLA builders for now.
|
||||||
for (Type ty : op->getOperandTypes()) {
|
for (Type ty : op->getOperandTypes()) {
|
||||||
|
@ -63,7 +63,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
|||||||
if (x.name != y.name) return true;
|
if (x.name != y.name) return true;
|
||||||
if (x.label != y.label) return true;
|
if (x.label != y.label) return true;
|
||||||
// The registrations refer to the same Op: ensures they are compatible and
|
// The registrations refer to the same Op: ensures they are compatible and
|
||||||
// are restricted to different device whitelists.
|
// are restricted to different device allowlists.
|
||||||
if (x.compilation_only != y.compilation_only) {
|
if (x.compilation_only != y.compilation_only) {
|
||||||
LOG(WARNING) << "Registrations of " << x.name
|
LOG(WARNING) << "Registrations of " << x.name
|
||||||
<< " have incompatible compilation_only settings.";
|
<< " have incompatible compilation_only settings.";
|
||||||
@ -84,14 +84,14 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
|||||||
<< " have incompatible allow_string_type settings.";
|
<< " have incompatible allow_string_type settings.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!x.has_device_whitelist && !y.has_device_whitelist) {
|
if (!x.has_device_allowlist && !y.has_device_allowlist) {
|
||||||
LOG(WARNING) << "Duplicate registrations of " << x.name
|
LOG(WARNING) << "Duplicate registrations of " << x.name
|
||||||
<< "with no device whitelists.";
|
<< "with no device allowlists.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (x.has_device_whitelist && y.has_device_whitelist) {
|
if (x.has_device_allowlist && y.has_device_allowlist) {
|
||||||
for (const auto& device : x.device_whitelist) {
|
for (const auto& device : x.device_allowlist) {
|
||||||
if (y.device_whitelist.count(device) != 0) {
|
if (y.device_allowlist.count(device) != 0) {
|
||||||
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
|
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
|
||||||
<< device;
|
<< device;
|
||||||
return false;
|
return false;
|
||||||
@ -185,28 +185,28 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
// The goal is to allow the co-existence of backend-specific kernels and
|
// The goal is to allow the co-existence of backend-specific kernels and
|
||||||
// generic kernels. To achieve this, we enforce the following order of
|
// generic kernels. To achieve this, we enforce the following order of
|
||||||
// registrations for one op:
|
// registrations for one op:
|
||||||
// 1. Process op registration with device whitelists:
|
// 1. Process op registration with device allowlists:
|
||||||
// this pass registers backend-specific kernels for this op.
|
// this pass registers backend-specific kernels for this op.
|
||||||
// 2. Process op registration without device whitelists:
|
// 2. Process op registration without device allowlists:
|
||||||
// this pass registers the kernels for all the other supported backends.
|
// this pass registers the kernels for all the other supported backends.
|
||||||
for (auto& ops : registry.ops_) {
|
for (auto& ops : registry.ops_) {
|
||||||
const string& op_name = ops.first;
|
const string& op_name = ops.first;
|
||||||
std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
|
std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
|
||||||
// Partition the op registration so that the ones with device whitelists
|
// Partition the op registration so that the ones with device allowlists
|
||||||
// precede the one without device whitelist.
|
// precede the one without device allowlist.
|
||||||
std::partition(op_registrations.begin(), op_registrations.end(),
|
std::partition(op_registrations.begin(), op_registrations.end(),
|
||||||
[](const std::unique_ptr<OpRegistration>& op_reg) {
|
[](const std::unique_ptr<OpRegistration>& op_reg) {
|
||||||
return op_reg->has_device_whitelist;
|
return op_reg->has_device_allowlist;
|
||||||
});
|
});
|
||||||
|
|
||||||
// Collect a set of backend registered by ops with device whitelists.
|
// Collect a set of backend registered by ops with device allowlists.
|
||||||
// The op registration without whitelists will register a generic kernel
|
// The op registration without allowlists will register a generic kernel
|
||||||
// for all other backends not in this set.
|
// for all other backends not in this set.
|
||||||
std::unordered_set<string> whitelisted_backend;
|
std::unordered_set<string> allowlisted_backend;
|
||||||
for (auto& op_registration : op_registrations) {
|
for (auto& op_registration : op_registrations) {
|
||||||
if (op_registration->has_device_whitelist) {
|
if (op_registration->has_device_allowlist) {
|
||||||
whitelisted_backend.insert(op_registration->device_whitelist.begin(),
|
allowlisted_backend.insert(op_registration->device_allowlist.begin(),
|
||||||
op_registration->device_whitelist.end());
|
op_registration->device_allowlist.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,19 +238,19 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto& backend : registry.backends_) {
|
for (auto& backend : registry.backends_) {
|
||||||
// If the operator has a device whitelist, only register on whitelisted
|
// If the operator has a device allowlist, only register on allowlisted
|
||||||
// devices.
|
// devices.
|
||||||
if (op_registration->has_device_whitelist &&
|
if (op_registration->has_device_allowlist &&
|
||||||
op_registration->device_whitelist.find(backend.first) ==
|
op_registration->device_allowlist.find(backend.first) ==
|
||||||
op_registration->device_whitelist.end()) {
|
op_registration->device_allowlist.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the operator does NOT has a device whitelist, skip all devices
|
// If the operator does NOT has a device allowlist, skip all devices
|
||||||
// that has already been registered.
|
// that has already been registered.
|
||||||
if (!op_registration->has_device_whitelist &&
|
if (!op_registration->has_device_allowlist &&
|
||||||
whitelisted_backend.find(backend.first) !=
|
allowlisted_backend.find(backend.first) !=
|
||||||
whitelisted_backend.end()) {
|
allowlisted_backend.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -478,17 +478,17 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(
|
|||||||
|
|
||||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
||||||
absl::Span<const absl::string_view> devices) {
|
absl::Span<const absl::string_view> devices) {
|
||||||
registration_->has_device_whitelist = true;
|
registration_->has_device_allowlist = true;
|
||||||
for (absl::string_view device : devices) {
|
for (absl::string_view device : devices) {
|
||||||
registration_->device_whitelist.emplace(device);
|
registration_->device_allowlist.emplace(device);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
||||||
absl::string_view device) {
|
absl::string_view device) {
|
||||||
registration_->has_device_whitelist = true;
|
registration_->has_device_allowlist = true;
|
||||||
registration_->device_whitelist.emplace(device);
|
registration_->device_allowlist.emplace(device);
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -258,10 +258,10 @@ class XlaOpRegistry {
|
|||||||
// Mapping from attribute name to a list of supported types.
|
// Mapping from attribute name to a list of supported types.
|
||||||
std::unordered_map<string, std::set<DataType>> type_constraints;
|
std::unordered_map<string, std::set<DataType>> type_constraints;
|
||||||
|
|
||||||
// An optional whitelist of devices. If there is no whitelist, all devices
|
// An optional allowlist of devices. If there is no allowlist, all devices
|
||||||
// are permitted.
|
// are permitted.
|
||||||
bool has_device_whitelist = false;
|
bool has_device_allowlist = false;
|
||||||
std::unordered_set<string> device_whitelist;
|
std::unordered_set<string> device_allowlist;
|
||||||
|
|
||||||
// Names of arguments that must be compile-time constants.
|
// Names of arguments that must be compile-time constants.
|
||||||
std::unordered_set<string> compile_time_constant_inputs;
|
std::unordered_set<string> compile_time_constant_inputs;
|
||||||
@ -279,8 +279,8 @@ class XlaOpRegistry {
|
|||||||
// Returns true if registrations x and y can both be added to the registry.
|
// Returns true if registrations x and y can both be added to the registry.
|
||||||
// This is always the case if they refer to different ops. If they refer to
|
// This is always the case if they refer to different ops. If they refer to
|
||||||
// the same op name, they must: have the same values for compilation_only,
|
// the same op name, they must: have the same values for compilation_only,
|
||||||
// allow_resource_types and allow_variant_types; use a device_whitelist; and
|
// allow_resource_types and allow_variant_types; use a device_allowlist; and
|
||||||
// their whitelists must not intersect.
|
// their allowlists must not intersect.
|
||||||
static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
|
static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
|
||||||
|
|
||||||
static Status CompileTimeConstantInputs(const NodeDef& node_def,
|
static Status CompileTimeConstantInputs(const NodeDef& node_def,
|
||||||
@ -319,7 +319,7 @@ class XlaOpRegistrationBuilder {
|
|||||||
// Starts an operator registration chain.
|
// Starts an operator registration chain.
|
||||||
static XlaOpRegistrationBuilder Name(absl::string_view name);
|
static XlaOpRegistrationBuilder Name(absl::string_view name);
|
||||||
|
|
||||||
// Specifies a whitelist of devices on which the operator may run.
|
// Specifies a allowlist of devices on which the operator may run.
|
||||||
XlaOpRegistrationBuilder& Device(absl::string_view devices);
|
XlaOpRegistrationBuilder& Device(absl::string_view devices);
|
||||||
XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
|
XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
|
||||||
|
|
||||||
|
@ -378,7 +378,7 @@ struct TensorAndDevice {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Tensors of some DataTypes cannot placed in device memory as feeds or
|
// Tensors of some DataTypes cannot placed in device memory as feeds or
|
||||||
// fetches. Validate against a whitelist of those known to work.
|
// fetches. Validate against a allowlist of those known to work.
|
||||||
bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
|
bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
|
||||||
// The mechanism for supporting feeds of device-backed Tensors requires
|
// The mechanism for supporting feeds of device-backed Tensors requires
|
||||||
// the _Arg kernel to be registered for the corresponding type (and that
|
// the _Arg kernel to be registered for the corresponding type (and that
|
||||||
@ -391,7 +391,7 @@ bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
|
|||||||
// For now, we return true iff there are _Arg AND _Retval kernels for dtype on
|
// For now, we return true iff there are _Arg AND _Retval kernels for dtype on
|
||||||
// the device. False negatives are okay, false positives would be bad.
|
// the device. False negatives are okay, false positives would be bad.
|
||||||
//
|
//
|
||||||
// TODO(ashankar): Instead of a whitelist here, perhaps we could query
|
// TODO(ashankar): Instead of a allowlist here, perhaps we could query
|
||||||
// the kernel registry for _Arg and _Retval kernels instead.
|
// the kernel registry for _Arg and _Retval kernels instead.
|
||||||
if (device_type == DEVICE_CPU) return true;
|
if (device_type == DEVICE_CPU) return true;
|
||||||
if (device_type != DEVICE_GPU) return false;
|
if (device_type != DEVICE_GPU) return false;
|
||||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
|
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
|
||||||
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
|
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
|
||||||
|
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -23,7 +23,7 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
// Registry for stateful ops that need to be used in dataset functions.
|
// Registry for stateful ops that need to be used in dataset functions.
|
||||||
// See below macro for usage details.
|
// See below macro for usage details.
|
||||||
class WhitelistedStatefulOpRegistry {
|
class AllowlistedStatefulOpRegistry {
|
||||||
public:
|
public:
|
||||||
Status Add(string op_name) {
|
Status Add(string op_name) {
|
||||||
op_names_.insert(std::move(op_name));
|
op_names_.insert(std::move(op_name));
|
||||||
@ -37,29 +37,29 @@ class WhitelistedStatefulOpRegistry {
|
|||||||
|
|
||||||
bool Contains(const string& op_name) { return op_names_.count(op_name); }
|
bool Contains(const string& op_name) { return op_names_.count(op_name); }
|
||||||
|
|
||||||
static WhitelistedStatefulOpRegistry* Global() {
|
static AllowlistedStatefulOpRegistry* Global() {
|
||||||
static auto* reg = new WhitelistedStatefulOpRegistry;
|
static auto* reg = new AllowlistedStatefulOpRegistry;
|
||||||
return reg;
|
return reg;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
WhitelistedStatefulOpRegistry() = default;
|
AllowlistedStatefulOpRegistry() = default;
|
||||||
WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) =
|
AllowlistedStatefulOpRegistry(AllowlistedStatefulOpRegistry const& copy) =
|
||||||
delete;
|
delete;
|
||||||
WhitelistedStatefulOpRegistry operator=(
|
AllowlistedStatefulOpRegistry operator=(
|
||||||
WhitelistedStatefulOpRegistry const& copy) = delete;
|
AllowlistedStatefulOpRegistry const& copy) = delete;
|
||||||
|
|
||||||
std::unordered_set<string> op_names_;
|
std::unordered_set<string> op_names_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
|
||||||
// Use this macro to whitelist an op that is marked stateful but needs to be
|
// Use this macro to allowlist an op that is marked stateful but needs to be
|
||||||
// used inside a map_fn in an input pipeline. This is only needed if you wish
|
// used inside a map_fn in an input pipeline. This is only needed if you wish
|
||||||
// to be able to checkpoint the state of the input pipeline. We currently
|
// to be able to checkpoint the state of the input pipeline. We currently
|
||||||
// do not allow stateful ops to be defined inside of map_fns since it is not
|
// do not allow stateful ops to be defined inside of map_fns since it is not
|
||||||
// possible to save their state.
|
// possible to save their state.
|
||||||
// Note that the state of the whitelisted ops inside functions will not be
|
// Note that the state of the allowlisted ops inside functions will not be
|
||||||
// saved during checkpointing, hence this should only be used if the op is
|
// saved during checkpointing, hence this should only be used if the op is
|
||||||
// marked stateful for reasons like to avoid constant folding during graph
|
// marked stateful for reasons like to avoid constant folding during graph
|
||||||
// optimization but is not stateful.
|
// optimization but is not stateful.
|
||||||
@ -73,9 +73,9 @@ class WhitelistedStatefulOpRegistry {
|
|||||||
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
|
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
|
||||||
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
|
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
|
||||||
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
|
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
|
||||||
static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
|
static ::tensorflow::Status allowlist_op##ctr TF_ATTRIBUTE_UNUSED = \
|
||||||
::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name)
|
::tensorflow::data::AllowlistedStatefulOpRegistry::Global()->Add(name)
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
|
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
|
||||||
|
@ -542,8 +542,8 @@ bool IsNumericType(const DataType dtype) {
|
|||||||
return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
|
return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
|
bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
|
||||||
static const gtl::FlatSet<string>* const kOpTpeWhitelist =
|
static const gtl::FlatSet<string>* const kOpTpeAllowlist =
|
||||||
CHECK_NOTNULL((new gtl::FlatSet<string>{
|
CHECK_NOTNULL((new gtl::FlatSet<string>{
|
||||||
// Unary arithmetic ops
|
// Unary arithmetic ops
|
||||||
"Floor",
|
"Floor",
|
||||||
@ -589,7 +589,7 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
|
|||||||
"Fill",
|
"Fill",
|
||||||
"Cast",
|
"Cast",
|
||||||
}));
|
}));
|
||||||
return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
|
return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Negative shape size of '-1' represents unknown, while negative shape sizes
|
// Negative shape size of '-1' represents unknown, while negative shape sizes
|
||||||
@ -1441,7 +1441,7 @@ class SymbolicShapeRefiner {
|
|||||||
|
|
||||||
// Due to the cost of running EvaluateNode(), we limit only to white listed
|
// Due to the cost of running EvaluateNode(), we limit only to white listed
|
||||||
// op types.
|
// op types.
|
||||||
if (!IsWhiteListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
|
if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1008,7 +1008,7 @@ TEST_F(GraphPropertiesTest, IdentityPassingShape) {
|
|||||||
|
|
||||||
TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
|
TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
|
||||||
// When using aggressive_shape_inference, we run EvaluateNode() for
|
// When using aggressive_shape_inference, we run EvaluateNode() for
|
||||||
// whitelisted ops and small input / output tensors. For instance, Fill op is
|
// allowlisted ops and small input / output tensors. For instance, Fill op is
|
||||||
// evaluated and produces output tensor value if output tensor size is smal
|
// evaluated and produces output tensor value if output tensor size is smal
|
||||||
// (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
|
// (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
|
||||||
// This is to avoid wasting time and memory for producing huge tensors (e.g.,
|
// This is to avoid wasting time and memory for producing huge tensors (e.g.,
|
||||||
|
@ -842,11 +842,11 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
|
|||||||
return AllowedDataTypes(*attr_def);
|
return AllowedDataTypes(*attr_def);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ValidateLists(const gtl::FlatSet<string>& white_list,
|
Status ValidateLists(const gtl::FlatSet<string>& allow_list,
|
||||||
const gtl::FlatSet<string>& black_list,
|
const gtl::FlatSet<string>& black_list,
|
||||||
const gtl::FlatSet<string>& gray_list,
|
const gtl::FlatSet<string>& gray_list,
|
||||||
const gtl::FlatSet<string>& clear_list) {
|
const gtl::FlatSet<string>& clear_list) {
|
||||||
std::vector<gtl::FlatSet<string>> lists{white_list, black_list, gray_list,
|
std::vector<gtl::FlatSet<string>> lists{allow_list, black_list, gray_list,
|
||||||
clear_list};
|
clear_list};
|
||||||
std::multiset<string> counts;
|
std::multiset<string> counts;
|
||||||
for (const auto& list : lists) {
|
for (const auto& list : lists) {
|
||||||
@ -973,25 +973,25 @@ class AutoMixedPrecisionImpl {
|
|||||||
void FindTensorListImplicitFloat32Edges(
|
void FindTensorListImplicitFloat32Edges(
|
||||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||||
std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
|
std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
|
||||||
void AddWhitelistOps(absl::flat_hash_set<int>* white_set) const;
|
void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
|
||||||
void PropagateBlackFwdThroughClearAndGray(
|
void PropagateBlackFwdThroughClearAndGray(
|
||||||
absl::flat_hash_set<int>* black_set) const;
|
absl::flat_hash_set<int>* black_set) const;
|
||||||
void ForceColorMatchBetweenTensorListOps(
|
void ForceColorMatchBetweenTensorListOps(
|
||||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||||
absl::flat_hash_set<int>* white_set,
|
absl::flat_hash_set<int>* allow_set,
|
||||||
absl::flat_hash_set<int>* black_set) const;
|
absl::flat_hash_set<int>* black_set) const;
|
||||||
void AddClearAndGrayToWhiteIfBetweenWhite(
|
void AddClearAndGrayToAllowIfBetweenAllow(
|
||||||
const absl::flat_hash_set<int>& black_set,
|
const absl::flat_hash_set<int>& black_set,
|
||||||
absl::flat_hash_set<int>* white_set) const;
|
absl::flat_hash_set<int>* allow_set) const;
|
||||||
void PropagateWhiteThroughClear(const absl::flat_hash_set<int>& black_set,
|
void PropagateAllowThroughClear(const absl::flat_hash_set<int>& black_set,
|
||||||
absl::flat_hash_set<int>* white_set) const;
|
absl::flat_hash_set<int>* allow_set) const;
|
||||||
Status ForceColorMatchOnRecurrentEdges(
|
Status ForceColorMatchOnRecurrentEdges(
|
||||||
absl::flat_hash_set<int>* white_set) const;
|
absl::flat_hash_set<int>* allow_set) const;
|
||||||
void MakeCastsWhiteIfAllOutputsWhite(
|
void MakeCastsAllowIfAllOutputsAllow(
|
||||||
absl::flat_hash_set<int>* white_set) const;
|
absl::flat_hash_set<int>* allow_set) const;
|
||||||
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
|
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
|
||||||
const string& device) const;
|
const string& device) const;
|
||||||
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
|
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
|
||||||
|
|
||||||
VirtualPlacer virtual_placer_;
|
VirtualPlacer virtual_placer_;
|
||||||
std::unordered_set<string> nodes_to_preserve_;
|
std::unordered_set<string> nodes_to_preserve_;
|
||||||
@ -1005,7 +1005,7 @@ class AutoMixedPrecisionImpl {
|
|||||||
GraphTypeTopologyView graph_type_view_;
|
GraphTypeTopologyView graph_type_view_;
|
||||||
bool force_all_fp16_;
|
bool force_all_fp16_;
|
||||||
AutoMixedPrecisionMode mode_;
|
AutoMixedPrecisionMode mode_;
|
||||||
gtl::FlatSet<string> f16_whitelist_;
|
gtl::FlatSet<string> f16_allowlist_;
|
||||||
gtl::FlatSet<string> f16_blacklist_;
|
gtl::FlatSet<string> f16_blacklist_;
|
||||||
gtl::FlatSet<string> f16_graylist_;
|
gtl::FlatSet<string> f16_graylist_;
|
||||||
gtl::FlatSet<string> f16_clearlist_;
|
gtl::FlatSet<string> f16_clearlist_;
|
||||||
@ -1079,8 +1079,8 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
|
|||||||
f.open(fname.c_str(), std::fstream::out);
|
f.open(fname.c_str(), std::fstream::out);
|
||||||
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||||
get_mixed_precision_lists();
|
get_mixed_precision_lists();
|
||||||
f << "WhiteList:\n";
|
f << "AllowList:\n";
|
||||||
for (const auto& x : mp_lists->WhiteList()) {
|
for (const auto& x : mp_lists->AllowList()) {
|
||||||
f << x << "\n";
|
f << x << "\n";
|
||||||
}
|
}
|
||||||
f << "\nBlackList:\n";
|
f << "\nBlackList:\n";
|
||||||
@ -1254,11 +1254,11 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
|||||||
|
|
||||||
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||||
get_mixed_precision_lists();
|
get_mixed_precision_lists();
|
||||||
f16_whitelist_ = mp_lists->WhiteList();
|
f16_allowlist_ = mp_lists->AllowList();
|
||||||
f16_blacklist_ = mp_lists->BlackList();
|
f16_blacklist_ = mp_lists->BlackList();
|
||||||
f16_graylist_ = mp_lists->GrayList();
|
f16_graylist_ = mp_lists->GrayList();
|
||||||
f16_clearlist_ = mp_lists->ClearList();
|
f16_clearlist_ = mp_lists->ClearList();
|
||||||
TF_RETURN_IF_ERROR(ValidateLists(f16_whitelist_, f16_blacklist_,
|
TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_blacklist_,
|
||||||
f16_graylist_, f16_clearlist_));
|
f16_graylist_, f16_clearlist_));
|
||||||
|
|
||||||
size_t timestamp = Env::Default()->NowMicros() / 1000;
|
size_t timestamp = Env::Default()->NowMicros() / 1000;
|
||||||
@ -1316,8 +1316,8 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
|||||||
// boundaries between f16/non-f16 nodes.
|
// boundaries between f16/non-f16 nodes.
|
||||||
|
|
||||||
// The algorithm for deciding which nodes to change to f16 is as follows:
|
// The algorithm for deciding which nodes to change to f16 is as follows:
|
||||||
// 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
|
// 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
|
||||||
// This is done under the assumption that whitelist ops are always
|
// This is done under the assumption that allowlist ops are always
|
||||||
// numerically-safe in f16 and that they are the most important ops for
|
// numerically-safe in f16 and that they are the most important ops for
|
||||||
// improving performance.
|
// improving performance.
|
||||||
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
|
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
|
||||||
@ -1329,20 +1329,20 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
|||||||
// numerical accuracy of the model.
|
// numerical accuracy of the model.
|
||||||
// 3) For all remaining nodes that are not considered dangerous (greylist
|
// 3) For all remaining nodes that are not considered dangerous (greylist
|
||||||
// and clearlist ops), find those that are between (i.e., both upstream
|
// and clearlist ops), find those that are between (i.e., both upstream
|
||||||
// and downstream of) white nodes, and add them to the white_set.
|
// and downstream of) allow nodes, and add them to the allow_set.
|
||||||
// This is done to avoid unnecessary casts between whitelist ops.
|
// This is done to avoid unnecessary casts between allowlist ops.
|
||||||
// 4) For all remaining clearlist nodes, add them to the white_set if they are
|
// 4) For all remaining clearlist nodes, add them to the allow_set if they are
|
||||||
// connected to a node in the white_set via other clearlist nodes.
|
// connected to a node in the allow_set via other clearlist nodes.
|
||||||
// This is done to increase the number of ops in the white_set without
|
// This is done to increase the number of ops in the allow_set without
|
||||||
// affecting numerical stability.
|
// affecting numerical stability.
|
||||||
|
|
||||||
absl::flat_hash_set<int> white_set;
|
absl::flat_hash_set<int> allow_set;
|
||||||
VLOG(2) << "Beginning pass 1 to add whitelist ops";
|
VLOG(2) << "Beginning pass 1 to add allowlist ops";
|
||||||
AddWhitelistOps(&white_set);
|
AddAllowlistOps(&allow_set);
|
||||||
VLOG(2) << "Finished pass 1";
|
VLOG(2) << "Finished pass 1";
|
||||||
|
|
||||||
if (white_set.empty()) {
|
if (allow_set.empty()) {
|
||||||
LOG(INFO) << "No whitelist ops found, nothing to do";
|
LOG(INFO) << "No allowlist ops found, nothing to do";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1353,33 +1353,33 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
|||||||
|
|
||||||
VLOG(2) << "Forcing color match between data structure ops";
|
VLOG(2) << "Forcing color match between data structure ops";
|
||||||
for (const auto& cluster : tensor_list_clusters) {
|
for (const auto& cluster : tensor_list_clusters) {
|
||||||
ForceColorMatchBetweenTensorListOps(cluster, &white_set, &black_set);
|
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(2) << "Beginning pass 3 to set clear and gray nodes to white if they "
|
VLOG(2) << "Beginning pass 3 to set clear and gray nodes to allow if they "
|
||||||
"are between white ops";
|
"are between allow ops";
|
||||||
AddClearAndGrayToWhiteIfBetweenWhite(black_set, &white_set);
|
AddClearAndGrayToAllowIfBetweenAllow(black_set, &allow_set);
|
||||||
VLOG(2) << "Finished pass 3";
|
VLOG(2) << "Finished pass 3";
|
||||||
|
|
||||||
VLOG(2) << "Beginning pass 4 to propagate white from white nodes through "
|
VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
|
||||||
"clearlist ops";
|
"clearlist ops";
|
||||||
PropagateWhiteThroughClear(black_set, &white_set);
|
PropagateAllowThroughClear(black_set, &allow_set);
|
||||||
VLOG(2) << "Finished pass 4";
|
VLOG(2) << "Finished pass 4";
|
||||||
|
|
||||||
VLOG(2) << "Forcing color match between data structure ops";
|
VLOG(2) << "Forcing color match between data structure ops";
|
||||||
for (const auto& cluster : tensor_list_clusters) {
|
for (const auto& cluster : tensor_list_clusters) {
|
||||||
ForceColorMatchBetweenTensorListOps(cluster, &white_set, &black_set);
|
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(2) << "Forcing color match on loop edges";
|
VLOG(2) << "Forcing color match on loop edges";
|
||||||
TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&white_set));
|
TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
|
||||||
|
|
||||||
VLOG(2) << "Finding existing casts that can be made white";
|
VLOG(2) << "Finding existing casts that can be made allow";
|
||||||
MakeCastsWhiteIfAllOutputsWhite(&white_set);
|
MakeCastsAllowIfAllOutputsAllow(&allow_set);
|
||||||
|
|
||||||
VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
|
VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
|
||||||
"ops at paint boundaries";
|
"ops at paint boundaries";
|
||||||
TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(white_set));
|
TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
|
||||||
VLOG(2) << "Finished final pass";
|
VLOG(2) << "Finished final pass";
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
|
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
|
||||||
@ -1516,19 +1516,19 @@ void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoMixedPrecisionImpl::AddWhitelistOps(
|
void AutoMixedPrecisionImpl::AddAllowlistOps(
|
||||||
absl::flat_hash_set<int>* white_set) const {
|
absl::flat_hash_set<int>* allow_set) const {
|
||||||
// Add whitelisted ops to white_set.
|
// Add allowlisted ops to allow_set.
|
||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!ShouldProcess(*root.node)) continue;
|
if (!ShouldProcess(*root.node)) continue;
|
||||||
bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
|
bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
|
||||||
if (f16_whitelist_.count(root.node->op()) || force_white) {
|
if (f16_allowlist_.count(root.node->op()) || force_allow) {
|
||||||
bool inserted = white_set->insert(root_idx).second;
|
bool inserted = allow_set->insert(root_idx).second;
|
||||||
if (VLOG_IS_ON(2) && inserted) {
|
if (VLOG_IS_ON(2) && inserted) {
|
||||||
VLOG(2) << "Painting type " << root.type_attr.DebugString()
|
VLOG(2) << "Painting type " << root.type_attr.DebugString()
|
||||||
<< " of node " << root.node->name() << " WHITE because its op "
|
<< " of node " << root.node->name() << " ALLOW because its op "
|
||||||
<< root.node->op() << " is on the whitelist";
|
<< root.node->op() << " is on the allowlist";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1537,8 +1537,8 @@ void AutoMixedPrecisionImpl::AddWhitelistOps(
|
|||||||
// Adds nodes to black_set iff they are on the blacklist or they are on a
|
// Adds nodes to black_set iff they are on the blacklist or they are on a
|
||||||
// forward path from a blacklist node to a black/gray node (including the node
|
// forward path from a blacklist node to a black/gray node (including the node
|
||||||
// at the end of the path) through clear and gray nodes.
|
// at the end of the path) through clear and gray nodes.
|
||||||
// E.g., black -> gray -> clear -> gray -> clear -> white -> gray
|
// E.g., black -> gray -> clear -> gray -> clear -> allow -> gray
|
||||||
// becomes: black -> black -> black -> black -> clear -> white -> gray.
|
// becomes: black -> black -> black -> black -> clear -> allow -> gray.
|
||||||
void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||||
absl::flat_hash_set<int>* black_set) const {
|
absl::flat_hash_set<int>* black_set) const {
|
||||||
if (force_all_fp16_) return;
|
if (force_all_fp16_) return;
|
||||||
@ -1588,14 +1588,14 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
|
||||||
const absl::flat_hash_set<int>& black_set,
|
const absl::flat_hash_set<int>& black_set,
|
||||||
absl::flat_hash_set<int>* white_set) const {
|
absl::flat_hash_set<int>* allow_set) const {
|
||||||
// Find clear/graylist ops that are downstream of white ops.
|
// Find clear/graylist ops that are downstream of allow ops.
|
||||||
absl::flat_hash_set<int> downstream_of_white_set;
|
absl::flat_hash_set<int> downstream_of_allow_set;
|
||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!ShouldProcess(*root.node) || !f16_whitelist_.count(root.node->op())) {
|
if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DfsTypeTraversal(
|
DfsTypeTraversal(
|
||||||
@ -1603,8 +1603,8 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
|||||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||||
return idx == root_idx ||
|
return idx == root_idx ||
|
||||||
(!downstream_of_white_set.count(idx) &&
|
(!downstream_of_allow_set.count(idx) &&
|
||||||
!f16_whitelist_.count(item.node->op()) &&
|
!f16_allowlist_.count(item.node->op()) &&
|
||||||
!black_set.count(idx) && ShouldProcess(*item.node) &&
|
!black_set.count(idx) && ShouldProcess(*item.node) &&
|
||||||
// TODO(benbarsdell): Consider allowing propagation through
|
// TODO(benbarsdell): Consider allowing propagation through
|
||||||
// ops that are already float16 in order to reduce the number
|
// ops that are already float16 in order to reduce the number
|
||||||
@ -1614,45 +1614,45 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
|||||||
f16_graylist_.count(item.node->op())));
|
f16_graylist_.count(item.node->op())));
|
||||||
}),
|
}),
|
||||||
DfsTypeCallbacks::PreOrder(
|
DfsTypeCallbacks::PreOrder(
|
||||||
[&](int idx) { downstream_of_white_set.insert(idx); }));
|
[&](int idx) { downstream_of_allow_set.insert(idx); }));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set nodes that are both downstream and upstream of white ops to white.
|
// Set nodes that are both downstream and upstream of allow ops to allow.
|
||||||
absl::flat_hash_set<int> upstream_of_white_set;
|
absl::flat_hash_set<int> upstream_of_allow_set;
|
||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
|
if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
|
||||||
!f16_whitelist_.count(root.node->op())) {
|
!f16_allowlist_.count(root.node->op())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DfsTypeTraversal(
|
DfsTypeTraversal(
|
||||||
graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
|
graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
|
||||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||||
return idx == root_idx || (!upstream_of_white_set.count(idx) &&
|
return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
|
||||||
downstream_of_white_set.count(idx));
|
downstream_of_allow_set.count(idx));
|
||||||
}),
|
}),
|
||||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||||
upstream_of_white_set.insert(idx);
|
upstream_of_allow_set.insert(idx);
|
||||||
bool inserted = white_set->insert(idx).second;
|
bool inserted = allow_set->insert(idx).second;
|
||||||
if (VLOG_IS_ON(2) && inserted) {
|
if (VLOG_IS_ON(2) && inserted) {
|
||||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||||
VLOG(2) << "Painting type " << item.type_attr.DebugString()
|
VLOG(2) << "Painting type " << item.type_attr.DebugString()
|
||||||
<< " of " << item.node->op() << " node "
|
<< " of " << item.node->op() << " node "
|
||||||
<< item.node->name() << " WHITE";
|
<< item.node->name() << " ALLOW";
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
|
||||||
const absl::flat_hash_set<int>& black_set,
|
const absl::flat_hash_set<int>& black_set,
|
||||||
absl::flat_hash_set<int>* white_set) const {
|
absl::flat_hash_set<int>* allow_set) const {
|
||||||
// Propagate white from white nodes through clearlist ops.
|
// Propagate allow from allow nodes through clearlist ops.
|
||||||
absl::flat_hash_set<int> clear_prop_set;
|
absl::flat_hash_set<int> clear_prop_set;
|
||||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||||
if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
|
if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
|
||||||
!white_set->count(root_idx)) {
|
!allow_set->count(root_idx)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
DfsTypeTraversal(
|
DfsTypeTraversal(
|
||||||
@ -1661,7 +1661,7 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
|||||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||||
return idx == root_idx ||
|
return idx == root_idx ||
|
||||||
(!white_set->count(idx) && !black_set.count(idx) &&
|
(!allow_set->count(idx) && !black_set.count(idx) &&
|
||||||
ShouldProcess(*item.node) && IsFloat32(item) &&
|
ShouldProcess(*item.node) && IsFloat32(item) &&
|
||||||
SupportsF16(item) &&
|
SupportsF16(item) &&
|
||||||
(f16_clearlist_.count(item.node->op())) &&
|
(f16_clearlist_.count(item.node->op())) &&
|
||||||
@ -1673,30 +1673,30 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
|||||||
}),
|
}),
|
||||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||||
clear_prop_set.insert(idx);
|
clear_prop_set.insert(idx);
|
||||||
bool inserted = white_set->insert(idx).second;
|
bool inserted = allow_set->insert(idx).second;
|
||||||
if (VLOG_IS_ON(2) && inserted) {
|
if (VLOG_IS_ON(2) && inserted) {
|
||||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||||
VLOG(2) << "Painting type " << item.type_attr.DebugString()
|
VLOG(2) << "Painting type " << item.type_attr.DebugString()
|
||||||
<< " of " << item.node->op() << " node "
|
<< " of " << item.node->op() << " node "
|
||||||
<< item.node->name() << " WHITE";
|
<< item.node->name() << " ALLOW";
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forces NextIteration nodes and their output Merge node(s) to have the same
|
// Forces NextIteration nodes and their output Merge node(s) to have the same
|
||||||
// color. Specifically, it removes them all from white_set if any of the Merge
|
// color. Specifically, it removes them all from allow_set if any of the Merge
|
||||||
// nodes is not in white_set, otherwise it adds the NextIteration node to
|
// nodes is not in allow_set, otherwise it adds the NextIteration node to
|
||||||
// white_set.
|
// allow_set.
|
||||||
Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||||
absl::flat_hash_set<int>* white_set) const {
|
absl::flat_hash_set<int>* allow_set) const {
|
||||||
for (const NodeDef& node : graph_->node()) {
|
for (const NodeDef& node : graph_->node()) {
|
||||||
if (node.op() == "NextIteration") {
|
if (node.op() == "NextIteration") {
|
||||||
GraphView::OutputPort output_port(&node, 0);
|
GraphView::OutputPort output_port(&node, 0);
|
||||||
const auto& fanout = graph_view_.GetFanout(output_port);
|
const auto& fanout = graph_view_.GetFanout(output_port);
|
||||||
std::vector<int> merge_idxs;
|
std::vector<int> merge_idxs;
|
||||||
merge_idxs.reserve(fanout.size());
|
merge_idxs.reserve(fanout.size());
|
||||||
bool any_merge_is_not_white = false;
|
bool any_merge_is_not_allow = false;
|
||||||
for (const auto& output : fanout) {
|
for (const auto& output : fanout) {
|
||||||
const NodeDef& merge_node = *output.node;
|
const NodeDef& merge_node = *output.node;
|
||||||
if (merge_node.op() != "Merge") {
|
if (merge_node.op() != "Merge") {
|
||||||
@ -1712,8 +1712,8 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
|||||||
}
|
}
|
||||||
int merge_idx = maybe_merge_idx.value();
|
int merge_idx = maybe_merge_idx.value();
|
||||||
merge_idxs.push_back(merge_idx);
|
merge_idxs.push_back(merge_idx);
|
||||||
any_merge_is_not_white =
|
any_merge_is_not_allow =
|
||||||
any_merge_is_not_white || !white_set->count(merge_idx);
|
any_merge_is_not_allow || !allow_set->count(merge_idx);
|
||||||
}
|
}
|
||||||
const absl::optional<int> maybe_nextiter_idx =
|
const absl::optional<int> maybe_nextiter_idx =
|
||||||
graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
|
graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
|
||||||
@ -1722,9 +1722,9 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
|||||||
node.name(), " not found in graph view");
|
node.name(), " not found in graph view");
|
||||||
}
|
}
|
||||||
int nextiter_idx = maybe_nextiter_idx.value();
|
int nextiter_idx = maybe_nextiter_idx.value();
|
||||||
if (any_merge_is_not_white) {
|
if (any_merge_is_not_allow) {
|
||||||
for (int merge_idx : merge_idxs) {
|
for (int merge_idx : merge_idxs) {
|
||||||
if (white_set->erase(merge_idx)) {
|
if (allow_set->erase(merge_idx)) {
|
||||||
VLOG(2) << "Painting type T of Merge node "
|
VLOG(2) << "Painting type T of Merge node "
|
||||||
<< graph_type_view_.GetNode(merge_idx)->node->name()
|
<< graph_type_view_.GetNode(merge_idx)->node->name()
|
||||||
<< " BLACK to match the color of its sibling Merge nodes "
|
<< " BLACK to match the color of its sibling Merge nodes "
|
||||||
@ -1732,14 +1732,14 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
|||||||
<< node.name();
|
<< node.name();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (white_set->erase(nextiter_idx)) {
|
if (allow_set->erase(nextiter_idx)) {
|
||||||
VLOG(2) << "Painting type T of NextIteration node " << node.name()
|
VLOG(2) << "Painting type T of NextIteration node " << node.name()
|
||||||
<< " BLACK to match the color of its output Merge node(s)";
|
<< " BLACK to match the color of its output Merge node(s)";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (white_set->insert(nextiter_idx).second) {
|
if (allow_set->insert(nextiter_idx).second) {
|
||||||
VLOG(2) << "Painting type T of NextIteration node " << node.name()
|
VLOG(2) << "Painting type T of NextIteration node " << node.name()
|
||||||
<< " WHITE to match the color of its output Merge node(s)";
|
<< " ALLOW to match the color of its output Merge node(s)";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1750,10 +1750,10 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
|||||||
// Forces all of the given Tensor List nodes into the same color set.
|
// Forces all of the given Tensor List nodes into the same color set.
|
||||||
void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
|
void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
|
||||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||||
absl::flat_hash_set<int>* white_set,
|
absl::flat_hash_set<int>* allow_set,
|
||||||
absl::flat_hash_set<int>* black_set) const {
|
absl::flat_hash_set<int>* black_set) const {
|
||||||
bool any_black = false;
|
bool any_black = false;
|
||||||
bool any_white = false;
|
bool any_allow = false;
|
||||||
std::vector<int> node_type_idxs;
|
std::vector<int> node_type_idxs;
|
||||||
node_type_idxs.reserve(tensor_list_nodes.size());
|
node_type_idxs.reserve(tensor_list_nodes.size());
|
||||||
for (const NodeDef* node : tensor_list_nodes) {
|
for (const NodeDef* node : tensor_list_nodes) {
|
||||||
@ -1769,23 +1769,23 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
|
|||||||
if (black_set->count(node_type_idx)) {
|
if (black_set->count(node_type_idx)) {
|
||||||
any_black = true;
|
any_black = true;
|
||||||
break;
|
break;
|
||||||
} else if (white_set->count(node_type_idx)) {
|
} else if (allow_set->count(node_type_idx)) {
|
||||||
any_white = true;
|
any_allow = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!any_black && !any_white) return;
|
if (!any_black && !any_allow) return;
|
||||||
for (int node_type_idx : node_type_idxs) {
|
for (int node_type_idx : node_type_idxs) {
|
||||||
const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
|
const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
|
||||||
VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
|
VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
|
||||||
<< node_type.node->op() << " node " << node_type.node->name() << " "
|
<< node_type.node->op() << " node " << node_type.node->name() << " "
|
||||||
<< (any_black ? "BLACK" : "WHITE")
|
<< (any_black ? "BLACK" : "ALLOW")
|
||||||
<< " because at least one of its siblings is "
|
<< " because at least one of its siblings is "
|
||||||
<< (any_black ? "BLACK" : "WHITE");
|
<< (any_black ? "BLACK" : "ALLOW");
|
||||||
if (any_black) {
|
if (any_black) {
|
||||||
white_set->erase(node_type_idx);
|
allow_set->erase(node_type_idx);
|
||||||
black_set->insert(node_type_idx);
|
black_set->insert(node_type_idx);
|
||||||
} else {
|
} else {
|
||||||
white_set->insert(node_type_idx);
|
allow_set->insert(node_type_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1807,10 +1807,10 @@ bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This adds existing Cast nodes to white_set if all of their outputs are white,
|
// This adds existing Cast nodes to allow_set if all of their outputs are allow,
|
||||||
// avoiding the need to add a new Cast node after an existing Cast.
|
// avoiding the need to add a new Cast node after an existing Cast.
|
||||||
void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
|
||||||
absl::flat_hash_set<int>* white_set) const {
|
absl::flat_hash_set<int>* allow_set) const {
|
||||||
int num_nodes_preop = graph_->node_size();
|
int num_nodes_preop = graph_->node_size();
|
||||||
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
|
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
|
||||||
NodeDef* node = graph_->mutable_node(node_idx);
|
NodeDef* node = graph_->mutable_node(node_idx);
|
||||||
@ -1818,7 +1818,7 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
|||||||
if (node->op() != "Cast" || !IsFloat32(node_type)) {
|
if (node->op() != "Cast" || !IsFloat32(node_type)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
bool all_fanouts_white = true;
|
bool all_fanouts_allow = true;
|
||||||
MutableGraphView::OutputPort src(node, 0);
|
MutableGraphView::OutputPort src(node, 0);
|
||||||
const auto& fanout = graph_view_.GetFanout(src);
|
const auto& fanout = graph_view_.GetFanout(src);
|
||||||
for (const MutableGraphView::InputPort& dst : fanout) {
|
for (const MutableGraphView::InputPort& dst : fanout) {
|
||||||
@ -1830,13 +1830,13 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
|||||||
<< "Type attribute " << dst_type_attr.DebugString() << " of node "
|
<< "Type attribute " << dst_type_attr.DebugString() << " of node "
|
||||||
<< dst.node->name() << " not found in graph view";
|
<< dst.node->name() << " not found in graph view";
|
||||||
int dst_type_idx = maybe_dst_type_idx.value();
|
int dst_type_idx = maybe_dst_type_idx.value();
|
||||||
bool dst_is_white = white_set->count(dst_type_idx);
|
bool dst_is_allow = allow_set->count(dst_type_idx);
|
||||||
if (!dst_is_white) {
|
if (!dst_is_allow) {
|
||||||
all_fanouts_white = false;
|
all_fanouts_allow = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!fanout.empty() && all_fanouts_white) {
|
if (!fanout.empty() && all_fanouts_allow) {
|
||||||
const absl::optional<int> maybe_node_type_idx =
|
const absl::optional<int> maybe_node_type_idx =
|
||||||
graph_type_view_.GetNodeIndex(node_type);
|
graph_type_view_.GetNodeIndex(node_type);
|
||||||
DCHECK(maybe_node_type_idx.has_value())
|
DCHECK(maybe_node_type_idx.has_value())
|
||||||
@ -1844,16 +1844,16 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
|||||||
<< " of node " << node_type.node->name()
|
<< " of node " << node_type.node->name()
|
||||||
<< " not found in graph view";
|
<< " not found in graph view";
|
||||||
int node_type_idx = maybe_node_type_idx.value();
|
int node_type_idx = maybe_node_type_idx.value();
|
||||||
white_set->insert(node_type_idx);
|
allow_set->insert(node_type_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Changes all white-painted type attributes to DT_HALF or DT_BFLOAT16, and
|
// Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
|
||||||
// inserts Cast nodes at node outputs for all edges that connect
|
// inserts Cast nodes at node outputs for all edges that connect
|
||||||
// white-painted <-> non-white-painted type attributes.
|
// allow-painted <-> non-allow-painted type attributes.
|
||||||
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||||
const absl::flat_hash_set<int>& white_set) {
|
const absl::flat_hash_set<int>& allow_set) {
|
||||||
int num_nodes_changed = 0;
|
int num_nodes_changed = 0;
|
||||||
int num_nonvar_casts_to_f16 = 0;
|
int num_nonvar_casts_to_f16 = 0;
|
||||||
int num_nodes_preop = graph_->node_size();
|
int num_nodes_preop = graph_->node_size();
|
||||||
@ -1869,8 +1869,8 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
|||||||
}
|
}
|
||||||
int node_type_idx = maybe_node_type_idx.value();
|
int node_type_idx = maybe_node_type_idx.value();
|
||||||
if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
|
if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
|
||||||
bool src_is_white = white_set.count(node_type_idx);
|
bool src_is_allow = allow_set.count(node_type_idx);
|
||||||
if (src_is_white) {
|
if (src_is_allow) {
|
||||||
VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
|
VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
|
||||||
<< node->op() << " node " << node->name() << " to "
|
<< node->op() << " node " << node->name() << " to "
|
||||||
<< DataTypeString(target_dtype_);
|
<< DataTypeString(target_dtype_);
|
||||||
@ -1896,10 +1896,10 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
|||||||
" not found in graph view");
|
" not found in graph view");
|
||||||
}
|
}
|
||||||
int dst_type_idx = maybe_dst_type_idx.value();
|
int dst_type_idx = maybe_dst_type_idx.value();
|
||||||
bool dst_is_white = white_set.count(dst_type_idx);
|
bool dst_is_allow = allow_set.count(dst_type_idx);
|
||||||
if (src_is_white != dst_is_white) {
|
if (src_is_allow != dst_is_allow) {
|
||||||
if (!added_cast_node) {
|
if (!added_cast_node) {
|
||||||
bool to_f16 = dst_is_white;
|
bool to_f16 = dst_is_allow;
|
||||||
VLOG(1) << "Inserting cast to "
|
VLOG(1) << "Inserting cast to "
|
||||||
<< (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
|
<< (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
|
||||||
<< " at " << src.node->op() << " " << src.node->name()
|
<< " at " << src.node->op() << " " << src.node->name()
|
||||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
// Represents the four lists of ops: the white list, gray list, black list, and
|
// Represents the four lists of ops: the allow list, gray list, black list, and
|
||||||
// clear list. These lists determine which ops are converted to fp16/bf16
|
// clear list. These lists determine which ops are converted to fp16/bf16
|
||||||
// (referred to as 'f16' for short) and which ops stay as fp32.
|
// (referred to as 'f16' for short) and which ops stay as fp32.
|
||||||
class AutoMixedPrecisionLists {
|
class AutoMixedPrecisionLists {
|
||||||
@ -33,7 +33,7 @@ class AutoMixedPrecisionLists {
|
|||||||
// Returns the set of ops that are considered numerically-safe (for execution
|
// Returns the set of ops that are considered numerically-safe (for execution
|
||||||
// in f16), performance-critical, and can run in f16. These ops are always
|
// in f16), performance-critical, and can run in f16. These ops are always
|
||||||
// converted to f16.
|
// converted to f16.
|
||||||
virtual gtl::FlatSet<string> WhiteList() = 0;
|
virtual gtl::FlatSet<string> AllowList() = 0;
|
||||||
// Returns the set of ops that can run in f16 and are considered numerically-
|
// Returns the set of ops that can run in f16 and are considered numerically-
|
||||||
// safe (for execution in f16), but which may be made unsafe by an upstream
|
// safe (for execution in f16), but which may be made unsafe by an upstream
|
||||||
// blacklist op.
|
// blacklist op.
|
||||||
@ -51,8 +51,10 @@ class AutoMixedPrecisionLists {
|
|||||||
protected:
|
protected:
|
||||||
// Adds or removes ops from list if certain environmental variables are set.
|
// Adds or removes ops from list if certain environmental variables are set.
|
||||||
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
|
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
|
||||||
CHECK(list_name == "WHITELIST" || list_name == "GRAYLIST" || // Crash OK.
|
CHECK(list_name == "ALLOWLIST" || list_name == "GRAYLIST" || // Crash OK.
|
||||||
list_name == "BLACKLIST" || list_name == "CLEARLIST");
|
list_name == "BLACKLIST" || list_name == "CLEARLIST" ||
|
||||||
|
// TODO(reedwm): for bkwds compat; remove when no longer necessary:
|
||||||
|
list_name == "WHITELIST");
|
||||||
string add_env_var =
|
string add_env_var =
|
||||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
|
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
|
||||||
string remove_env_var =
|
string remove_env_var =
|
||||||
@ -104,7 +106,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
|||||||
AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
|
AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
|
||||||
: cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
|
: cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
|
||||||
|
|
||||||
gtl::FlatSet<string> WhiteList() override {
|
gtl::FlatSet<string> AllowList() override {
|
||||||
auto list = gtl::FlatSet<string>{
|
auto list = gtl::FlatSet<string>{
|
||||||
"BlockLSTM",
|
"BlockLSTM",
|
||||||
"BlockLSTMV2",
|
"BlockLSTMV2",
|
||||||
@ -144,7 +146,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
|||||||
list.insert("Conv3DBackpropInput");
|
list.insert("Conv3DBackpropInput");
|
||||||
list.insert("Conv3DBackpropInputV2");
|
list.insert("Conv3DBackpropInputV2");
|
||||||
}
|
}
|
||||||
|
UpdateList("ALLOWLIST", &list);
|
||||||
|
// For backwards compatibility, keeping the original env variable here.
|
||||||
|
// TODO(reedwm): This should be removed if we don't have active users.
|
||||||
UpdateList("WHITELIST", &list);
|
UpdateList("WHITELIST", &list);
|
||||||
|
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,8 +344,8 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
|||||||
AutoMixedPrecisionListsMkl() {}
|
AutoMixedPrecisionListsMkl() {}
|
||||||
|
|
||||||
// Only ops which are supported by MKL in bfloat16 should be added to the
|
// Only ops which are supported by MKL in bfloat16 should be added to the
|
||||||
// white list, gray list, or clear list.
|
// allow list, gray list, or clear list.
|
||||||
gtl::FlatSet<string> WhiteList() override {
|
gtl::FlatSet<string> AllowList() override {
|
||||||
auto list = gtl::FlatSet<string>{"Conv2D",
|
auto list = gtl::FlatSet<string>{"Conv2D",
|
||||||
"Conv2DBackpropFilter",
|
"Conv2DBackpropFilter",
|
||||||
"Conv2DBackpropInput",
|
"Conv2DBackpropInput",
|
||||||
@ -353,7 +359,7 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
|||||||
"BatchMatMul",
|
"BatchMatMul",
|
||||||
"BatchMatMulV2"};
|
"BatchMatMulV2"};
|
||||||
|
|
||||||
UpdateList("WHITELIST", &list);
|
UpdateList("ALLOWLIST", &list);
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,10 +169,10 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
|||||||
Output eye = ops::Const(s.WithOpName("eye"),
|
Output eye = ops::Const(s.WithOpName("eye"),
|
||||||
GenerateIdentityMatrix<DT_FLOAT>(size, size));
|
GenerateIdentityMatrix<DT_FLOAT>(size, size));
|
||||||
Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, eye);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
|
||||||
Output gry1 = test_op_factory(s.WithOpName("gry1"), wht1);
|
Output gry1 = test_op_factory(s.WithOpName("gry1"), allow1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, eye);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, eye);
|
||||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch1"};
|
item.fetch = {"fetch1"};
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
@ -190,9 +190,9 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
|||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
|
||||||
DT_FLOAT);
|
DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, feed);
|
auto tensors = EvaluateNodes(output, item.fetch, feed);
|
||||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
@ -247,8 +247,8 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
|
|||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
||||||
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
|
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
|
||||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
|
||||||
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
||||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
||||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||||
@ -267,7 +267,7 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
|
|||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
||||||
@ -288,8 +288,8 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
|
|||||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
|
||||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
|
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
|
||||||
Output gry2 = ops::Log(s.WithOpName("gry2"), clr3);
|
Output gry2 = ops::Log(s.WithOpName("gry2"), clr3);
|
||||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), gry2);
|
Output clr4 = ops::Relu(s.WithOpName("clr4"), gry2);
|
||||||
Output blk2 = ops::SparseMatMul(s.WithOpName("blk2"), clr4, clr4);
|
Output blk2 = ops::SparseMatMul(s.WithOpName("blk2"), clr4, clr4);
|
||||||
@ -314,7 +314,7 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
|
|||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry2")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("gry2")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||||
@ -335,10 +335,10 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
|
|||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
|
||||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr1, clr1);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
|
||||||
auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
|
auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
|
||||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
|
Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
|
||||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1);
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
|
||||||
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
|
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
@ -357,7 +357,7 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
|
|||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
|
||||||
|
|
||||||
@ -372,18 +372,18 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
|
|||||||
TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
|
||||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), gry1);
|
Output blk1 = ops::Exp(s.WithOpName("blk1"), gry1);
|
||||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), blk1);
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), blk1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), clr2, clr2);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
|
||||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht2);
|
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
|
||||||
Output blk2 = ops::Exp(s.WithOpName("blk2"), clr3);
|
Output blk2 = ops::Exp(s.WithOpName("blk2"), clr3);
|
||||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"wht1", "clr2", "clr3"};
|
item.fetch = {"allow1", "clr2", "clr3"};
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
|
||||||
@ -396,12 +396,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
|||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||||
@ -418,12 +418,13 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
|
|||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr1, clr1);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
|
||||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), wht1);
|
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2").WithDevice(
|
Output allow2 =
|
||||||
"/job:localhost/replica:0/task:0/device:CPU:0"),
|
ops::MatMul(s.WithOpName("allow2").WithDevice(
|
||||||
gry1, gry1);
|
"/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), wht2);
|
gry1, gry1);
|
||||||
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
|
||||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
@ -441,9 +442,9 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
|
|||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
@ -459,12 +460,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
|
|||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
|
Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
|
||||||
Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
|
Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, clr1);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
|
||||||
Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
|
Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
|
||||||
Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
|
Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), input, clr2);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
|
||||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1);
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
|
||||||
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), wht2);
|
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch1", "fetch2"};
|
item.fetch = {"fetch1", "fetch2"};
|
||||||
@ -485,10 +486,10 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
|
|||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, feed);
|
auto tensors = EvaluateNodes(output, item.fetch, feed);
|
||||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
@ -507,22 +508,24 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
|||||||
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
|
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
|
||||||
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
|
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
|
||||||
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
|
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
|
||||||
Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1},
|
Output allow1 =
|
||||||
"SAME", ops::Conv2D::DataFormat("NHWC"));
|
ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
|
||||||
|
ops::Conv2D::DataFormat("NHWC"));
|
||||||
auto fbn1_op =
|
auto fbn1_op =
|
||||||
ops::FusedBatchNorm(s.WithOpName("fbn1"), wht1, scale, offset, mean,
|
ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
|
||||||
variance, ops::FusedBatchNorm::DataFormat("NHWC"));
|
variance, ops::FusedBatchNorm::DataFormat("NHWC"));
|
||||||
Output fbn1 = fbn1_op.y;
|
Output fbn1 = fbn1_op.y;
|
||||||
Output fbn1_rs1 = fbn1_op.reserve_space_1;
|
Output fbn1_rs1 = fbn1_op.reserve_space_1;
|
||||||
Output fbn1_rs2 = fbn1_op.reserve_space_2;
|
Output fbn1_rs2 = fbn1_op.reserve_space_2;
|
||||||
Output bng1 = ops::FusedBatchNormGrad(
|
Output bng1 = ops::FusedBatchNormGrad(
|
||||||
s.WithOpName("bng1"), fbn1, wht1, scale, fbn1_rs1, fbn1_rs2,
|
s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
|
||||||
ops::FusedBatchNormGrad::DataFormat("NHWC"))
|
fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
|
||||||
.x_backprop;
|
.x_backprop;
|
||||||
Output gry1 = ops::Add(s.WithOpName("gry1"), fbn1, bng1);
|
Output gry1 = ops::Add(s.WithOpName("gry1"), fbn1, bng1);
|
||||||
Output wht2 = ops::Conv2D(s.WithOpName("wht2"), gry1, weight, {1, 1, 1, 1},
|
Output allow2 =
|
||||||
"SAME", ops::Conv2D::DataFormat("NHWC"));
|
ops::Conv2D(s.WithOpName("allow2"), gry1, weight, {1, 1, 1, 1}, "SAME",
|
||||||
Output fetch = ops::Identity(s.WithOpName("fetch"), wht2);
|
ops::Conv2D::DataFormat("NHWC"));
|
||||||
|
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch"};
|
item.fetch = {"fetch"};
|
||||||
@ -537,7 +540,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
|||||||
|
|
||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
|
EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
|
||||||
EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
|
||||||
@ -545,7 +548,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
|||||||
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
@ -558,13 +561,13 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
|||||||
TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {wht1, wht1, wht1});
|
auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
|
||||||
Output gry1 =
|
Output gry1 =
|
||||||
ops::AddN(s.WithOpName("gry1"),
|
ops::AddN(s.WithOpName("gry1"),
|
||||||
{clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
|
{clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||||
Output fetch = ops::Identity(s.WithOpName("fetch"), wht2);
|
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch"};
|
item.fetch = {"fetch"};
|
||||||
@ -580,12 +583,12 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
|||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
|
for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
|
||||||
EXPECT_EQ(type, DT_HALF);
|
EXPECT_EQ(type, DT_HALF);
|
||||||
}
|
}
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
@ -599,8 +602,8 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) {
|
|||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
|
||||||
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
|
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
|
||||||
Output fetch = ops::Identity(s.WithOpName("fetch"), wht1);
|
Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch"};
|
item.fetch = {"fetch"};
|
||||||
@ -617,7 +620,7 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) {
|
|||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
|
||||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
|
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
|
||||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||||
@ -640,8 +643,8 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
|
|||||||
Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
|
Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
|
||||||
auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
|
auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
|
||||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), swt1.output_true);
|
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), swt1.output_true);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), gry1, gry1);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), gry1, gry1);
|
||||||
Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), wht1);
|
Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
|
||||||
Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
|
Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
|
||||||
Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
|
Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
|
||||||
// Add a second merge node from the same NextIteration node. This case arises
|
// Add a second merge node from the same NextIteration node. This case arises
|
||||||
@ -670,13 +673,13 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
|
|||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
// Note that mrg1 gets painted black because it is between blk1 and gry1. This
|
// Note that mrg1 gets painted black because it is between blk1 and gry1. This
|
||||||
// forces nxt1 and mrg2 to be painted black as well (they would otherwise be
|
// forces nxt1 and mrg2 to be painted black as well (they would otherwise be
|
||||||
// painted white because they are clear and have a direct path to wht1).
|
// painted allow because they are clear and have a direct path to allow1).
|
||||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
|
||||||
@ -699,9 +702,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
|||||||
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
||||||
auto tl1w1 =
|
auto tl1w1 =
|
||||||
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
auto tl1w2 =
|
auto tl1w2 =
|
||||||
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
|
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
|
||||||
// Ensure that TensorListResize doesn't cause any problems.
|
// Ensure that TensorListResize doesn't cause any problems.
|
||||||
Output tl1rs =
|
Output tl1rs =
|
||||||
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
||||||
@ -709,9 +712,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
|||||||
shape, DT_FLOAT)
|
shape, DT_FLOAT)
|
||||||
.item;
|
.item;
|
||||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||||
auto tl1w3 =
|
auto tl1w3 =
|
||||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2);
|
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
|
||||||
Output tl1r2 =
|
Output tl1r2 =
|
||||||
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
||||||
shape, DT_FLOAT)
|
shape, DT_FLOAT)
|
||||||
@ -742,11 +745,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
|||||||
const char* type_key = "element_dtype";
|
const char* type_key = "element_dtype";
|
||||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
@ -767,15 +770,16 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
|
|||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
auto tl1w1 =
|
auto tl1w1 =
|
||||||
ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
|
ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
auto tl1w2 =
|
auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
|
||||||
ops::TensorListPushBack(s.WithOpName("tl1w2"), tl1w1.output_handle, wht1);
|
tl1w1.output_handle, allow1);
|
||||||
Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
|
Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
|
||||||
tl1w2.output_handle, shape, DT_FLOAT)
|
tl1w2.output_handle, shape, DT_FLOAT)
|
||||||
.tensor;
|
.tensor;
|
||||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||||
auto tl1w3 = ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, wht2);
|
auto tl1w3 =
|
||||||
|
ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
|
||||||
Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
|
Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
|
||||||
tl1w3.output_handle, shape, DT_FLOAT)
|
tl1w3.output_handle, shape, DT_FLOAT)
|
||||||
.tensor;
|
.tensor;
|
||||||
@ -804,11 +808,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
|
|||||||
const char* type_key = "element_dtype";
|
const char* type_key = "element_dtype";
|
||||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
@ -826,19 +830,19 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
|
|||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
tensorflow::Input shape = {32};
|
tensorflow::Input shape = {32};
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), wht1, shape);
|
auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
|
||||||
Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
|
Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
|
||||||
shape, DT_FLOAT)
|
shape, DT_FLOAT)
|
||||||
.tensor;
|
.tensor;
|
||||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||||
|
|
||||||
// This tests that a white-painted object node (tl2) will force an unpainted
|
// This tests that a allow-painted object node (tl2) will force an unpainted
|
||||||
// client node (tl2w1) to be painted white as well. (Without the force, tl2w1
|
// client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
|
||||||
// would remain unpainted, producing an invalid graph).
|
// would remain unpainted, producing an invalid graph).
|
||||||
auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), wht1, shape);
|
auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
|
||||||
auto tl2w1 =
|
auto tl2w1 =
|
||||||
ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
|
ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
|
||||||
|
|
||||||
@ -856,11 +860,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
|
|||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
const char* type_key = "element_dtype";
|
const char* type_key = "element_dtype";
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
||||||
|
|
||||||
@ -878,12 +882,13 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
|||||||
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
||||||
auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
|
auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
Output tl1_tl2 =
|
Output tl1_tl2 =
|
||||||
ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
|
ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
|
||||||
Output wht1_wht1 = ops::Stack(s.WithOpName("wht1_wht1"), {wht1, wht1});
|
Output allow1_allow1 =
|
||||||
auto tl12w1 =
|
ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
|
||||||
ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2, wht1_wht1);
|
auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
|
||||||
|
allow1_allow1);
|
||||||
OutputList tl12w1_outputs =
|
OutputList tl12w1_outputs =
|
||||||
ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
|
ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
|
||||||
.output;
|
.output;
|
||||||
@ -898,8 +903,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
|||||||
ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
|
ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
|
||||||
.tensor;
|
.tensor;
|
||||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl3r1);
|
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl3r1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch1"};
|
item.fetch = {"fetch1"};
|
||||||
@ -915,8 +920,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
|||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
const char* type_key = "element_dtype";
|
const char* type_key = "element_dtype";
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||||
@ -961,8 +966,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
|||||||
TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
|
TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
|
||||||
tensorflow::Input shape = {32, 32};
|
tensorflow::Input shape = {32, 32};
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), wht1);
|
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
|
||||||
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
||||||
auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, gry1);
|
auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, gry1);
|
||||||
auto _gry1 = tensorflow::ops::AsNodeOut(s, gry1);
|
auto _gry1 = tensorflow::ops::AsNodeOut(s, gry1);
|
||||||
@ -981,8 +986,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
|||||||
Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
|
Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
|
||||||
tl2w1.output_handle, shape, DT_FLOAT)
|
tl2w1.output_handle, shape, DT_FLOAT)
|
||||||
.tensor;
|
.tensor;
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), tl1r1, tl2r1);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
|
||||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch1"};
|
item.fetch = {"fetch1"};
|
||||||
@ -997,8 +1002,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
|||||||
|
|
||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
const char* type_key = "element_dtype";
|
const char* type_key = "element_dtype";
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
||||||
@ -1031,8 +1036,8 @@ int GetCudaVersion(const Cluster& cluster) {
|
|||||||
TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
|
TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
|
||||||
Output wht1 = ops::BatchMatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
|
||||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1);
|
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"fetch1"};
|
item.fetch = {"fetch1"};
|
||||||
@ -1049,10 +1054,10 @@ TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
|
|||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
if (GetCudaVersion(*virtual_cluster_.get()) >= 9010) {
|
if (GetCudaVersion(*virtual_cluster_.get()) >= 9010) {
|
||||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_EQ(output.node_size(), item.graph.node_size());
|
EXPECT_EQ(output.node_size(), item.graph.node_size());
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
@ -1187,8 +1192,8 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
|
|||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
||||||
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
|
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
|
||||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
|
||||||
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
||||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
||||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||||
@ -1207,7 +1212,7 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
|
|||||||
GraphView output_view(&output);
|
GraphView output_view(&output);
|
||||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
||||||
@ -1228,8 +1233,8 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
|||||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
|
||||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
|
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
|
||||||
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
|
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
|
||||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||||
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
|
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
|
||||||
@ -1254,7 +1259,7 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
|||||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||||
@ -1280,9 +1285,9 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
|||||||
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
||||||
auto tl1w1 =
|
auto tl1w1 =
|
||||||
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
||||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||||
auto tl1w2 =
|
auto tl1w2 =
|
||||||
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
|
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
|
||||||
// Ensure that TensorListResize doesn't cause any problems.
|
// Ensure that TensorListResize doesn't cause any problems.
|
||||||
Output tl1rs =
|
Output tl1rs =
|
||||||
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
||||||
@ -1290,9 +1295,9 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
|||||||
shape, DT_FLOAT)
|
shape, DT_FLOAT)
|
||||||
.item;
|
.item;
|
||||||
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
|
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
|
||||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||||
auto tl1w3 =
|
auto tl1w3 =
|
||||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2);
|
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
|
||||||
Output tl1r2 =
|
Output tl1r2 =
|
||||||
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
||||||
shape, DT_FLOAT)
|
shape, DT_FLOAT)
|
||||||
@ -1325,13 +1330,13 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
|||||||
DT_BFLOAT16);
|
DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
|
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
|
||||||
DT_BFLOAT16);
|
DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
|
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
|
||||||
DT_BFLOAT16);
|
DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
|
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
|
||||||
DT_BFLOAT16);
|
DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_BFLOAT16);
|
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
|
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
|
||||||
DT_BFLOAT16);
|
DT_BFLOAT16);
|
||||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||||
|
@ -1020,9 +1020,9 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skips nodes that must be preserved except whitelisted nodes.
|
// Skips nodes that must be preserved except allowlisted nodes.
|
||||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
|
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
|
||||||
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
|
nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1082,13 +1082,13 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't fold nodes that have no outgoing edges except whitelisted nodes.
|
// Don't fold nodes that have no outgoing edges except allowlisted nodes.
|
||||||
// Such nodes could be introduced by an earlier constant folding pass and are
|
// Such nodes could be introduced by an earlier constant folding pass and are
|
||||||
// preserved in case users want to fetch their values; re-processing them
|
// preserved in case users want to fetch their values; re-processing them
|
||||||
// would lead to an error of adding a duplicated node to graph.
|
// would lead to an error of adding a duplicated node to graph.
|
||||||
const auto& outputs = node_map_->GetOutputs(node.name());
|
const auto& outputs = node_map_->GetOutputs(node.name());
|
||||||
if (outputs.empty() &&
|
if (outputs.empty() &&
|
||||||
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
|
nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -3874,7 +3874,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
|||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
graph_ = &item->graph;
|
graph_ = &item->graph;
|
||||||
node_map_.reset(new NodeMap(graph_));
|
node_map_.reset(new NodeMap(graph_));
|
||||||
nodes_whitelist_.clear();
|
nodes_allowlist_.clear();
|
||||||
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
|
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
|
||||||
// has a single fanout, it would be rewritten as a constant with the same
|
// has a single fanout, it would be rewritten as a constant with the same
|
||||||
// node name, and therefore users are still able to fetch it. This is not
|
// node name, and therefore users are still able to fetch it. This is not
|
||||||
@ -3885,7 +3885,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
|||||||
for (const auto& fetch : item->fetch) {
|
for (const auto& fetch : item->fetch) {
|
||||||
const NodeDef* fetch_node = node_map_->GetNode(fetch);
|
const NodeDef* fetch_node = node_map_->GetNode(fetch);
|
||||||
if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
|
if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
|
||||||
nodes_whitelist_.insert(fetch_node->name());
|
nodes_allowlist_.insert(fetch_node->name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -328,7 +328,7 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
std::unique_ptr<NodeMap> node_map_;
|
std::unique_ptr<NodeMap> node_map_;
|
||||||
std::unordered_set<string> nodes_to_preserve_;
|
std::unordered_set<string> nodes_to_preserve_;
|
||||||
// TODO(rmlarsen): Could these be keyed on absl::string_view?
|
// TODO(rmlarsen): Could these be keyed on absl::string_view?
|
||||||
absl::flat_hash_set<string> nodes_whitelist_;
|
absl::flat_hash_set<string> nodes_allowlist_;
|
||||||
absl::flat_hash_set<string> feed_nodes_;
|
absl::flat_hash_set<string> feed_nodes_;
|
||||||
absl::flat_hash_map<string, bool> maybe_foldable_nodes_;
|
absl::flat_hash_map<string, bool> maybe_foldable_nodes_;
|
||||||
bool has_fetch_;
|
bool has_fetch_;
|
||||||
|
@ -232,16 +232,16 @@ Status IsFunctionStateful(const FunctionLibraryDefinition& library,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether an op has been whitelisted as stateless. Uses a heuristic to
|
// Returns whether an op has been allowlisted as stateless. Uses a heuristic to
|
||||||
// whitelist source dataset ops which have been marked stateful due to
|
// allowlist source dataset ops which have been marked stateful due to
|
||||||
// b/65524810. Also looks up the `op_def->name` in the global
|
// b/65524810. Also looks up the `op_def->name` in the global
|
||||||
// `WhitelistedStatefulOpRegistry`.
|
// `AllowlistedStatefulOpRegistry`.
|
||||||
bool IsOpWhitelisted(const OpDef* op_def) {
|
bool IsOpAllowlisted(const OpDef* op_def) {
|
||||||
return (op_def->output_arg_size() == 1 &&
|
return (op_def->output_arg_size() == 1 &&
|
||||||
op_def->output_arg(0).type() == DT_VARIANT &&
|
op_def->output_arg(0).type() == DT_VARIANT &&
|
||||||
(absl::EndsWith(op_def->name(), "Dataset") ||
|
(absl::EndsWith(op_def->name(), "Dataset") ||
|
||||||
absl::EndsWith(op_def->name(), "DatasetV2"))) ||
|
absl::EndsWith(op_def->name(), "DatasetV2"))) ||
|
||||||
WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
|
AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LookupFunction(const FunctionLibraryDefinition& lib_def,
|
Status LookupFunction(const FunctionLibraryDefinition& lib_def,
|
||||||
@ -389,7 +389,7 @@ Status IsNodeStateful(const FunctionLibraryDefinition& library,
|
|||||||
// TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
|
// TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
|
||||||
// `LookUpOpDef` errors here.
|
// `LookUpOpDef` errors here.
|
||||||
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
|
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
|
||||||
IsOpWhitelisted(op_def) || !op_def->is_stateful() ||
|
IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
|
||||||
op_def->name() == "Assert") {
|
op_def->name() == "Assert") {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -41478,13 +41478,13 @@ func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_ke
|
|||||||
// DatasetToGraphAttr is an optional argument to DatasetToGraph.
|
// DatasetToGraphAttr is an optional argument to DatasetToGraph.
|
||||||
type DatasetToGraphAttr func(optionalAttr)
|
type DatasetToGraphAttr func(optionalAttr)
|
||||||
|
|
||||||
// DatasetToGraphStatefulWhitelist sets the optional stateful_whitelist attribute to value.
|
// DatasetToGraphStatefulAllowlist sets the optional stateful_allowlist attribute to value.
|
||||||
// If not specified, defaults to <>
|
// If not specified, defaults to <>
|
||||||
//
|
//
|
||||||
// REQUIRES: len(value) >= 0
|
// REQUIRES: len(value) >= 0
|
||||||
func DatasetToGraphStatefulWhitelist(value []string) DatasetToGraphAttr {
|
func DatasetToGraphStatefulAllowlist(value []string) DatasetToGraphAttr {
|
||||||
return func(m optionalAttr) {
|
return func(m optionalAttr) {
|
||||||
m["stateful_whitelist"] = value
|
m["stateful_allowlist"] = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,10 +233,10 @@ tf_cc_test(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "whitelisted_flex_ops_lib",
|
name = "whitelisted_flex_ops_lib",
|
||||||
srcs = [
|
srcs = [
|
||||||
"whitelisted_flex_ops.cc",
|
"allowlisted_flex_ops.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"whitelisted_flex_ops.h",
|
"allowlisted_flex_ops.h",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
|
||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace flex {
|
namespace flex {
|
||||||
|
|
||||||
bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) {
|
bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name) {
|
||||||
static const std::set<std::string>* whitelisted_flex_ops =
|
static const std::set<std::string>* allowlisted_flex_ops =
|
||||||
new std::set<std::string>({
|
new std::set<std::string>({
|
||||||
// go/keep-sorted start
|
// go/keep-sorted start
|
||||||
"Abort",
|
"Abort",
|
||||||
@ -538,8 +538,8 @@ bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) {
|
|||||||
"_Send",
|
"_Send",
|
||||||
// go/keep-sorted end
|
// go/keep-sorted end
|
||||||
});
|
});
|
||||||
return whitelisted_flex_ops->find(tensorflow_op_name) !=
|
return allowlisted_flex_ops->find(tensorflow_op_name) !=
|
||||||
whitelisted_flex_ops->end();
|
allowlisted_flex_ops->end();
|
||||||
// Prevent lint error about this function being too long. This function
|
// Prevent lint error about this function being too long. This function
|
||||||
// is a set of ops, and making it shorter won't help readbility.
|
// is a set of ops, and making it shorter won't help readbility.
|
||||||
// NOLINTNEXTLINE
|
// NOLINTNEXTLINE
|
@ -12,24 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_
|
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_
|
||||||
#define TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_
|
#define TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace flex {
|
namespace flex {
|
||||||
|
|
||||||
// Whether the given op has been statically whitelisted for flex export.
|
// Whether the given op has been statically allowlisted for flex export.
|
||||||
//
|
//
|
||||||
// This static whitelist is formed by the intersection of ops supported by
|
// This static allowlist is formed by the intersection of ops supported by
|
||||||
// TensorFlowMobile on both iOS and Android. As the converter is likely running
|
// TensorFlowMobile on both iOS and Android. As the converter is likely running
|
||||||
// on a host that has the full suite of TensorFlow ops available, we use this
|
// on a host that has the full suite of TensorFlow ops available, we use this
|
||||||
// static whitelist to ensure compatibility when deploying to a mobile device.
|
// static allowlist to ensure compatibility when deploying to a mobile device.
|
||||||
// TODO(b/118389105): Automate generation of the whitelisted flex ops.
|
// TODO(b/118389105): Automate generation of the allowlisted flex ops.
|
||||||
bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name);
|
bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name);
|
||||||
|
|
||||||
} // namespace flex
|
} // namespace flex
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
#endif // TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_
|
#endif // TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_
|
@ -70,7 +70,7 @@ TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We maintain an op-version whitelist here to ensure we don't accept unintended
|
// We maintain an op-version allowlist here to ensure we don't accept unintended
|
||||||
// ops.
|
// ops.
|
||||||
bool CheckOpVersion(const TfLiteRegistration* registration) {
|
bool CheckOpVersion(const TfLiteRegistration* registration) {
|
||||||
switch (registration->builtin_code) {
|
switch (registration->builtin_code) {
|
||||||
|
@ -18,7 +18,7 @@ namespace tflite {
|
|||||||
|
|
||||||
const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
|
const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
|
||||||
R"(
|
R"(
|
||||||
## Every Test can be whitelisted or blacklisted using a regexp on its test_id
|
## Every Test can be allowlisted or blacklisted using a regexp on its test_id
|
||||||
|
|
||||||
## Test_id
|
## Test_id
|
||||||
#
|
#
|
||||||
@ -28,7 +28,7 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
|
|||||||
# the ordinal is the position in the list of parameters generated by the
|
# the ordinal is the position in the list of parameters generated by the
|
||||||
# cardinal product of all the different parameter sets
|
# cardinal product of all the different parameter sets
|
||||||
|
|
||||||
# Blacklist/Whitelist
|
# Blacklist/Allowlist
|
||||||
# To blacklist an element simply add - before the test_id regex
|
# To blacklist an element simply add - before the test_id regex
|
||||||
|
|
||||||
## Rules evaluation
|
## Rules evaluation
|
||||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
// NNAPI specific configuration for the validation whitelist.
|
// NNAPI specific configuration for the validation allowlist.
|
||||||
class NnapiAccelerationTestParams {
|
class NnapiAccelerationTestParams {
|
||||||
public:
|
public:
|
||||||
// Content in nnapi_acceleration_test_list.cc.
|
// Content in nnapi_acceleration_test_list.cc.
|
||||||
|
@ -4526,7 +4526,7 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
|||||||
} else {
|
} else {
|
||||||
// If no accelerator is specified, only use NNAPI if an accelerator is
|
// If no accelerator is specified, only use NNAPI if an accelerator is
|
||||||
// available. Any available accelerator will make the device_count larger
|
// available. Any available accelerator will make the device_count larger
|
||||||
// than 1. More sophisticated check and whitelisting can be added later.
|
// than 1. More sophisticated check and allowlisting can be added later.
|
||||||
uint32_t device_count = 0;
|
uint32_t device_count = 0;
|
||||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||||
context, nnapi->ANeuralNetworks_getDeviceCount(&device_count),
|
context, nnapi->ANeuralNetworks_getDeviceCount(&device_count),
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Accelerator whitelisting
|
# Accelerator allowlisting
|
||||||
|
|
||||||
Experimental library and tools for determining whether an accelerator engine
|
Experimental library and tools for determining whether an accelerator engine
|
||||||
works well on a given device, and for a given model.
|
works well on a given device, and for a given model.
|
||||||
@ -6,7 +6,7 @@ works well on a given device, and for a given model.
|
|||||||
## Platform-agnostic, Android-first
|
## Platform-agnostic, Android-first
|
||||||
|
|
||||||
Android-focused, since the much smaller set of configurations on iOS means there
|
Android-focused, since the much smaller set of configurations on iOS means there
|
||||||
is much less need for whitelisting on iOS.
|
is much less need for allowlisting on iOS.
|
||||||
|
|
||||||
## Not just for TfLite
|
## Not just for TfLite
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ package tflite.proto;
|
|||||||
// compatibility list entries have been developed for and what settings are used
|
// compatibility list entries have been developed for and what settings are used
|
||||||
// for NNAPI.
|
// for NNAPI.
|
||||||
enum ExecutionPreference {
|
enum ExecutionPreference {
|
||||||
// Match any selected preference. Whitelist (semantically - value is same as
|
// Match any selected preference. Allowlist (semantically - value is same as
|
||||||
// on input).
|
// on input).
|
||||||
ANY = 0;
|
ANY = 0;
|
||||||
// Match low latency preference. Both compatibility list and input.
|
// Match low latency preference. Both compatibility list and input.
|
||||||
|
@ -39,8 +39,8 @@ for `target_spec.supported_ops`:
|
|||||||
|
|
||||||
* `TFLITE_BUILTINS` - Converts models using TensorFlow Lite builtin ops.
|
* `TFLITE_BUILTINS` - Converts models using TensorFlow Lite builtin ops.
|
||||||
* `SELECT_TF_OPS` - Converts models using TensorFlow ops. The exact subset of
|
* `SELECT_TF_OPS` - Converts models using TensorFlow ops. The exact subset of
|
||||||
supported ops can be found in the whitelist at
|
supported ops can be found in the allowlist at
|
||||||
`lite/delegates/flex/whitelisted_flex_ops.cc`.
|
`lite/delegates/flex/allowlisted_flex_ops.cc`.
|
||||||
|
|
||||||
Note: `target_spec.supported_ops` was previously `target_ops` in the Python API.
|
Note: `target_spec.supported_ops` was previously `target_ops` in the Python API.
|
||||||
|
|
||||||
|
@ -27,8 +27,8 @@ public final class CompatibilityListTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasic() throws Exception {
|
public void testBasic() throws Exception {
|
||||||
try (CompatibilityList whitelist = new CompatibilityList()) {
|
try (CompatibilityList allowlist = new CompatibilityList()) {
|
||||||
assertThat(whitelist.isDelegateSupportedOnThisDevice()).isTrue();
|
assertThat(allowlist.isDelegateSupportedOnThisDevice()).isTrue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
|||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
// Returns the test id to use to retrieve the acceleration configuration
|
// Returns the test id to use to retrieve the acceleration configuration
|
||||||
// in the acceleration whitelist.
|
// in the acceleration allowlist.
|
||||||
std::string GetCurrentTestId();
|
std::string GetCurrentTestId();
|
||||||
|
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -51,14 +51,14 @@ struct SimpleConfig {
|
|||||||
|
|
||||||
class ReadAccelerationConfigTest : public ::testing::Test {
|
class ReadAccelerationConfigTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
std::unordered_map<std::string, SimpleConfig> whitelist_;
|
std::unordered_map<std::string, SimpleConfig> allowlist_;
|
||||||
std::unordered_map<std::string, SimpleConfig> blacklist_;
|
std::unordered_map<std::string, SimpleConfig> blacklist_;
|
||||||
std::function<void(std::string, std::string, bool)> consumer_ =
|
std::function<void(std::string, std::string, bool)> consumer_ =
|
||||||
[this](std::string key, std::string value, bool is_blacklist) {
|
[this](std::string key, std::string value, bool is_blacklist) {
|
||||||
if (is_blacklist) {
|
if (is_blacklist) {
|
||||||
blacklist_[key] = {value};
|
blacklist_[key] = {value};
|
||||||
} else {
|
} else {
|
||||||
whitelist_[key] = {value};
|
allowlist_[key] = {value};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@ -66,21 +66,21 @@ class ReadAccelerationConfigTest : public ::testing::Test {
|
|||||||
TEST_F(ReadAccelerationConfigTest, ReadsAKeyOnlyLine) {
|
TEST_F(ReadAccelerationConfigTest, ReadsAKeyOnlyLine) {
|
||||||
ReadAccelerationConfig("key", consumer_);
|
ReadAccelerationConfig("key", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(whitelist_.find("key"), Not(Eq(whitelist_.end())));
|
EXPECT_THAT(allowlist_.find("key"), Not(Eq(allowlist_.end())));
|
||||||
EXPECT_TRUE(blacklist_.empty());
|
EXPECT_TRUE(blacklist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, ReadsABlacklistKeyOnlyLine) {
|
TEST_F(ReadAccelerationConfigTest, ReadsABlacklistKeyOnlyLine) {
|
||||||
ReadAccelerationConfig("-key", consumer_);
|
ReadAccelerationConfig("-key", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(blacklist_.find("key"), Not(Eq(whitelist_.end())));
|
EXPECT_THAT(blacklist_.find("key"), Not(Eq(allowlist_.end())));
|
||||||
EXPECT_TRUE(whitelist_.empty());
|
EXPECT_TRUE(allowlist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, ReadsAKeyValueLine) {
|
TEST_F(ReadAccelerationConfigTest, ReadsAKeyValueLine) {
|
||||||
ReadAccelerationConfig("key,value", consumer_);
|
ReadAccelerationConfig("key,value", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(whitelist_["key"].value, Eq("value"));
|
EXPECT_THAT(allowlist_["key"].value, Eq("value"));
|
||||||
EXPECT_TRUE(blacklist_.empty());
|
EXPECT_TRUE(blacklist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,13 +88,13 @@ TEST_F(ReadAccelerationConfigTest, ReadsABlackListKeyValueLine) {
|
|||||||
ReadAccelerationConfig("-key,value", consumer_);
|
ReadAccelerationConfig("-key,value", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
||||||
EXPECT_TRUE(whitelist_.empty());
|
EXPECT_TRUE(allowlist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, KeysAreLeftTrimmed) {
|
TEST_F(ReadAccelerationConfigTest, KeysAreLeftTrimmed) {
|
||||||
ReadAccelerationConfig(" key,value", consumer_);
|
ReadAccelerationConfig(" key,value", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(whitelist_["key"].value, Eq("value"));
|
EXPECT_THAT(allowlist_["key"].value, Eq("value"));
|
||||||
EXPECT_TRUE(blacklist_.empty());
|
EXPECT_TRUE(blacklist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,58 +102,58 @@ TEST_F(ReadAccelerationConfigTest, BlKeysAreLeftTrimmed) {
|
|||||||
ReadAccelerationConfig(" -key,value", consumer_);
|
ReadAccelerationConfig(" -key,value", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
||||||
EXPECT_TRUE(whitelist_.empty());
|
EXPECT_TRUE(allowlist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, IgnoresCommentedLines) {
|
TEST_F(ReadAccelerationConfigTest, IgnoresCommentedLines) {
|
||||||
ReadAccelerationConfig("#key,value", consumer_);
|
ReadAccelerationConfig("#key,value", consumer_);
|
||||||
|
|
||||||
EXPECT_TRUE(whitelist_.empty());
|
EXPECT_TRUE(allowlist_.empty());
|
||||||
EXPECT_TRUE(blacklist_.empty());
|
EXPECT_TRUE(blacklist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, CommentCanHaveTrailingBlanks) {
|
TEST_F(ReadAccelerationConfigTest, CommentCanHaveTrailingBlanks) {
|
||||||
ReadAccelerationConfig(" #key,value", consumer_);
|
ReadAccelerationConfig(" #key,value", consumer_);
|
||||||
|
|
||||||
EXPECT_TRUE(whitelist_.empty());
|
EXPECT_TRUE(allowlist_.empty());
|
||||||
EXPECT_TRUE(blacklist_.empty());
|
EXPECT_TRUE(blacklist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, CommentsAreOnlyForTheFullLine) {
|
TEST_F(ReadAccelerationConfigTest, CommentsAreOnlyForTheFullLine) {
|
||||||
ReadAccelerationConfig("key,value #comment", consumer_);
|
ReadAccelerationConfig("key,value #comment", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(whitelist_["key"].value, Eq("value #comment"));
|
EXPECT_THAT(allowlist_["key"].value, Eq("value #comment"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, IgnoresEmptyLines) {
|
TEST_F(ReadAccelerationConfigTest, IgnoresEmptyLines) {
|
||||||
ReadAccelerationConfig("", consumer_);
|
ReadAccelerationConfig("", consumer_);
|
||||||
|
|
||||||
EXPECT_TRUE(whitelist_.empty());
|
EXPECT_TRUE(allowlist_.empty());
|
||||||
EXPECT_TRUE(blacklist_.empty());
|
EXPECT_TRUE(blacklist_.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLines) {
|
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLines) {
|
||||||
ReadAccelerationConfig("key1,value1\nkey2,value2\n-key3,value3", consumer_);
|
ReadAccelerationConfig("key1,value1\nkey2,value2\n-key3,value3", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(whitelist_["key1"].value, Eq("value1"));
|
EXPECT_THAT(allowlist_["key1"].value, Eq("value1"));
|
||||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
EXPECT_THAT(allowlist_["key2"].value, Eq("value2"));
|
||||||
EXPECT_THAT(blacklist_["key3"].value, Eq("value3"));
|
EXPECT_THAT(blacklist_["key3"].value, Eq("value3"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithCommentsAndSpaces) {
|
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithCommentsAndSpaces) {
|
||||||
ReadAccelerationConfig("key1,value1\n#comment\n\nkey2,value2", consumer_);
|
ReadAccelerationConfig("key1,value1\n#comment\n\nkey2,value2", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(whitelist_["key1"].value, Eq("value1"));
|
EXPECT_THAT(allowlist_["key1"].value, Eq("value1"));
|
||||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
EXPECT_THAT(allowlist_["key2"].value, Eq("value2"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithMissingConfigValues) {
|
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithMissingConfigValues) {
|
||||||
ReadAccelerationConfig("key1\nkey2,value2\nkey3\nkey4,value4", consumer_);
|
ReadAccelerationConfig("key1\nkey2,value2\nkey3\nkey4,value4", consumer_);
|
||||||
|
|
||||||
EXPECT_THAT(whitelist_["key1"].value, Eq(""));
|
EXPECT_THAT(allowlist_["key1"].value, Eq(""));
|
||||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
EXPECT_THAT(allowlist_["key2"].value, Eq("value2"));
|
||||||
EXPECT_THAT(whitelist_["key3"].value, Eq(""));
|
EXPECT_THAT(allowlist_["key3"].value, Eq(""));
|
||||||
EXPECT_THAT(whitelist_["key4"].value, Eq("value4"));
|
EXPECT_THAT(allowlist_["key4"].value, Eq("value4"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GetAccelerationTestParam, LoadsTestConfig) {
|
TEST(GetAccelerationTestParam, LoadsTestConfig) {
|
||||||
|
@ -27,7 +27,7 @@ import six
|
|||||||
|
|
||||||
|
|
||||||
def sanitize_xml(unsanitized):
|
def sanitize_xml(unsanitized):
|
||||||
"""Uses a whitelist to avoid generating bad XML."""
|
"""Uses a allowlist to avoid generating bad XML."""
|
||||||
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized))
|
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized))
|
||||||
|
|
||||||
|
|
||||||
|
@ -794,7 +794,7 @@ TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
|
|||||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||||
EXPECT_EQ(key.custom_code(), "HashTableV2");
|
EXPECT_EQ(key.custom_code(), "HashTableV2");
|
||||||
EXPECT_EQ(key.version(), 1);
|
EXPECT_EQ(key.version(), 1);
|
||||||
// While HashTableV2 is excluded from the whitelisted flex op list, eventually
|
// While HashTableV2 is excluded from the allowlisted flex op list, eventually
|
||||||
// it won't be, and the following expectations will need to change as the op
|
// it won't be, and the following expectations will need to change as the op
|
||||||
// is explicitly blacklisted due to lack of asset support.
|
// is explicitly blacklisted due to lack of asset support.
|
||||||
EXPECT_FALSE(key.is_flex_op());
|
EXPECT_FALSE(key.is_flex_op());
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
|
|
||||||
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
||||||
// graph_transformation module.
|
// graph_transformation module.
|
||||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
||||||
#include "tensorflow/lite/toco/model.h"
|
#include "tensorflow/lite/toco/model.h"
|
||||||
@ -2116,7 +2116,7 @@ bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
|
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
|
||||||
// it and it has been whitelisted, export the op as an Flex op. Otherwise,
|
// it and it has been allowlisted, export the op as an Flex op. Otherwise,
|
||||||
// export it as a regular custom op.
|
// export it as a regular custom op.
|
||||||
const tensorflow::OpDef* op_def = nullptr;
|
const tensorflow::OpDef* op_def = nullptr;
|
||||||
if (!tensorflow::OpRegistry::Global()
|
if (!tensorflow::OpRegistry::Global()
|
||||||
@ -2125,9 +2125,9 @@ bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!::tflite::flex::IsWhitelistedFlexOp(tensorflow_op_name)) {
|
if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
|
||||||
LOG(WARNING) << "Op " << tensorflow_op_name
|
LOG(WARNING) << "Op " << tensorflow_op_name
|
||||||
<< " is a valid TensorFlow op but has not been whitelisted for"
|
<< " is a valid TensorFlow op but has not been allowlisted for"
|
||||||
" the TensorFlow Lite flex op set.";
|
" the TensorFlow Lite flex op set.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -156,7 +156,7 @@ To do so, we utilize the `preprocess_coco_minival` Python binary as follows:
|
|||||||
bazel run //tensorflow/lite/tools/evaluation/tasks/coco_object_detection:preprocess_coco_minival -- \
|
bazel run //tensorflow/lite/tools/evaluation/tasks/coco_object_detection:preprocess_coco_minival -- \
|
||||||
--images_folder=/path/to/val2014 \
|
--images_folder=/path/to/val2014 \
|
||||||
--instances_file=/path/to/instances_val2014.json \
|
--instances_file=/path/to/instances_val2014.json \
|
||||||
--whitelist_file=/path/to/minival_whitelist.txt \
|
--allowlist_file=/path/to/minival_allowlist.txt \
|
||||||
--output_folder=/path/to/output/folder
|
--output_folder=/path/to/output/folder
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -16,13 +16,13 @@
|
|||||||
|
|
||||||
The 2014 validation images & annotations can be downloaded from:
|
The 2014 validation images & annotations can be downloaded from:
|
||||||
http://cocodataset.org/#download
|
http://cocodataset.org/#download
|
||||||
The minival image ID whitelist, a subset of the 2014 validation set, can be
|
The minival image ID allowlist, a subset of the 2014 validation set, can be
|
||||||
found here:
|
found here:
|
||||||
https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_minival_ids.txt.
|
https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_minival_ids.txt.
|
||||||
|
|
||||||
This script takes in the original images folder, instances JSON file and
|
This script takes in the original images folder, instances JSON file and
|
||||||
image ID whitelist and produces the following in the specified output folder:
|
image ID allowlist and produces the following in the specified output folder:
|
||||||
A subfolder for whitelisted images (images/), and a file (ground_truth.pbtxt)
|
A subfolder for allowlisted images (images/), and a file (ground_truth.pbtxt)
|
||||||
containing an instance of tflite::evaluation::ObjectDetectionGroundTruth.
|
containing an instance of tflite::evaluation::ObjectDetectionGroundTruth.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -40,17 +40,17 @@ from tensorflow.lite.tools.evaluation.proto import evaluation_stages_pb2
|
|||||||
|
|
||||||
|
|
||||||
def _get_ground_truth_detections(instances_file,
|
def _get_ground_truth_detections(instances_file,
|
||||||
whitelist_file=None,
|
allowlist_file=None,
|
||||||
num_images=None):
|
num_images=None):
|
||||||
"""Processes the annotations JSON file and returns ground truth data corresponding to whitelisted image IDs.
|
"""Processes the annotations JSON file and returns ground truth data corresponding to allowlisted image IDs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
instances_file: COCO instances JSON file, usually named as
|
instances_file: COCO instances JSON file, usually named as
|
||||||
instances_val20xx.json.
|
instances_val20xx.json.
|
||||||
whitelist_file: File containing COCO minival image IDs to whitelist for
|
allowlist_file: File containing COCO minival image IDs to allowlist for
|
||||||
evaluation, one per line.
|
evaluation, one per line.
|
||||||
num_images: Number of whitelisted images to pre-process. First num_images
|
num_images: Number of allowlisted images to pre-process. First num_images
|
||||||
are chosen based on sorted list of filenames. If None, all whitelisted
|
are chosen based on sorted list of filenames. If None, all allowlisted
|
||||||
files are preprocessed.
|
files are preprocessed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -70,17 +70,17 @@ def _get_ground_truth_detections(instances_file,
|
|||||||
image_data = collections.OrderedDict()
|
image_data = collections.OrderedDict()
|
||||||
all_file_names = []
|
all_file_names = []
|
||||||
|
|
||||||
# Read whitelist.
|
# Read allowlist.
|
||||||
if whitelist_file is not None:
|
if allowlist_file is not None:
|
||||||
with open(whitelist_file, 'r') as whitelist:
|
with open(allowlist_file, 'r') as allowlist:
|
||||||
image_id_whitelist = set([int(x) for x in whitelist.readlines()])
|
image_id_allowlist = set([int(x) for x in allowlist.readlines()])
|
||||||
else:
|
else:
|
||||||
image_id_whitelist = [image['id'] for image in data_dict['images']]
|
image_id_allowlist = [image['id'] for image in data_dict['images']]
|
||||||
|
|
||||||
# Get image names and dimensions.
|
# Get image names and dimensions.
|
||||||
for image_dict in data_dict['images']:
|
for image_dict in data_dict['images']:
|
||||||
image_id = image_dict['id']
|
image_id = image_dict['id']
|
||||||
if image_id not in image_id_whitelist:
|
if image_id not in image_id_allowlist:
|
||||||
continue
|
continue
|
||||||
image_data_dict = {}
|
image_data_dict = {}
|
||||||
image_data_dict['id'] = image_dict['id']
|
image_data_dict['id'] = image_dict['id']
|
||||||
@ -99,7 +99,7 @@ def _get_ground_truth_detections(instances_file,
|
|||||||
# Get detected object annotations per image.
|
# Get detected object annotations per image.
|
||||||
for annotation_dict in data_dict['annotations']:
|
for annotation_dict in data_dict['annotations']:
|
||||||
image_id = annotation_dict['image_id']
|
image_id = annotation_dict['image_id']
|
||||||
if image_id not in image_id_whitelist:
|
if image_id not in image_id_allowlist:
|
||||||
continue
|
continue
|
||||||
if image_id not in image_data:
|
if image_id not in image_data:
|
||||||
continue
|
continue
|
||||||
@ -133,7 +133,7 @@ def _dump_data(ground_truth_detections, images_folder_path, output_folder_path):
|
|||||||
"""Dumps images & data from ground-truth objects into output_folder_path.
|
"""Dumps images & data from ground-truth objects into output_folder_path.
|
||||||
|
|
||||||
The following are created in output_folder_path:
|
The following are created in output_folder_path:
|
||||||
images/: sub-folder for whitelisted validation images.
|
images/: sub-folder for allowlisted validation images.
|
||||||
ground_truth.pb: A binary proto file containing all ground-truth
|
ground_truth.pb: A binary proto file containing all ground-truth
|
||||||
object-sets.
|
object-sets.
|
||||||
|
|
||||||
@ -193,14 +193,14 @@ def _parse_args():
|
|||||||
help='Full path of the input JSON file, like instances_val20xx.json.',
|
help='Full path of the input JSON file, like instances_val20xx.json.',
|
||||||
required=True)
|
required=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--whitelist_file',
|
'--allowlist_file',
|
||||||
type=str,
|
type=str,
|
||||||
help='File with COCO image ids to preprocess, one on each line.',
|
help='File with COCO image ids to preprocess, one on each line.',
|
||||||
required=False)
|
required=False)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--num_images',
|
'--num_images',
|
||||||
type=int,
|
type=int,
|
||||||
help='Number of whitelisted images to preprocess into the output folder.',
|
help='Number of allowlisted images to preprocess into the output folder.',
|
||||||
required=False)
|
required=False)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--output_folder',
|
'--output_folder',
|
||||||
@ -213,6 +213,6 @@ def _parse_args():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = _parse_args()
|
args = _parse_args()
|
||||||
ground_truths = _get_ground_truth_detections(args.instances_file,
|
ground_truths = _get_ground_truth_detections(args.instances_file,
|
||||||
args.whitelist_file,
|
args.allowlist_file,
|
||||||
args.num_images)
|
args.num_images)
|
||||||
_dump_data(ground_truths, args.images_folder, args.output_folder)
|
_dump_data(ground_truths, args.images_folder, args.output_folder)
|
||||||
|
@ -55,7 +55,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
|||||||
const TensorType& output_type, bool allow_float,
|
const TensorType& output_type, bool allow_float,
|
||||||
ErrorReporter* error_reporter);
|
ErrorReporter* error_reporter);
|
||||||
|
|
||||||
// Same as above, but enables only quantizing a whitelist of operations,
|
// Same as above, but enables only quantizing an allowlist of operations,
|
||||||
// specified by their operator output name.
|
// specified by their operator output name.
|
||||||
//
|
//
|
||||||
// Note: This is a private API, subject to change.
|
// Note: This is a private API, subject to change.
|
||||||
|
@ -158,6 +158,6 @@ _exported_dunders = set([
|
|||||||
'__monolithic_build__',
|
'__monolithic_build__',
|
||||||
])
|
])
|
||||||
|
|
||||||
# Expose symbols minus dunders, unless they are whitelisted above.
|
# Expose symbols minus dunders, unless they are allowlisted above.
|
||||||
# This is necessary to export our dunders.
|
# This is necessary to export our dunders.
|
||||||
__all__ = [s for s in dir() if s in _exported_dunders or not s.startswith('_')]
|
__all__ = [s for s in dir() if s in _exported_dunders or not s.startswith('_')]
|
||||||
|
@ -177,7 +177,7 @@ class CallTreeTransformer(converter.Base):
|
|||||||
# Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
|
# Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
|
||||||
# the normal mechanisms to bypass these literals because they are sensitive
|
# the normal mechanisms to bypass these literals because they are sensitive
|
||||||
# to the frame they are being called from.
|
# to the frame they are being called from.
|
||||||
# TODO(mdan): Generalize this to a "static whitelist" config.
|
# TODO(mdan): Generalize this to a "static allowlist" config.
|
||||||
if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
|
if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
|
||||||
global set_trace_warned
|
global set_trace_warned
|
||||||
if not set_trace_warned:
|
if not set_trace_warned:
|
||||||
|
@ -32,16 +32,16 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
def whitelist(f):
|
def allowlist(f):
|
||||||
"""Helper that marks a callable as whtelitisted."""
|
"""Helper that marks a callable as whtelitisted."""
|
||||||
if 'whitelisted_module_for_testing' not in sys.modules:
|
if 'allowlisted_module_for_testing' not in sys.modules:
|
||||||
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
|
allowlisted_mod = imp.new_module('allowlisted_module_for_testing')
|
||||||
sys.modules['whitelisted_module_for_testing'] = whitelisted_mod
|
sys.modules['allowlisted_module_for_testing'] = allowlisted_mod
|
||||||
config.CONVERSION_RULES = (
|
config.CONVERSION_RULES = (
|
||||||
(config.DoNotConvert('whitelisted_module_for_testing'),) +
|
(config.DoNotConvert('allowlisted_module_for_testing'),) +
|
||||||
config.CONVERSION_RULES)
|
config.CONVERSION_RULES)
|
||||||
|
|
||||||
f.__module__ = 'whitelisted_module_for_testing'
|
f.__module__ = 'allowlisted_module_for_testing'
|
||||||
|
|
||||||
|
|
||||||
def is_inside_generated_code():
|
def is_inside_generated_code():
|
||||||
|
@ -44,18 +44,18 @@ are handled correctly.
|
|||||||
|
|
||||||
The following types of functions are not converted:
|
The following types of functions are not converted:
|
||||||
|
|
||||||
* functions already converted
|
* functions already converted
|
||||||
* functions defined in in a whitelisted module (see autograph/core/config.py)
|
* functions defined in in a allowlisted module (see autograph/core/config.py)
|
||||||
* non-Python functions (such as native bindings)
|
* non-Python functions (such as native bindings)
|
||||||
* `print`, `pdb.set_trace`, `ipdb.set_trace`
|
* `print`, `pdb.set_trace`, `ipdb.set_trace`
|
||||||
* most built-in functions (exceptions are listed in
|
* most built-in functions (exceptions are listed in
|
||||||
autograph/operators/py_builtins.py)
|
autograph/operators/py_builtins.py)
|
||||||
* constructors
|
* constructors
|
||||||
* functions without source code attached (prints a warning)(see
|
* functions without source code attached (prints a warning)(see
|
||||||
[limitations](limitations.md))
|
[limitations](limitations.md))
|
||||||
* generator functions (prints a warning)
|
* generator functions (prints a warning)
|
||||||
* iterator protocol methods (`__next__`, `__iter__`)
|
* iterator protocol methods (`__next__`, `__iter__`)
|
||||||
* context manager methods (`__enter__`, `__exit__`)
|
* context manager methods (`__enter__`, `__exit__`)
|
||||||
|
|
||||||
When AutoGraph encounters a function that it cannot convert outside of this
|
When AutoGraph encounters a function that it cannot convert outside of this
|
||||||
list, it prints a warning.
|
list, it prints a warning.
|
||||||
|
@ -342,16 +342,16 @@ def converted_call(f,
|
|||||||
raise ValueError('either caller_fn_scope or options must have a value')
|
raise ValueError('either caller_fn_scope or options must have a value')
|
||||||
options = caller_fn_scope.callopts
|
options = caller_fn_scope.callopts
|
||||||
|
|
||||||
if conversion.is_in_whitelist_cache(f, options):
|
if conversion.is_in_allowlist_cache(f, options):
|
||||||
logging.log(2, 'Whitelisted %s: from cache', f)
|
logging.log(2, 'Allowlisted %s: from cache', f)
|
||||||
return _call_unconverted(f, args, kwargs, options, False)
|
return _call_unconverted(f, args, kwargs, options, False)
|
||||||
|
|
||||||
if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
|
if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
|
||||||
logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
|
logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
|
||||||
return _call_unconverted(f, args, kwargs, options, False)
|
return _call_unconverted(f, args, kwargs, options, False)
|
||||||
|
|
||||||
if is_autograph_artifact(f):
|
if is_autograph_artifact(f):
|
||||||
logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f)
|
logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f)
|
||||||
return _call_unconverted(f, args, kwargs, options)
|
return _call_unconverted(f, args, kwargs, options)
|
||||||
|
|
||||||
# If this is a partial, unwrap it and redo all the checks.
|
# If this is a partial, unwrap it and redo all the checks.
|
||||||
@ -385,7 +385,7 @@ def converted_call(f,
|
|||||||
if conversion.is_unsupported(f):
|
if conversion.is_unsupported(f):
|
||||||
return _call_unconverted(f, args, kwargs, options)
|
return _call_unconverted(f, args, kwargs, options)
|
||||||
|
|
||||||
if not options.user_requested and conversion.is_whitelisted(f):
|
if not options.user_requested and conversion.is_allowlisted(f):
|
||||||
return _call_unconverted(f, args, kwargs, options)
|
return _call_unconverted(f, args, kwargs, options)
|
||||||
|
|
||||||
# internal_convert_user_code is for example turned off when issuing a dynamic
|
# internal_convert_user_code is for example turned off when issuing a dynamic
|
||||||
@ -425,13 +425,13 @@ def converted_call(f,
|
|||||||
return _fall_back_unconverted(f, args, kwargs, options, e)
|
return _fall_back_unconverted(f, args, kwargs, options, e)
|
||||||
|
|
||||||
if not hasattr(target_entity, '__code__'):
|
if not hasattr(target_entity, '__code__'):
|
||||||
logging.log(2, 'Permanently whitelisted: %s: native binding',
|
logging.log(2, 'Permanently allowed: %s: native binding',
|
||||||
target_entity)
|
target_entity)
|
||||||
return _call_unconverted(f, args, kwargs, options)
|
return _call_unconverted(f, args, kwargs, options)
|
||||||
elif (hasattr(target_entity.__code__, 'co_filename') and
|
elif (hasattr(target_entity.__code__, 'co_filename') and
|
||||||
target_entity.__code__.co_filename == '<string>'):
|
target_entity.__code__.co_filename == '<string>'):
|
||||||
# TODO(mdan): __globals__['txt'] might work in Py3.
|
# TODO(mdan): __globals__['txt'] might work in Py3.
|
||||||
logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
|
logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)',
|
||||||
target_entity)
|
target_entity)
|
||||||
return _call_unconverted(f, args, kwargs, options)
|
return _call_unconverted(f, args, kwargs, options)
|
||||||
|
|
||||||
@ -462,7 +462,7 @@ def converted_call(f,
|
|||||||
def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
||||||
"""Calls the original function without converting with AutoGraph."""
|
"""Calls the original function without converting with AutoGraph."""
|
||||||
if update_cache:
|
if update_cache:
|
||||||
conversion.cache_whitelisted(f, options)
|
conversion.cache_allowlisted(f, options)
|
||||||
|
|
||||||
if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
|
if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
|
||||||
return f.__self__.call(args, kwargs)
|
return f.__self__.call(args, kwargs)
|
||||||
@ -482,7 +482,7 @@ def _fall_back_unconverted(f, args, kwargs, options, exc):
|
|||||||
'To silence this warning, decorate the function with'
|
'To silence this warning, decorate the function with'
|
||||||
' @tf.autograph.experimental.do_not_convert')
|
' @tf.autograph.experimental.do_not_convert')
|
||||||
if isinstance(exc, errors.UnsupportedLanguageElementError):
|
if isinstance(exc, errors.UnsupportedLanguageElementError):
|
||||||
if not conversion.is_in_whitelist_cache(f, options):
|
if not conversion.is_in_allowlist_cache(f, options):
|
||||||
logging.warn(warning_template, f, '', exc)
|
logging.warn(warning_template, f, '', exc)
|
||||||
else:
|
else:
|
||||||
file_bug_message = (
|
file_bug_message = (
|
||||||
@ -516,7 +516,7 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
|
|||||||
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
||||||
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
||||||
specify.
|
specify.
|
||||||
user_requested: bool, whether to ignore the conversion whitelist. See
|
user_requested: bool, whether to ignore the conversion allowlist. See
|
||||||
ConversionOptions.user_requested.
|
ConversionOptions.user_requested.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -203,14 +203,14 @@ class ApiTest(test.TestCase):
|
|||||||
z = x + y
|
z = x + y
|
||||||
return z
|
return z
|
||||||
|
|
||||||
test_method_whitelisted = api.do_not_convert(test_method)
|
test_method_allowlisted = api.do_not_convert(test_method)
|
||||||
|
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
self.assertTrue(tf_inspect.ismethod(tc.test_method_whitelisted))
|
self.assertTrue(tf_inspect.ismethod(tc.test_method_allowlisted))
|
||||||
# Because the wrapped function is not generated, we can't preserve its
|
# Because the wrapped function is not generated, we can't preserve its
|
||||||
# arg spec.
|
# arg spec.
|
||||||
self.assertEqual((),
|
self.assertEqual((),
|
||||||
tuple(function_utils.fn_args(tc.test_method_whitelisted)))
|
tuple(function_utils.fn_args(tc.test_method_allowlisted)))
|
||||||
|
|
||||||
def test_do_not_convert_callable_object(self):
|
def test_do_not_convert_callable_object(self):
|
||||||
|
|
||||||
@ -521,12 +521,12 @@ class ApiTest(test.TestCase):
|
|||||||
ag_logging.set_verbosity(0, False)
|
ag_logging.set_verbosity(0, False)
|
||||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||||
|
|
||||||
def test_converted_call_partial_of_whitelisted_function(self):
|
def test_converted_call_partial_of_allowlisted_function(self):
|
||||||
|
|
||||||
def test_fn(_):
|
def test_fn(_):
|
||||||
self.assertFalse(converter_testing.is_inside_generated_code())
|
self.assertFalse(converter_testing.is_inside_generated_code())
|
||||||
|
|
||||||
converter_testing.whitelist(test_fn)
|
converter_testing.allowlist(test_fn)
|
||||||
api.converted_call(
|
api.converted_call(
|
||||||
functools.partial(test_fn, None), (), None, options=DEFAULT_RECURSIVE)
|
functools.partial(test_fn, None), (), None, options=DEFAULT_RECURSIVE)
|
||||||
|
|
||||||
@ -563,7 +563,7 @@ class ApiTest(test.TestCase):
|
|||||||
f, (g, constant_op.constant(1)), None, options=DEFAULT_RECURSIVE)
|
f, (g, constant_op.constant(1)), None, options=DEFAULT_RECURSIVE)
|
||||||
self.assertEqual(self.evaluate(x), 1)
|
self.assertEqual(self.evaluate(x), 1)
|
||||||
|
|
||||||
def test_converted_call_forced_when_explicitly_whitelisted(self):
|
def test_converted_call_forced_when_explicitly_allowlisted(self):
|
||||||
|
|
||||||
@api.do_not_convert()
|
@api.do_not_convert()
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -606,7 +606,7 @@ class ApiTest(test.TestCase):
|
|||||||
self.assertIsNotNone(
|
self.assertIsNotNone(
|
||||||
api.converted_call(f, (1, 2, 3, 4), None, options=opts))
|
api.converted_call(f, (1, 2, 3, 4), None, options=opts))
|
||||||
|
|
||||||
def test_converted_call_whitelisted_method(self):
|
def test_converted_call_allowlisted_method(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
@ -614,19 +614,19 @@ class ApiTest(test.TestCase):
|
|||||||
return converter_testing.is_inside_generated_code()
|
return converter_testing.is_inside_generated_code()
|
||||||
|
|
||||||
obj = TestClass()
|
obj = TestClass()
|
||||||
converter_testing.whitelist(obj.method.__func__)
|
converter_testing.allowlist(obj.method.__func__)
|
||||||
|
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
|
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
|
||||||
|
|
||||||
def test_converted_call_whitelisted_method_via_owner(self):
|
def test_converted_call_allowlisted_method_via_owner(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
def method(self):
|
def method(self):
|
||||||
return converter_testing.is_inside_generated_code()
|
return converter_testing.is_inside_generated_code()
|
||||||
|
|
||||||
converter_testing.whitelist(TestClass)
|
converter_testing.allowlist(TestClass)
|
||||||
|
|
||||||
obj = TestClass()
|
obj = TestClass()
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
@ -852,7 +852,7 @@ class ApiTest(test.TestCase):
|
|||||||
# invocation would fail.
|
# invocation would fail.
|
||||||
self.assertEqual(self.evaluate(call_in_default_context()), 1)
|
self.assertEqual(self.evaluate(call_in_default_context()), 1)
|
||||||
|
|
||||||
def test_converted_call_caching_of_whitelisted_bound_methods(self):
|
def test_converted_call_caching_of_allowlisted_bound_methods(self):
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
@ -863,7 +863,7 @@ class ApiTest(test.TestCase):
|
|||||||
return self.__private
|
return self.__private
|
||||||
|
|
||||||
# TODO(mdan): Refactor to avoid this use of global state.
|
# TODO(mdan): Refactor to avoid this use of global state.
|
||||||
cache_size_before = len(conversion._WHITELIST_CACHE)
|
cache_size_before = len(conversion._ALLOWLIST_CACHE)
|
||||||
|
|
||||||
# First invocation with fallback on, to allow recording it into cache.
|
# First invocation with fallback on, to allow recording it into cache.
|
||||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '0'
|
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '0'
|
||||||
@ -871,15 +871,15 @@ class ApiTest(test.TestCase):
|
|||||||
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
|
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
|
||||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||||
|
|
||||||
# Entry should be added to the whitelist cache.
|
# Entry should be added to the allowlist cache.
|
||||||
self.assertEqual(len(conversion._WHITELIST_CACHE), cache_size_before + 1)
|
self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1)
|
||||||
|
|
||||||
# A second invocation should go through even with fallback off.
|
# A second invocation should go through even with fallback off.
|
||||||
tc = TestClass()
|
tc = TestClass()
|
||||||
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
|
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
|
||||||
|
|
||||||
# No new entries should appear in the whitelist cache.
|
# No new entries should appear in the allowlist cache.
|
||||||
self.assertEqual(len(conversion._WHITELIST_CACHE), cache_size_before + 1)
|
self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1)
|
||||||
|
|
||||||
def test_context_tracking_direct_calls(self):
|
def test_context_tracking_direct_calls(self):
|
||||||
|
|
||||||
@ -1102,7 +1102,7 @@ class ApiTest(test.TestCase):
|
|||||||
|
|
||||||
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
|
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
|
||||||
|
|
||||||
def test_tf_convert_whitelisted_method(self):
|
def test_tf_convert_allowlisted_method(self):
|
||||||
|
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
self.skipTest('Test bank not comptible with Python 2.')
|
self.skipTest('Test bank not comptible with Python 2.')
|
||||||
@ -1112,7 +1112,7 @@ class ApiTest(test.TestCase):
|
|||||||
def method(self):
|
def method(self):
|
||||||
return converter_testing.is_inside_generated_code()
|
return converter_testing.is_inside_generated_code()
|
||||||
|
|
||||||
converter_testing.whitelist(TestClass.method)
|
converter_testing.allowlist(TestClass.method)
|
||||||
|
|
||||||
obj = TestClass()
|
obj = TestClass()
|
||||||
converted_call = api.tf_convert(
|
converted_call = api.tf_convert(
|
||||||
|
@ -31,7 +31,7 @@ from tensorflow.python.eager import function
|
|||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
_WHITELIST_CACHE = cache.UnboundInstanceCache()
|
_ALLOWLIST_CACHE = cache.UnboundInstanceCache()
|
||||||
|
|
||||||
|
|
||||||
def _is_of_known_loaded_module(f, module_name):
|
def _is_of_known_loaded_module(f, module_name):
|
||||||
@ -80,53 +80,53 @@ def is_unsupported(o):
|
|||||||
'{} appears to be decorated by wrapt, which is not yet supported'
|
'{} appears to be decorated by wrapt, which is not yet supported'
|
||||||
' by AutoGraph. The function will run as-is.'
|
' by AutoGraph. The function will run as-is.'
|
||||||
' You may still apply AutoGraph before the wrapt decorator.'.format(o))
|
' You may still apply AutoGraph before the wrapt decorator.'.format(o))
|
||||||
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', o)
|
logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
|
if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
|
||||||
logging.log(2, 'Permanently whitelisted: %s: lru_cache', o)
|
logging.log(2, 'Permanently allowed: %s: lru_cache', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Constructors are permanently whitelisted.
|
# Constructors are permanently allowed.
|
||||||
# TODO(mdan): Toggle as experimental feature instead.
|
# TODO(mdan): Toggle as experimental feature instead.
|
||||||
# TODO(b/124016764): Remove this limitation.
|
# TODO(b/124016764): Remove this limitation.
|
||||||
if inspect_utils.isconstructor(o):
|
if inspect_utils.isconstructor(o):
|
||||||
logging.log(2, 'Permanently whitelisted: %s: constructor', o)
|
logging.log(2, 'Permanently allowed: %s: constructor', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Other built-in modules are permanently whitelisted.
|
# Other built-in modules are permanently allowed.
|
||||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||||
if any(
|
if any(
|
||||||
_is_of_known_loaded_module(o, m)
|
_is_of_known_loaded_module(o, m)
|
||||||
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
||||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', o)
|
logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Custom ops and kernels are also permanently whitelisted.
|
# Custom ops and kernels are also permanently allowed.
|
||||||
# See tensorflow.framework.load_library.
|
# See tensorflow.framework.load_library.
|
||||||
if (hasattr(o, '__module__') and
|
if (hasattr(o, '__module__') and
|
||||||
hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
|
hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
|
||||||
logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', o)
|
logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
||||||
def is_whitelisted(
|
def is_allowlisted(
|
||||||
o, check_call_override=True, allow_namedtuple_subclass=False):
|
o, check_call_override=True, allow_namedtuple_subclass=False):
|
||||||
"""Checks whether an entity is whitelisted for use in graph mode.
|
"""Checks whether an entity is allowed for use in graph mode.
|
||||||
|
|
||||||
Examples of whitelisted entities include all members of the tensorflow
|
Examples of allowed entities include all members of the tensorflow
|
||||||
package.
|
package.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
o: A Python entity.
|
o: A Python entity.
|
||||||
check_call_override: Reserved for internal use. When set to `False`, it
|
check_call_override: Reserved for internal use. When set to `False`, it
|
||||||
disables the rule according to which classes are whitelisted if their
|
disables the rule according to which classes are allowed if their
|
||||||
__call__ method is whitelisted.
|
__call__ method is allowed.
|
||||||
allow_namedtuple_subclass: Reserved for internal use. When `True`,
|
allow_namedtuple_subclass: Reserved for internal use. When `True`,
|
||||||
namedtuple subclasses are not whitelisted.
|
namedtuple subclasses are not allowed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Boolean
|
Boolean
|
||||||
@ -144,10 +144,10 @@ def is_whitelisted(
|
|||||||
for rule in config.CONVERSION_RULES:
|
for rule in config.CONVERSION_RULES:
|
||||||
action = rule.get_action(m)
|
action = rule.get_action(m)
|
||||||
if action == config.Action.CONVERT:
|
if action == config.Action.CONVERT:
|
||||||
logging.log(2, 'Not whitelisted: %s: %s', o, rule)
|
logging.log(2, 'Not allowed: %s: %s', o, rule)
|
||||||
return False
|
return False
|
||||||
elif action == config.Action.DO_NOT_CONVERT:
|
elif action == config.Action.DO_NOT_CONVERT:
|
||||||
logging.log(2, 'Whitelisted: %s: %s', o, rule)
|
logging.log(2, 'Allowlisted: %s: %s', o, rule)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# The check for __code__ below is because isgeneratorfunction crashes
|
# The check for __code__ below is because isgeneratorfunction crashes
|
||||||
@ -156,26 +156,26 @@ def is_whitelisted(
|
|||||||
logging.warn(
|
logging.warn(
|
||||||
'Entity %s appears to be a generator function. It will not be converted'
|
'Entity %s appears to be a generator function. It will not be converted'
|
||||||
' by AutoGraph.', o)
|
' by AutoGraph.', o)
|
||||||
logging.log(2, 'Whitelisted: %s: generator functions are not converted', o)
|
logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if (check_call_override and not tf_inspect.isclass(o) and
|
if (check_call_override and not tf_inspect.isclass(o) and
|
||||||
hasattr(o, '__call__')):
|
hasattr(o, '__call__')):
|
||||||
# Callable objects: whitelisted if their __call__ method is.
|
# Callable objects: allowed if their __call__ method is.
|
||||||
# The type check avoids infinite recursion around the __call__ method
|
# The type check avoids infinite recursion around the __call__ method
|
||||||
# of function objects.
|
# of function objects.
|
||||||
if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__): # pylint: disable=unidiomatic-typecheck
|
if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck
|
||||||
logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
|
logging.log(2, 'Allowlisted: %s: object __call__ allowed', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
owner_class = None
|
owner_class = None
|
||||||
if tf_inspect.ismethod(o):
|
if tf_inspect.ismethod(o):
|
||||||
# Methods of whitelisted classes are also whitelisted, even if they are
|
# Methods of allowed classes are also allowed, even if they are
|
||||||
# bound via user subclasses.
|
# bound via user subclasses.
|
||||||
#
|
#
|
||||||
# For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
|
# For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
|
||||||
# defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
|
# defined as below. `tf.Foo` is allowed. Then `baz.bar` is also
|
||||||
# whitelisted.
|
# allowed.
|
||||||
#
|
#
|
||||||
# class Custom(tf.Foo):
|
# class Custom(tf.Foo):
|
||||||
# pass
|
# pass
|
||||||
@ -183,22 +183,22 @@ def is_whitelisted(
|
|||||||
# baz = Custom()
|
# baz = Custom()
|
||||||
#
|
#
|
||||||
# For the example above, if `Custom` did overload `bar`, then it would no
|
# For the example above, if `Custom` did overload `bar`, then it would no
|
||||||
# longer be whitelisted.
|
# longer be allowed.
|
||||||
|
|
||||||
owner_class = inspect_utils.getmethodclass(o)
|
owner_class = inspect_utils.getmethodclass(o)
|
||||||
if owner_class is function.TfMethodTarget:
|
if owner_class is function.TfMethodTarget:
|
||||||
owner_class = o.__self__.target_class
|
owner_class = o.__self__.target_class
|
||||||
if owner_class is not None:
|
if owner_class is not None:
|
||||||
if issubclass(owner_class, unittest.TestCase):
|
if issubclass(owner_class, unittest.TestCase):
|
||||||
logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
|
logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
||||||
if is_whitelisted(
|
if is_allowlisted(
|
||||||
owner_class,
|
owner_class,
|
||||||
check_call_override=False,
|
check_call_override=False,
|
||||||
allow_namedtuple_subclass=True):
|
allow_namedtuple_subclass=True):
|
||||||
logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
|
logging.log(2, 'Allowlisted: %s: owner is allowed %s', o,
|
||||||
owner_class)
|
owner_class)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -208,27 +208,27 @@ def is_whitelisted(
|
|||||||
# graph mode since they are just containers.
|
# graph mode since they are just containers.
|
||||||
if allow_namedtuple_subclass:
|
if allow_namedtuple_subclass:
|
||||||
if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
|
if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
|
||||||
logging.log(2, 'Whitelisted: %s: named tuple', o)
|
logging.log(2, 'Allowlisted: %s: named tuple', o)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
|
logging.log(2, 'Allowlisted: %s: named tuple or subclass', o)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logging.log(2, 'Not whitelisted: %s: default rule', o)
|
logging.log(2, 'Not allowed: %s: default rule', o)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_in_whitelist_cache(entity, options):
|
def is_in_allowlist_cache(entity, options):
|
||||||
try:
|
try:
|
||||||
return _WHITELIST_CACHE.has(entity, options)
|
return _ALLOWLIST_CACHE.has(entity, options)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def cache_whitelisted(entity, options):
|
def cache_allowlisted(entity, options):
|
||||||
try:
|
try:
|
||||||
_WHITELIST_CACHE[entity][options] = True
|
_ALLOWLIST_CACHE[entity][options] = True
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
||||||
pass
|
pass
|
||||||
|
@ -43,16 +43,16 @@ class ConversionTest(test.TestCase):
|
|||||||
options=converter.ConversionOptions(recursive=True),
|
options=converter.ConversionOptions(recursive=True),
|
||||||
autograph_module=api)
|
autograph_module=api)
|
||||||
|
|
||||||
def test_is_whitelisted(self):
|
def test_is_allowlisted(self):
|
||||||
|
|
||||||
def test_fn():
|
def test_fn():
|
||||||
return constant_op.constant(1)
|
return constant_op.constant(1)
|
||||||
|
|
||||||
self.assertFalse(conversion.is_whitelisted(test_fn))
|
self.assertFalse(conversion.is_allowlisted(test_fn))
|
||||||
self.assertTrue(conversion.is_whitelisted(utils))
|
self.assertTrue(conversion.is_allowlisted(utils))
|
||||||
self.assertTrue(conversion.is_whitelisted(constant_op.constant))
|
self.assertTrue(conversion.is_allowlisted(constant_op.constant))
|
||||||
|
|
||||||
def test_is_whitelisted_tensorflow_like(self):
|
def test_is_allowlisted_tensorflow_like(self):
|
||||||
|
|
||||||
tf_like = imp.new_module('tensorflow_foo')
|
tf_like = imp.new_module('tensorflow_foo')
|
||||||
def test_fn():
|
def test_fn():
|
||||||
@ -60,13 +60,13 @@ class ConversionTest(test.TestCase):
|
|||||||
tf_like.test_fn = test_fn
|
tf_like.test_fn = test_fn
|
||||||
test_fn.__module__ = tf_like
|
test_fn.__module__ = tf_like
|
||||||
|
|
||||||
self.assertFalse(conversion.is_whitelisted(tf_like.test_fn))
|
self.assertFalse(conversion.is_allowlisted(tf_like.test_fn))
|
||||||
|
|
||||||
def test_is_whitelisted_callable_whitelisted_call(self):
|
def test_is_allowlisted_callable_allowlisted_call(self):
|
||||||
|
|
||||||
whitelisted_mod = imp.new_module('test_whitelisted_call')
|
allowlisted_mod = imp.new_module('test_allowlisted_call')
|
||||||
sys.modules['test_whitelisted_call'] = whitelisted_mod
|
sys.modules['test_allowlisted_call'] = allowlisted_mod
|
||||||
config.CONVERSION_RULES = ((config.DoNotConvert('test_whitelisted_call'),) +
|
config.CONVERSION_RULES = ((config.DoNotConvert('test_allowlisted_call'),) +
|
||||||
config.CONVERSION_RULES)
|
config.CONVERSION_RULES)
|
||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
@ -74,14 +74,14 @@ class ConversionTest(test.TestCase):
|
|||||||
def __call__(self):
|
def __call__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def whitelisted_method(self):
|
def allowlisted_method(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
TestClass.__module__ = 'test_whitelisted_call'
|
TestClass.__module__ = 'test_allowlisted_call'
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
TestClass.__call__.__func__.__module__ = 'test_whitelisted_call'
|
TestClass.__call__.__func__.__module__ = 'test_allowlisted_call'
|
||||||
else:
|
else:
|
||||||
TestClass.__call__.__module__ = 'test_whitelisted_call'
|
TestClass.__call__.__module__ = 'test_allowlisted_call'
|
||||||
|
|
||||||
class Subclass(TestClass):
|
class Subclass(TestClass):
|
||||||
|
|
||||||
@ -90,20 +90,21 @@ class ConversionTest(test.TestCase):
|
|||||||
|
|
||||||
tc = Subclass()
|
tc = Subclass()
|
||||||
|
|
||||||
self.assertTrue(conversion.is_whitelisted(TestClass.__call__))
|
self.assertTrue(conversion.is_allowlisted(TestClass.__call__))
|
||||||
self.assertTrue(conversion.is_whitelisted(tc))
|
self.assertTrue(conversion.is_allowlisted(tc))
|
||||||
self.assertTrue(conversion.is_whitelisted(tc.__call__))
|
self.assertTrue(conversion.is_allowlisted(tc.__call__))
|
||||||
self.assertTrue(conversion.is_whitelisted(tc.whitelisted_method))
|
self.assertTrue(conversion.is_allowlisted(tc.allowlisted_method))
|
||||||
self.assertFalse(conversion.is_whitelisted(Subclass))
|
self.assertFalse(conversion.is_allowlisted(Subclass))
|
||||||
self.assertFalse(conversion.is_whitelisted(tc.converted_method))
|
self.assertFalse(conversion.is_allowlisted(tc.converted_method))
|
||||||
|
|
||||||
|
def test_is_allowlisted_tfmethodwrapper(self):
|
||||||
|
|
||||||
def test_is_whitelisted_tfmethodwrapper(self):
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
|
|
||||||
def member_function(self):
|
def member_function(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
TestClass.__module__ = 'test_whitelisted_call'
|
TestClass.__module__ = 'test_allowlisted_call'
|
||||||
test_obj = TestClass()
|
test_obj = TestClass()
|
||||||
|
|
||||||
def test_fn(self):
|
def test_fn(self):
|
||||||
@ -114,14 +115,14 @@ class ConversionTest(test.TestCase):
|
|||||||
function.TfMethodTarget(
|
function.TfMethodTarget(
|
||||||
weakref.ref(test_obj), test_obj.member_function))
|
weakref.ref(test_obj), test_obj.member_function))
|
||||||
|
|
||||||
self.assertTrue(conversion.is_whitelisted(bound_method))
|
self.assertTrue(conversion.is_allowlisted(bound_method))
|
||||||
|
|
||||||
def test_is_whitelisted_pybind(self):
|
def test_is_allowlisted_pybind(self):
|
||||||
test_object = pybind_for_testing.TestClassDef()
|
test_object = pybind_for_testing.TestClassDef()
|
||||||
with test.mock.patch.object(config, 'CONVERSION_RULES', ()):
|
with test.mock.patch.object(config, 'CONVERSION_RULES', ()):
|
||||||
# TODO(mdan): This should return True for functions and methods.
|
# TODO(mdan): This should return True for functions and methods.
|
||||||
# Note: currently, native bindings are whitelisted by a separate check.
|
# Note: currently, native bindings are allowlisted by a separate check.
|
||||||
self.assertFalse(conversion.is_whitelisted(test_object.method))
|
self.assertFalse(conversion.is_allowlisted(test_object.method))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -477,12 +477,14 @@ class AnfConfiguredTest(AnfTestBase):
|
|||||||
def test_anf_some_function_calls(self):
|
def test_anf_some_function_calls(self):
|
||||||
# Another example specific configuration that differs from the default:
|
# Another example specific configuration that differs from the default:
|
||||||
# Moving all arguments out of some function calls but leaving others be.
|
# Moving all arguments out of some function calls but leaving others be.
|
||||||
whitelist = ['foo']
|
allowlist = ['foo']
|
||||||
|
|
||||||
def transform(parent, field, child):
|
def transform(parent, field, child):
|
||||||
del field
|
del field
|
||||||
del child
|
del child
|
||||||
func_name = parent.func.id
|
func_name = parent.func.id
|
||||||
return str(func_name) in whitelist
|
return str(func_name) in allowlist
|
||||||
|
|
||||||
config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, anf.ANY), transform)]
|
config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, anf.ANY), transform)]
|
||||||
|
|
||||||
def test_function(x, foo, bar):
|
def test_function(x, foo, bar):
|
||||||
|
@ -24,10 +24,9 @@ from tensorflow.python.autograph.pyct import origin_info
|
|||||||
|
|
||||||
|
|
||||||
class FrameInfo(
|
class FrameInfo(
|
||||||
collections.namedtuple(
|
collections.namedtuple('FrameInfo',
|
||||||
'FrameInfo',
|
('filename', 'lineno', 'function_name', 'code',
|
||||||
('filename', 'lineno', 'function_name', 'code', 'is_converted',
|
'is_converted', 'is_allowlisted'))):
|
||||||
'is_whitelisted'))):
|
|
||||||
|
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
@ -75,7 +74,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
|||||||
origin_info.create_source_map.
|
origin_info.create_source_map.
|
||||||
converter_filename: str, the file path of the converted module. Call frames
|
converter_filename: str, the file path of the converted module. Call frames
|
||||||
corresponding to this module are elided and their preceding frames are
|
corresponding to this module are elided and their preceding frames are
|
||||||
marked as whitelisted. Note that frames enclosing converted code are
|
marked as allowlisted. Note that frames enclosing converted code are
|
||||||
dropped using a different mechanism.
|
dropped using a different mechanism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -93,7 +92,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
|||||||
function_name=origin.function_name,
|
function_name=origin.function_name,
|
||||||
code=origin.source_code_line,
|
code=origin.source_code_line,
|
||||||
is_converted=True,
|
is_converted=True,
|
||||||
is_whitelisted=False)
|
is_allowlisted=False)
|
||||||
result_frames.append(fi)
|
result_frames.append(fi)
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -107,7 +106,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
|||||||
function_name=prev.function_name,
|
function_name=prev.function_name,
|
||||||
code=prev.code,
|
code=prev.code,
|
||||||
is_converted=False,
|
is_converted=False,
|
||||||
is_whitelisted=True)
|
is_allowlisted=True)
|
||||||
result_frames[-1] = fi
|
result_frames[-1] = fi
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -117,7 +116,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
|||||||
function_name=function_name,
|
function_name=function_name,
|
||||||
code=text,
|
code=text,
|
||||||
is_converted=False,
|
is_converted=False,
|
||||||
is_whitelisted=False)
|
is_allowlisted=False)
|
||||||
result_frames.append(fi)
|
result_frames.append(fi)
|
||||||
|
|
||||||
return tuple(result_frames)
|
return tuple(result_frames)
|
||||||
@ -188,7 +187,7 @@ class ErrorMetadataBase(object):
|
|||||||
frame_info.function_name)
|
frame_info.function_name)
|
||||||
if frame_info.is_converted:
|
if frame_info.is_converted:
|
||||||
formatted_line += ' *'
|
formatted_line += ' *'
|
||||||
elif frame_info.is_whitelisted:
|
elif frame_info.is_allowlisted:
|
||||||
formatted_line += ' **'
|
formatted_line += ' **'
|
||||||
lines.append(formatted_line)
|
lines.append(formatted_line)
|
||||||
|
|
||||||
|
@ -2250,7 +2250,7 @@ class DatasetV1(DatasetV2):
|
|||||||
# by value _make_dataset() function would try to capture these variant
|
# by value _make_dataset() function would try to capture these variant
|
||||||
# tensor dataset inputs, which are marked as stateful ops and would throw
|
# tensor dataset inputs, which are marked as stateful ops and would throw
|
||||||
# an error if we try and capture them. We therefore traverse the graph
|
# an error if we try and capture them. We therefore traverse the graph
|
||||||
# to find all these ops and whitelist them so that the capturing
|
# to find all these ops and allowlist them so that the capturing
|
||||||
# logic instead of throwing an error recreates these ops which is what was
|
# logic instead of throwing an error recreates these ops which is what was
|
||||||
# happening before.
|
# happening before.
|
||||||
all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
|
all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
|
||||||
@ -2258,7 +2258,7 @@ class DatasetV1(DatasetV2):
|
|||||||
|
|
||||||
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
|
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
|
||||||
# a 0-argument function.
|
# a 0-argument function.
|
||||||
@function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops)
|
@function.Defun(capture_by_value=True, allowlisted_stateful_ops=all_ds_ops)
|
||||||
def _make_dataset():
|
def _make_dataset():
|
||||||
"""Factory function for a dataset."""
|
"""Factory function for a dataset."""
|
||||||
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
|
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
|
||||||
|
@ -1246,8 +1246,8 @@ class DebugAnalyzer(object):
|
|||||||
parsed = self._arg_parsers["list_source"].parse_args(args)
|
parsed = self._arg_parsers["list_source"].parse_args(args)
|
||||||
source_list = source_utils.list_source_files_against_dump(
|
source_list = source_utils.list_source_files_against_dump(
|
||||||
self._debug_dump,
|
self._debug_dump,
|
||||||
path_regex_whitelist=parsed.path_filter,
|
path_regex_allowlist=parsed.path_filter,
|
||||||
node_name_regex_whitelist=parsed.node_name_filter)
|
node_name_regex_allowlist=parsed.node_name_filter)
|
||||||
|
|
||||||
top_lines = [
|
top_lines = [
|
||||||
RL("List of source files that created nodes in this run", "bold")]
|
RL("List of source files that created nodes in this run", "bold")]
|
||||||
|
@ -1578,9 +1578,9 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testListSourceWithCompiledPythonSourceWorks(self):
|
def testListSourceWithCompiledPythonSourceWorks(self):
|
||||||
def fake_list_source_files_against_dump(dump,
|
def fake_list_source_files_against_dump(dump,
|
||||||
path_regex_whitelist=None,
|
path_regex_allowlist=None,
|
||||||
node_name_regex_whitelist=None):
|
node_name_regex_allowlist=None):
|
||||||
del dump, path_regex_whitelist, node_name_regex_whitelist
|
del dump, path_regex_allowlist, node_name_regex_allowlist
|
||||||
return [("compiled_1.pyc", False, 10, 20, 30, 4),
|
return [("compiled_1.pyc", False, 10, 20, 30, 4),
|
||||||
("compiled_2.pyo", False, 10, 20, 30, 5),
|
("compiled_2.pyo", False, 10, 20, 30, 5),
|
||||||
("uncompiled.py", False, 10, 20, 30, 6)]
|
("uncompiled.py", False, 10, 20, 30, 6)]
|
||||||
|
@ -38,7 +38,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||||||
|
|
||||||
# Many ops have benign NaN outputs, and running them with check_numerics
|
# Many ops have benign NaN outputs, and running them with check_numerics
|
||||||
# on will create unwanted errors
|
# on will create unwanted errors
|
||||||
# TODO(b/142497024): Replace this whitelist with function decorators in the ops
|
# TODO(b/142497024): Replace this allowlist with function decorators in the ops
|
||||||
IGNORE_OP_OUTPUTS = (
|
IGNORE_OP_OUTPUTS = (
|
||||||
# For FusedBatchNorm, if the input tensor is empty then batch_mean and
|
# For FusedBatchNorm, if the input tensor is empty then batch_mean and
|
||||||
# batch_variance will be NaN. reserve_space holds intermediate values
|
# batch_variance will be NaN. reserve_space holds intermediate values
|
||||||
|
@ -83,16 +83,16 @@ def watch_graph(run_options,
|
|||||||
graph,
|
graph,
|
||||||
debug_ops="DebugIdentity",
|
debug_ops="DebugIdentity",
|
||||||
debug_urls=None,
|
debug_urls=None,
|
||||||
node_name_regex_whitelist=None,
|
node_name_regex_allowlist=None,
|
||||||
op_type_regex_whitelist=None,
|
op_type_regex_allowlist=None,
|
||||||
tensor_dtype_regex_whitelist=None,
|
tensor_dtype_regex_allowlist=None,
|
||||||
tolerate_debug_op_creation_failures=False,
|
tolerate_debug_op_creation_failures=False,
|
||||||
global_step=-1,
|
global_step=-1,
|
||||||
reset_disk_byte_usage=False):
|
reset_disk_byte_usage=False):
|
||||||
"""Add debug watches to `RunOptions` for a TensorFlow graph.
|
"""Add debug watches to `RunOptions` for a TensorFlow graph.
|
||||||
|
|
||||||
To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
|
To watch all `Tensor`s on the graph, let both `node_name_regex_allowlist`
|
||||||
and `op_type_regex_whitelist` be the default (`None`).
|
and `op_type_regex_allowlist` be the default (`None`).
|
||||||
|
|
||||||
N.B.:
|
N.B.:
|
||||||
1. Under certain circumstances, the `Tensor` may not get actually watched
|
1. Under certain circumstances, the `Tensor` may not get actually watched
|
||||||
@ -114,17 +114,17 @@ def watch_graph(run_options,
|
|||||||
For debug op types with customizable attributes, each debug op name string
|
For debug op types with customizable attributes, each debug op name string
|
||||||
can optionally contain a list of attribute names, in the syntax of:
|
can optionally contain a list of attribute names, in the syntax of:
|
||||||
debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
|
debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
|
||||||
node_name_regex_whitelist: Regular-expression whitelist for node_name,
|
node_name_regex_allowlist: Regular-expression allowlist for node_name,
|
||||||
e.g., `"(weight_[0-9]+|bias_.*)"`
|
e.g., `"(weight_[0-9]+|bias_.*)"`
|
||||||
op_type_regex_whitelist: Regular-expression whitelist for the op type of
|
op_type_regex_allowlist: Regular-expression allowlist for the op type of
|
||||||
nodes, e.g., `"(Variable|Add)"`.
|
nodes, e.g., `"(Variable|Add)"`.
|
||||||
If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
|
If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
|
||||||
are set, the two filtering operations will occur in a logical `AND`
|
are set, the two filtering operations will occur in a logical `AND`
|
||||||
relation. In other words, a node will be included if and only if it
|
relation. In other words, a node will be included if and only if it
|
||||||
hits both whitelists.
|
hits both allowlists.
|
||||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
|
tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
|
||||||
data type, e.g., `"^int.*"`.
|
data type, e.g., `"^int.*"`.
|
||||||
This whitelist operates in logical `AND` relations to the two whitelists
|
This allowlist operates in logical `AND` relations to the two allowlists
|
||||||
above.
|
above.
|
||||||
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
||||||
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
||||||
@ -142,12 +142,14 @@ def watch_graph(run_options,
|
|||||||
if isinstance(debug_ops, str):
|
if isinstance(debug_ops, str):
|
||||||
debug_ops = [debug_ops]
|
debug_ops = [debug_ops]
|
||||||
|
|
||||||
node_name_pattern = (re.compile(node_name_regex_whitelist)
|
node_name_pattern = (
|
||||||
if node_name_regex_whitelist else None)
|
re.compile(node_name_regex_allowlist)
|
||||||
op_type_pattern = (re.compile(op_type_regex_whitelist)
|
if node_name_regex_allowlist else None)
|
||||||
if op_type_regex_whitelist else None)
|
op_type_pattern = (
|
||||||
tensor_dtype_pattern = (re.compile(tensor_dtype_regex_whitelist)
|
re.compile(op_type_regex_allowlist) if op_type_regex_allowlist else None)
|
||||||
if tensor_dtype_regex_whitelist else None)
|
tensor_dtype_pattern = (
|
||||||
|
re.compile(tensor_dtype_regex_allowlist)
|
||||||
|
if tensor_dtype_regex_allowlist else None)
|
||||||
|
|
||||||
ops = graph.get_operations()
|
ops = graph.get_operations()
|
||||||
for op in ops:
|
for op in ops:
|
||||||
@ -210,7 +212,7 @@ def watch_graph_with_blacklists(run_options,
|
|||||||
"""Add debug tensor watches, blacklisting nodes and op types.
|
"""Add debug tensor watches, blacklisting nodes and op types.
|
||||||
|
|
||||||
This is similar to `watch_graph()`, but the node names and op types are
|
This is similar to `watch_graph()`, but the node names and op types are
|
||||||
blacklisted, instead of whitelisted.
|
blacklisted, instead of allowlisted.
|
||||||
|
|
||||||
N.B.:
|
N.B.:
|
||||||
1. Under certain circumstances, the `Tensor` may not get actually watched
|
1. Under certain circumstances, the `Tensor` may not get actually watched
|
||||||
@ -238,7 +240,7 @@ def watch_graph_with_blacklists(run_options,
|
|||||||
neither of the blacklists.
|
neither of the blacklists.
|
||||||
tensor_dtype_regex_blacklist: Regular-expression blacklist for Tensor
|
tensor_dtype_regex_blacklist: Regular-expression blacklist for Tensor
|
||||||
data type, e.g., `"^int.*"`.
|
data type, e.g., `"^int.*"`.
|
||||||
This blacklist operates in logical `OR` relations to the two whitelists
|
This blacklist operates in logical `OR` relations to the two allowlists
|
||||||
above.
|
above.
|
||||||
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
||||||
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
||||||
|
@ -227,12 +227,12 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
# Assert that the wildcard node name has been created.
|
# Assert that the wildcard node name has been created.
|
||||||
self.assertIn("*", node_names)
|
self.assertIn("*", node_names)
|
||||||
|
|
||||||
def testWatchGraph_nodeNameWhitelist(self):
|
def testWatchGraph_nodeNameAllowlist(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
self._graph,
|
self._graph,
|
||||||
debug_urls="file:///tmp/tfdbg_1",
|
debug_urls="file:///tmp/tfdbg_1",
|
||||||
node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)")
|
node_name_regex_allowlist="(a1$|a1_init$|a1/.*|p1$)")
|
||||||
|
|
||||||
node_names = self._verify_watches(
|
node_names = self._verify_watches(
|
||||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||||
@ -241,50 +241,50 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
|||||||
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
|
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
|
||||||
sorted(node_names))
|
sorted(node_names))
|
||||||
|
|
||||||
def testWatchGraph_opTypeWhitelist(self):
|
def testWatchGraph_opTypeAllowlist(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
self._graph,
|
self._graph,
|
||||||
debug_urls="file:///tmp/tfdbg_1",
|
debug_urls="file:///tmp/tfdbg_1",
|
||||||
op_type_regex_whitelist="(Variable|MatMul)")
|
op_type_regex_allowlist="(Variable|MatMul)")
|
||||||
|
|
||||||
node_names = self._verify_watches(
|
node_names = self._verify_watches(
|
||||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
|
self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
|
||||||
|
|
||||||
def testWatchGraph_nodeNameAndOpTypeWhitelists(self):
|
def testWatchGraph_nodeNameAndOpTypeAllowlists(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
self._graph,
|
self._graph,
|
||||||
debug_urls="file:///tmp/tfdbg_1",
|
debug_urls="file:///tmp/tfdbg_1",
|
||||||
node_name_regex_whitelist="([a-z]+1$)",
|
node_name_regex_allowlist="([a-z]+1$)",
|
||||||
op_type_regex_whitelist="(MatMul)")
|
op_type_regex_allowlist="(MatMul)")
|
||||||
|
|
||||||
node_names = self._verify_watches(
|
node_names = self._verify_watches(
|
||||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertEqual(["p1"], node_names)
|
self.assertEqual(["p1"], node_names)
|
||||||
|
|
||||||
def testWatchGraph_tensorDTypeWhitelist(self):
|
def testWatchGraph_tensorDTypeAllowlist(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
self._graph,
|
self._graph,
|
||||||
debug_urls="file:///tmp/tfdbg_1",
|
debug_urls="file:///tmp/tfdbg_1",
|
||||||
tensor_dtype_regex_whitelist=".*_ref")
|
tensor_dtype_regex_allowlist=".*_ref")
|
||||||
|
|
||||||
node_names = self._verify_watches(
|
node_names = self._verify_watches(
|
||||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||||
self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
|
self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
|
||||||
|
|
||||||
def testWatchGraph_nodeNameAndTensorDTypeWhitelists(self):
|
def testWatchGraph_nodeNameAndTensorDTypeAllowlists(self):
|
||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
self._run_options,
|
self._run_options,
|
||||||
self._graph,
|
self._graph,
|
||||||
debug_urls="file:///tmp/tfdbg_1",
|
debug_urls="file:///tmp/tfdbg_1",
|
||||||
node_name_regex_whitelist="^a.*",
|
node_name_regex_allowlist="^a.*",
|
||||||
tensor_dtype_regex_whitelist=".*_ref")
|
tensor_dtype_regex_allowlist=".*_ref")
|
||||||
|
|
||||||
node_names = self._verify_watches(
|
node_names = self._verify_watches(
|
||||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||||
|
@ -143,7 +143,7 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
|||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
run_options,
|
run_options,
|
||||||
sess.graph,
|
sess.graph,
|
||||||
node_name_regex_whitelist=r"a",
|
node_name_regex_allowlist=r"a",
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"],
|
||||||
debug_urls=[self.debug_server_url])
|
debug_urls=[self.debug_server_url])
|
||||||
|
|
||||||
@ -155,7 +155,7 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
|||||||
debug_utils.watch_graph(
|
debug_utils.watch_graph(
|
||||||
run_options,
|
run_options,
|
||||||
sess.graph,
|
sess.graph,
|
||||||
node_name_regex_whitelist=r"p",
|
node_name_regex_allowlist=r"p",
|
||||||
debug_ops=["DebugIdentity(gated_grpc=True)"],
|
debug_ops=["DebugIdentity(gated_grpc=True)"],
|
||||||
debug_urls=[self.debug_server_url])
|
debug_urls=[self.debug_server_url])
|
||||||
|
|
||||||
@ -209,8 +209,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
|||||||
def watch_fn(feeds, fetch_keys):
|
def watch_fn(feeds, fetch_keys):
|
||||||
del feeds, fetch_keys
|
del feeds, fetch_keys
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"p")
|
||||||
node_name_regex_whitelist=r"p")
|
|
||||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||||
|
|
||||||
|
@ -71,7 +71,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
|||||||
del fetches, feeds
|
del fetches, feeds
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"],
|
||||||
node_name_regex_whitelist=r"original_u")
|
node_name_regex_allowlist=r"original_u")
|
||||||
|
|
||||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||||
self.assertAllClose(42.0, sess.run(u))
|
self.assertAllClose(42.0, sess.run(u))
|
||||||
@ -101,8 +102,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
|||||||
def watch_fn(fetches, feeds):
|
def watch_fn(fetches, feeds):
|
||||||
del fetches, feeds # Unused by this watch_fn.
|
del fetches, feeds # Unused by this watch_fn.
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||||
node_name_regex_whitelist=r"u_init")
|
|
||||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||||
sess.run(u.initializer)
|
sess.run(u.initializer)
|
||||||
@ -125,8 +126,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
|||||||
def watch_fn(fetches, feeds):
|
def watch_fn(fetches, feeds):
|
||||||
del fetches, feeds
|
del fetches, feeds
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||||
node_name_regex_whitelist=r"u_init")
|
|
||||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||||
sess.run(u.initializer)
|
sess.run(u.initializer)
|
||||||
@ -155,8 +156,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
|||||||
def watch_fn(fetches, feeds):
|
def watch_fn(fetches, feeds):
|
||||||
del fetches, feeds
|
del fetches, feeds
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||||
node_name_regex_whitelist=r"u_init")
|
|
||||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||||
sess.run(u.initializer)
|
sess.run(u.initializer)
|
||||||
@ -177,8 +178,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
|||||||
def watch_fn(fetches, feeds):
|
def watch_fn(fetches, feeds):
|
||||||
del fetches, feeds
|
del fetches, feeds
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||||
node_name_regex_whitelist=r"u_init")
|
|
||||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||||
sess.run(u.initializer)
|
sess.run(u.initializer)
|
||||||
@ -200,8 +201,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
|||||||
def watch_fn(fetches, feeds):
|
def watch_fn(fetches, feeds):
|
||||||
del fetches, feeds
|
del fetches, feeds
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity"],
|
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||||
node_name_regex_whitelist=r"u_init")
|
|
||||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||||
sess.run(u.initializer)
|
sess.run(u.initializer)
|
||||||
|
@ -207,8 +207,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
|||||||
del feeds, fetch_keys
|
del feeds, fetch_keys
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity", "DebugNumericSummary"],
|
debug_ops=["DebugIdentity", "DebugNumericSummary"],
|
||||||
node_name_regex_whitelist=r".*/read",
|
node_name_regex_allowlist=r".*/read",
|
||||||
op_type_regex_whitelist=None,
|
op_type_regex_allowlist=None,
|
||||||
tolerate_debug_op_creation_failures=True)
|
tolerate_debug_op_creation_failures=True)
|
||||||
|
|
||||||
u = variables.VariableV1(2.1, name="u")
|
u = variables.VariableV1(2.1, name="u")
|
||||||
|
@ -221,15 +221,15 @@ def annotate_source(dump,
|
|||||||
|
|
||||||
|
|
||||||
def list_source_files_against_dump(dump,
|
def list_source_files_against_dump(dump,
|
||||||
path_regex_whitelist=None,
|
path_regex_allowlist=None,
|
||||||
node_name_regex_whitelist=None):
|
node_name_regex_allowlist=None):
|
||||||
"""Generate a list of source files with information regarding ops and tensors.
|
"""Generate a list of source files with information regarding ops and tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
|
dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
|
||||||
has been loaded.
|
has been loaded.
|
||||||
path_regex_whitelist: A regular-expression filter for source file path.
|
path_regex_allowlist: A regular-expression filter for source file path.
|
||||||
node_name_regex_whitelist: A regular-expression filter for node names.
|
node_name_regex_allowlist: A regular-expression filter for node names.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of tuples regarding the Python source files involved in constructing
|
A list of tuples regarding the Python source files involved in constructing
|
||||||
@ -264,10 +264,11 @@ def list_source_files_against_dump(dump,
|
|||||||
path_to_first_line = {}
|
path_to_first_line = {}
|
||||||
tensor_name_to_num_dumps = {}
|
tensor_name_to_num_dumps = {}
|
||||||
|
|
||||||
path_regex = (re.compile(path_regex_whitelist)
|
path_regex = (
|
||||||
if path_regex_whitelist else None)
|
re.compile(path_regex_allowlist) if path_regex_allowlist else None)
|
||||||
node_name_regex = (re.compile(node_name_regex_whitelist)
|
node_name_regex = (
|
||||||
if node_name_regex_whitelist else None)
|
re.compile(node_name_regex_allowlist)
|
||||||
|
if node_name_regex_allowlist else None)
|
||||||
|
|
||||||
to_skip_file_paths = set()
|
to_skip_file_paths = set()
|
||||||
for op in py_graph.get_operations():
|
for op in py_graph.get_operations():
|
||||||
|
@ -406,7 +406,7 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
def testGenerateSourceListWithNodeNameFilter(self):
|
def testGenerateSourceListWithNodeNameFilter(self):
|
||||||
source_list = source_utils.list_source_files_against_dump(
|
source_list = source_utils.list_source_files_against_dump(
|
||||||
self.dump, node_name_regex_whitelist=r"while/Add.*")
|
self.dump, node_name_regex_allowlist=r"while/Add.*")
|
||||||
|
|
||||||
# Assert that the file paths are sorted.
|
# Assert that the file paths are sorted.
|
||||||
file_paths = [item[0] for item in source_list]
|
file_paths = [item[0] for item in source_list]
|
||||||
@ -433,8 +433,8 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
|
|||||||
curr_file_basename = os.path.basename(self.curr_file_path)
|
curr_file_basename = os.path.basename(self.curr_file_path)
|
||||||
source_list = source_utils.list_source_files_against_dump(
|
source_list = source_utils.list_source_files_against_dump(
|
||||||
self.dump,
|
self.dump,
|
||||||
path_regex_whitelist=(
|
path_regex_allowlist=(".*" + curr_file_basename.replace(".", "\\.") +
|
||||||
".*" + curr_file_basename.replace(".", "\\.") + "$"))
|
"$"))
|
||||||
|
|
||||||
self.assertEqual(1, len(source_list))
|
self.assertEqual(1, len(source_list))
|
||||||
(file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
|
(file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
|
||||||
|
@ -169,7 +169,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
|||||||
log_usage=False)
|
log_usage=False)
|
||||||
|
|
||||||
def testDumpingWithLegacyWatchFnOnFetchesWorks(self):
|
def testDumpingWithLegacyWatchFnOnFetchesWorks(self):
|
||||||
"""Use a watch_fn that returns different whitelists for different runs."""
|
"""Use a watch_fn that returns different allowlists for different runs."""
|
||||||
|
|
||||||
def watch_fn(fetches, feeds):
|
def watch_fn(fetches, feeds):
|
||||||
del feeds
|
del feeds
|
||||||
@ -240,9 +240,9 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
|||||||
del fetches, feeds
|
del fetches, feeds
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops=["DebugIdentity", "DebugNumericSummary"],
|
debug_ops=["DebugIdentity", "DebugNumericSummary"],
|
||||||
node_name_regex_whitelist=r"^v.*",
|
node_name_regex_allowlist=r"^v.*",
|
||||||
op_type_regex_whitelist=r".*",
|
op_type_regex_allowlist=r".*",
|
||||||
tensor_dtype_regex_whitelist=".*_ref")
|
tensor_dtype_regex_allowlist=".*_ref")
|
||||||
|
|
||||||
sess = dumping_wrapper.DumpingDebugWrapperSession(
|
sess = dumping_wrapper.DumpingDebugWrapperSession(
|
||||||
self.sess,
|
self.sess,
|
||||||
@ -288,14 +288,13 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
|||||||
if watch_fn_state["run_counter"] % 2 == 1:
|
if watch_fn_state["run_counter"] % 2 == 1:
|
||||||
# If odd-index run (1-based), watch every ref-type tensor.
|
# If odd-index run (1-based), watch every ref-type tensor.
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops="DebugIdentity",
|
debug_ops="DebugIdentity", tensor_dtype_regex_allowlist=".*_ref")
|
||||||
tensor_dtype_regex_whitelist=".*_ref")
|
|
||||||
else:
|
else:
|
||||||
# If even-index run, watch nothing.
|
# If even-index run, watch nothing.
|
||||||
return framework.WatchOptions(
|
return framework.WatchOptions(
|
||||||
debug_ops="DebugIdentity",
|
debug_ops="DebugIdentity",
|
||||||
node_name_regex_whitelist=r"^$",
|
node_name_regex_allowlist=r"^$",
|
||||||
op_type_regex_whitelist=r"^$")
|
op_type_regex_allowlist=r"^$")
|
||||||
|
|
||||||
dumping_hook = hooks.DumpingDebugHook(
|
dumping_hook = hooks.DumpingDebugHook(
|
||||||
self.session_root, watch_fn=counting_watch_fn, log_usage=False)
|
self.session_root, watch_fn=counting_watch_fn, log_usage=False)
|
||||||
|
@ -234,9 +234,9 @@ class OnRunStartResponse(object):
|
|||||||
action,
|
action,
|
||||||
debug_urls,
|
debug_urls,
|
||||||
debug_ops="DebugIdentity",
|
debug_ops="DebugIdentity",
|
||||||
node_name_regex_whitelist=None,
|
node_name_regex_allowlist=None,
|
||||||
op_type_regex_whitelist=None,
|
op_type_regex_allowlist=None,
|
||||||
tensor_dtype_regex_whitelist=None,
|
tensor_dtype_regex_allowlist=None,
|
||||||
tolerate_debug_op_creation_failures=False):
|
tolerate_debug_op_creation_failures=False):
|
||||||
"""Constructor of `OnRunStartResponse`.
|
"""Constructor of `OnRunStartResponse`.
|
||||||
|
|
||||||
@ -247,10 +247,10 @@ class OnRunStartResponse(object):
|
|||||||
during the run() call.
|
during the run() call.
|
||||||
debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
|
debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
|
||||||
debugger.
|
debugger.
|
||||||
node_name_regex_whitelist: Regular-expression whitelist for node
|
node_name_regex_allowlist: Regular-expression allowlist for node
|
||||||
name.
|
name.
|
||||||
op_type_regex_whitelist: Regular-expression whitelist for op type.
|
op_type_regex_allowlist: Regular-expression allowlist for op type.
|
||||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
|
tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
|
||||||
dtype.
|
dtype.
|
||||||
tolerate_debug_op_creation_failures: Whether debug op creation failures
|
tolerate_debug_op_creation_failures: Whether debug op creation failures
|
||||||
are to be tolerated.
|
are to be tolerated.
|
||||||
@ -264,9 +264,9 @@ class OnRunStartResponse(object):
|
|||||||
|
|
||||||
self.debug_ops = debug_ops
|
self.debug_ops = debug_ops
|
||||||
|
|
||||||
self.node_name_regex_whitelist = node_name_regex_whitelist
|
self.node_name_regex_allowlist = node_name_regex_allowlist
|
||||||
self.op_type_regex_whitelist = op_type_regex_whitelist
|
self.op_type_regex_allowlist = op_type_regex_allowlist
|
||||||
self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
|
self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
|
||||||
self.tolerate_debug_op_creation_failures = (
|
self.tolerate_debug_op_creation_failures = (
|
||||||
tolerate_debug_op_creation_failures)
|
tolerate_debug_op_creation_failures)
|
||||||
|
|
||||||
@ -329,7 +329,7 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
Args:
|
Args:
|
||||||
sess: An (unwrapped) TensorFlow session instance. It should be a subtype
|
sess: An (unwrapped) TensorFlow session instance. It should be a subtype
|
||||||
of `BaseSession` or `tf.MonitoredSession`.
|
of `BaseSession` or `tf.MonitoredSession`.
|
||||||
thread_name_filter: Regular-expression filter (whitelist) for name(s) of
|
thread_name_filter: Regular-expression filter (allowlist) for name(s) of
|
||||||
thread(s) on which the wrapper session will be active. This regular
|
thread(s) on which the wrapper session will be active. This regular
|
||||||
expression is used in a start-anchored fashion on the thread name, i.e.,
|
expression is used in a start-anchored fashion on the thread name, i.e.,
|
||||||
by applying the `match` method of the compiled pattern. The default
|
by applying the `match` method of the compiled pattern. The default
|
||||||
@ -545,11 +545,10 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
decorated_run_options,
|
decorated_run_options,
|
||||||
run_start_resp.debug_urls,
|
run_start_resp.debug_urls,
|
||||||
debug_ops=run_start_resp.debug_ops,
|
debug_ops=run_start_resp.debug_ops,
|
||||||
node_name_regex_whitelist=(
|
node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist),
|
||||||
run_start_resp.node_name_regex_whitelist),
|
op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist,
|
||||||
op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
|
tensor_dtype_regex_allowlist=(
|
||||||
tensor_dtype_regex_whitelist=(
|
run_start_resp.tensor_dtype_regex_allowlist),
|
||||||
run_start_resp.tensor_dtype_regex_whitelist),
|
|
||||||
tolerate_debug_op_creation_failures=(
|
tolerate_debug_op_creation_failures=(
|
||||||
run_start_resp.tolerate_debug_op_creation_failures))
|
run_start_resp.tolerate_debug_op_creation_failures))
|
||||||
|
|
||||||
@ -707,9 +706,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
run_options,
|
run_options,
|
||||||
debug_urls,
|
debug_urls,
|
||||||
debug_ops="DebugIdentity",
|
debug_ops="DebugIdentity",
|
||||||
node_name_regex_whitelist=None,
|
node_name_regex_allowlist=None,
|
||||||
op_type_regex_whitelist=None,
|
op_type_regex_allowlist=None,
|
||||||
tensor_dtype_regex_whitelist=None,
|
tensor_dtype_regex_allowlist=None,
|
||||||
tolerate_debug_op_creation_failures=False):
|
tolerate_debug_op_creation_failures=False):
|
||||||
"""Modify a RunOptions object for debug tensor watching.
|
"""Modify a RunOptions object for debug tensor watching.
|
||||||
|
|
||||||
@ -721,10 +720,10 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
debug_urls: (list of str) debug URLs to be entered in run_options.
|
debug_urls: (list of str) debug URLs to be entered in run_options.
|
||||||
debug_tensor_watch_opts.
|
debug_tensor_watch_opts.
|
||||||
debug_ops: (str or list of str) debug op(s) to be used by the debugger.
|
debug_ops: (str or list of str) debug op(s) to be used by the debugger.
|
||||||
node_name_regex_whitelist: Regular-expression whitelist for node
|
node_name_regex_allowlist: Regular-expression allowlist for node
|
||||||
name.
|
name.
|
||||||
op_type_regex_whitelist: Regular-expression whitelist for op type.
|
op_type_regex_allowlist: Regular-expression allowlist for op type.
|
||||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
|
tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
|
||||||
dtype.
|
dtype.
|
||||||
tolerate_debug_op_creation_failures: Whether debug op creation failures
|
tolerate_debug_op_creation_failures: Whether debug op creation failures
|
||||||
are to be tolerated.
|
are to be tolerated.
|
||||||
@ -736,9 +735,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
self._sess.graph,
|
self._sess.graph,
|
||||||
debug_urls=debug_urls,
|
debug_urls=debug_urls,
|
||||||
debug_ops=debug_ops,
|
debug_ops=debug_ops,
|
||||||
node_name_regex_whitelist=node_name_regex_whitelist,
|
node_name_regex_allowlist=node_name_regex_allowlist,
|
||||||
op_type_regex_whitelist=op_type_regex_whitelist,
|
op_type_regex_allowlist=op_type_regex_allowlist,
|
||||||
tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
|
tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist,
|
||||||
tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
|
tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
|
||||||
reset_disk_byte_usage=(self._run_call_count == 1 or
|
reset_disk_byte_usage=(self._run_call_count == 1 or
|
||||||
self._is_disk_usage_reset_each_run()))
|
self._is_disk_usage_reset_each_run()))
|
||||||
@ -821,8 +820,8 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
|||||||
def close(self):
|
def close(self):
|
||||||
self._sess.close()
|
self._sess.close()
|
||||||
|
|
||||||
# TODO(cais): Add _node_name_regex_whitelist and
|
# TODO(cais): Add _node_name_regex_allowlist and
|
||||||
# _node_op_type_regex_whitelist.
|
# _node_op_type_regex_allowlist.
|
||||||
|
|
||||||
def should_stop(self):
|
def should_stop(self):
|
||||||
if hasattr(self._sess, "should_stop"):
|
if hasattr(self._sess, "should_stop"):
|
||||||
@ -838,9 +837,9 @@ class WatchOptions(object):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
debug_ops=None,
|
debug_ops=None,
|
||||||
node_name_regex_whitelist=None,
|
node_name_regex_allowlist=None,
|
||||||
op_type_regex_whitelist=None,
|
op_type_regex_allowlist=None,
|
||||||
tensor_dtype_regex_whitelist=None,
|
tensor_dtype_regex_allowlist=None,
|
||||||
tolerate_debug_op_creation_failures=False):
|
tolerate_debug_op_creation_failures=False):
|
||||||
"""Constructor of WatchOptions: Debug watch options.
|
"""Constructor of WatchOptions: Debug watch options.
|
||||||
|
|
||||||
@ -848,17 +847,17 @@ class WatchOptions(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
debug_ops: (`str` or `list of str`) Debug ops to be used.
|
debug_ops: (`str` or `list of str`) Debug ops to be used.
|
||||||
node_name_regex_whitelist: Regular-expression whitelist for node_name,
|
node_name_regex_allowlist: Regular-expression allowlist for node_name,
|
||||||
e.g., `"(weight_[0-9]+|bias_.*)"`
|
e.g., `"(weight_[0-9]+|bias_.*)"`
|
||||||
op_type_regex_whitelist: Regular-expression whitelist for the op type of
|
op_type_regex_allowlist: Regular-expression allowlist for the op type of
|
||||||
nodes, e.g., `"(Variable|Add)"`.
|
nodes, e.g., `"(Variable|Add)"`.
|
||||||
If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
|
If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
|
||||||
are set, the two filtering operations will occur in a logical `AND`
|
are set, the two filtering operations will occur in a logical `AND`
|
||||||
relation. In other words, a node will be included if and only if it
|
relation. In other words, a node will be included if and only if it
|
||||||
hits both whitelists.
|
hits both allowlists.
|
||||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
|
tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
|
||||||
data type, e.g., `"^int.*"`.
|
data type, e.g., `"^int.*"`.
|
||||||
This whitelist operates in logical `AND` relations to the two whitelists
|
This allowlist operates in logical `AND` relations to the two allowlists
|
||||||
above.
|
above.
|
||||||
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
||||||
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
||||||
@ -868,19 +867,19 @@ class WatchOptions(object):
|
|||||||
self.debug_ops = debug_ops
|
self.debug_ops = debug_ops
|
||||||
else:
|
else:
|
||||||
self.debug_ops = ["DebugIdentity"]
|
self.debug_ops = ["DebugIdentity"]
|
||||||
self.node_name_regex_whitelist = node_name_regex_whitelist
|
self.node_name_regex_allowlist = node_name_regex_allowlist
|
||||||
self.op_type_regex_whitelist = op_type_regex_whitelist
|
self.op_type_regex_allowlist = op_type_regex_allowlist
|
||||||
self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
|
self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
|
||||||
self.tolerate_debug_op_creation_failures = (
|
self.tolerate_debug_op_creation_failures = (
|
||||||
tolerate_debug_op_creation_failures)
|
tolerate_debug_op_creation_failures)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, "
|
return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, "
|
||||||
"op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, "
|
"op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, "
|
||||||
"tolerate_debug_op_creation_failures=%r)" % (
|
"tolerate_debug_op_creation_failures=%r)" %
|
||||||
self.debug_ops, self.node_name_regex_whitelist,
|
(self.debug_ops, self.node_name_regex_allowlist,
|
||||||
self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist,
|
self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist,
|
||||||
self.tolerate_debug_op_creation_failures))
|
self.tolerate_debug_op_creation_failures))
|
||||||
|
|
||||||
|
|
||||||
class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
|
class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
|
||||||
@ -952,14 +951,14 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
|
|||||||
OnRunStartAction.DEBUG_RUN,
|
OnRunStartAction.DEBUG_RUN,
|
||||||
debug_urls,
|
debug_urls,
|
||||||
debug_ops=watch_opts.debug_ops,
|
debug_ops=watch_opts.debug_ops,
|
||||||
node_name_regex_whitelist=watch_opts.node_name_regex_whitelist,
|
node_name_regex_allowlist=watch_opts.node_name_regex_allowlist,
|
||||||
op_type_regex_whitelist=watch_opts.op_type_regex_whitelist,
|
op_type_regex_allowlist=watch_opts.op_type_regex_allowlist,
|
||||||
tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist,
|
tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist,
|
||||||
tolerate_debug_op_creation_failures=(
|
tolerate_debug_op_creation_failures=(
|
||||||
watch_opts.tolerate_debug_op_creation_failures))
|
watch_opts.tolerate_debug_op_creation_failures))
|
||||||
|
|
||||||
def _prepare_run_watch_config(self, fetches, feed_dict):
|
def _prepare_run_watch_config(self, fetches, feed_dict):
|
||||||
"""Get the debug_urls, and node/op whitelists for the current run() call.
|
"""Get the debug_urls, and node/op allowlists for the current run() call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fetches: Same as the `fetches` argument to `Session.run()`.
|
fetches: Same as the `fetches` argument to `Session.run()`.
|
||||||
@ -969,7 +968,7 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
|
|||||||
debug_urls: (str or list of str) Debug URLs for the current run() call.
|
debug_urls: (str or list of str) Debug URLs for the current run() call.
|
||||||
Currently, the list consists of only one URL that is a file:// URL.
|
Currently, the list consists of only one URL that is a file:// URL.
|
||||||
watch_options: (WatchOptions) The return value of a watch_fn, containing
|
watch_options: (WatchOptions) The return value of a watch_fn, containing
|
||||||
options including debug_ops, and whitelists.
|
options including debug_ops, and allowlists.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
|
debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
|
||||||
|
@ -124,12 +124,12 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook):
|
|||||||
run_args.options,
|
run_args.options,
|
||||||
on_run_start_response.debug_urls,
|
on_run_start_response.debug_urls,
|
||||||
debug_ops=on_run_start_response.debug_ops,
|
debug_ops=on_run_start_response.debug_ops,
|
||||||
node_name_regex_whitelist=(
|
node_name_regex_allowlist=(
|
||||||
on_run_start_response.node_name_regex_whitelist),
|
on_run_start_response.node_name_regex_allowlist),
|
||||||
op_type_regex_whitelist=(
|
op_type_regex_allowlist=(
|
||||||
on_run_start_response.op_type_regex_whitelist),
|
on_run_start_response.op_type_regex_allowlist),
|
||||||
tensor_dtype_regex_whitelist=(
|
tensor_dtype_regex_allowlist=(
|
||||||
on_run_start_response.tensor_dtype_regex_whitelist),
|
on_run_start_response.tensor_dtype_regex_allowlist),
|
||||||
tolerate_debug_op_creation_failures=(
|
tolerate_debug_op_creation_failures=(
|
||||||
on_run_start_response.tolerate_debug_op_creation_failures))
|
on_run_start_response.tolerate_debug_op_creation_failures))
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
@ -205,9 +205,9 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
|
|||||||
run_context.session.graph,
|
run_context.session.graph,
|
||||||
debug_urls=debug_urls,
|
debug_urls=debug_urls,
|
||||||
debug_ops=watch_options.debug_ops,
|
debug_ops=watch_options.debug_ops,
|
||||||
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
|
node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
|
||||||
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
|
op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
|
||||||
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
|
tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
|
||||||
tolerate_debug_op_creation_failures=(
|
tolerate_debug_op_creation_failures=(
|
||||||
watch_options.tolerate_debug_op_creation_failures),
|
watch_options.tolerate_debug_op_creation_failures),
|
||||||
reset_disk_byte_usage=reset_disk_byte_usage)
|
reset_disk_byte_usage=reset_disk_byte_usage)
|
||||||
@ -292,9 +292,9 @@ class GrpcDebugHook(session_run_hook.SessionRunHook):
|
|||||||
debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls(
|
debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls(
|
||||||
fetches, feed_dict),
|
fetches, feed_dict),
|
||||||
debug_ops=watch_options.debug_ops,
|
debug_ops=watch_options.debug_ops,
|
||||||
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
|
node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
|
||||||
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
|
op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
|
||||||
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
|
tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
|
||||||
tolerate_debug_op_creation_failures=(
|
tolerate_debug_op_creation_failures=(
|
||||||
watch_options.tolerate_debug_op_creation_failures))
|
watch_options.tolerate_debug_op_creation_failures))
|
||||||
|
|
||||||
|
@ -552,9 +552,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
|||||||
run_start_response = framework.OnRunStartResponse(
|
run_start_response = framework.OnRunStartResponse(
|
||||||
action,
|
action,
|
||||||
debug_urls,
|
debug_urls,
|
||||||
node_name_regex_whitelist=parsed.node_name_filter,
|
node_name_regex_allowlist=parsed.node_name_filter,
|
||||||
op_type_regex_whitelist=parsed.op_type_filter,
|
op_type_regex_allowlist=parsed.op_type_filter,
|
||||||
tensor_dtype_regex_whitelist=parsed.tensor_dtype_filter)
|
tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter)
|
||||||
|
|
||||||
if parsed.till_filter_pass:
|
if parsed.till_filter_pass:
|
||||||
# For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN
|
# For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN
|
||||||
|
@ -88,7 +88,7 @@ def call_for_each_replica(strategy, fn, args=None, kwargs=None):
|
|||||||
else:
|
else:
|
||||||
# When a tf.function is wrapped to trigger _call_for_each_replica (see
|
# When a tf.function is wrapped to trigger _call_for_each_replica (see
|
||||||
# the other branch above), AutoGraph stops conversion at
|
# the other branch above), AutoGraph stops conversion at
|
||||||
# _call_for_each_replica itself (TF library functions are whitelisted).
|
# _call_for_each_replica itself (TF library functions are allowlisted).
|
||||||
# This makes sure that the Python function that originally passed to
|
# This makes sure that the Python function that originally passed to
|
||||||
# the tf.function is still converted.
|
# the tf.function is still converted.
|
||||||
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
||||||
|
@ -237,7 +237,7 @@ def _parse_func_attrs(attributes):
|
|||||||
A dict of attributes where the key is the name of attribute and the value
|
A dict of attributes where the key is the name of attribute and the value
|
||||||
is the AttrValue proto.
|
is the AttrValue proto.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the kwargs contains unwhitelisted name or unsupported value
|
ValueError: If the kwargs contains unallowlisted name or unsupported value
|
||||||
types.
|
types.
|
||||||
"""
|
"""
|
||||||
attrs = {}
|
attrs = {}
|
||||||
@ -3625,9 +3625,9 @@ def defun_with_attributes(func=None,
|
|||||||
input_signature: same as defun()'s input_signature.
|
input_signature: same as defun()'s input_signature.
|
||||||
attributes: A dictionary of arguments which will be added to function def as
|
attributes: A dictionary of arguments which will be added to function def as
|
||||||
attributes. Currently only support primitive types as value, and only
|
attributes. Currently only support primitive types as value, and only
|
||||||
whitelisted attribute name is allowed. Unwhitelisted attribute name or
|
allowlisted attribute name is allowed. Unallowlisted attribute name or
|
||||||
unsupported value will result into ValueError. `func_name` is also one of
|
unsupported value will result into ValueError. `func_name` is also one of
|
||||||
the whitelisted argument which is a python string, and sets the name for
|
the allowlisted argument which is a python string, and sets the name for
|
||||||
this `ConcreteFunction` in the graph.
|
this `ConcreteFunction` in the graph.
|
||||||
autograph: same as defun()'s autograph.
|
autograph: same as defun()'s autograph.
|
||||||
experimental_autograph_options: same as defun()'s
|
experimental_autograph_options: same as defun()'s
|
||||||
|
@ -108,9 +108,9 @@ _ALL_BLACKLISTED_OPS = (
|
|||||||
set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
|
set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
|
||||||
| set(_ORDER_INSENSITIVE_STATEFUL_OPS))
|
| set(_ORDER_INSENSITIVE_STATEFUL_OPS))
|
||||||
|
|
||||||
# Op types that are marked as stateless, but should be whitelisted to add auto
|
# Op types that are marked as stateless, but should be allowlisted to add auto
|
||||||
# control dependencies.
|
# control dependencies.
|
||||||
_WHITELIST_STATELESS_OPS = [
|
_ALLOWLIST_STATELESS_OPS = [
|
||||||
# As TPU collective ops are blocking, if there are more than one collective
|
# As TPU collective ops are blocking, if there are more than one collective
|
||||||
# op in the function, we need to make sure different collectives ops are
|
# op in the function, we need to make sure different collectives ops are
|
||||||
# scheduled in certain orders. Otherwise if at the same time all the
|
# scheduled in certain orders. Otherwise if at the same time all the
|
||||||
@ -125,7 +125,7 @@ _WHITELIST_STATELESS_OPS = [
|
|||||||
def op_is_stateful(op):
|
def op_is_stateful(op):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return (op._is_stateful and op.type not in _ALL_BLACKLISTED_OPS) or (
|
return (op._is_stateful and op.type not in _ALL_BLACKLISTED_OPS) or (
|
||||||
op.type in _WHITELIST_STATELESS_OPS)
|
op.type in _ALLOWLIST_STATELESS_OPS)
|
||||||
|
|
||||||
|
|
||||||
class ResourceType(enum.Enum):
|
class ResourceType(enum.Enum):
|
||||||
|
@ -710,12 +710,12 @@ class _ConverterData(object):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
graph_def,
|
graph_def,
|
||||||
variable_names_whitelist=None,
|
variable_names_allowlist=None,
|
||||||
variable_names_blacklist=None):
|
variable_names_blacklist=None):
|
||||||
self._graph_def = graph_def
|
self._graph_def = graph_def
|
||||||
self._tensor_data = {}
|
self._tensor_data = {}
|
||||||
self._build_node_defs_list()
|
self._build_node_defs_list()
|
||||||
self._variable_names_whitelist = variable_names_whitelist
|
self._variable_names_allowlist = variable_names_allowlist
|
||||||
self._variable_names_blacklist = variable_names_blacklist
|
self._variable_names_blacklist = variable_names_blacklist
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -740,8 +740,8 @@ class _ConverterData(object):
|
|||||||
|
|
||||||
def _should_convert(self, name):
|
def _should_convert(self, name):
|
||||||
"""Checks whether to convert the given variable name to a constant."""
|
"""Checks whether to convert the given variable name to a constant."""
|
||||||
return (self._variable_names_whitelist is None or
|
return (self._variable_names_allowlist is None or
|
||||||
name in self._variable_names_whitelist) and (
|
name in self._variable_names_allowlist) and (
|
||||||
self._variable_names_blacklist is None or
|
self._variable_names_blacklist is None or
|
||||||
name not in self._variable_names_blacklist)
|
name not in self._variable_names_blacklist)
|
||||||
|
|
||||||
@ -776,7 +776,7 @@ class _FunctionConverterData(_ConverterData):
|
|||||||
func,
|
func,
|
||||||
lower_control_flow,
|
lower_control_flow,
|
||||||
aggressive_inlining,
|
aggressive_inlining,
|
||||||
variable_names_whitelist=None,
|
variable_names_allowlist=None,
|
||||||
variable_names_blacklist=None):
|
variable_names_blacklist=None):
|
||||||
"""Creates the conversion data for the given function.
|
"""Creates the conversion data for the given function.
|
||||||
|
|
||||||
@ -787,7 +787,7 @@ class _FunctionConverterData(_ConverterData):
|
|||||||
aggressive_inlining: Boolean indicating whether or not to to aggressive
|
aggressive_inlining: Boolean indicating whether or not to to aggressive
|
||||||
function inlining (might be unsafe if function has stateful ops, not
|
function inlining (might be unsafe if function has stateful ops, not
|
||||||
properly connected to control outputs).
|
properly connected to control outputs).
|
||||||
variable_names_whitelist: The set of variable names to convert (by
|
variable_names_allowlist: The set of variable names to convert (by
|
||||||
default, all variables are converted).
|
default, all variables are converted).
|
||||||
variable_names_blacklist: The set of variable names to omit converting to
|
variable_names_blacklist: The set of variable names to omit converting to
|
||||||
constants.
|
constants.
|
||||||
@ -799,7 +799,7 @@ class _FunctionConverterData(_ConverterData):
|
|||||||
aggressive_inlining)
|
aggressive_inlining)
|
||||||
super(_FunctionConverterData, self).__init__(
|
super(_FunctionConverterData, self).__init__(
|
||||||
graph_def,
|
graph_def,
|
||||||
variable_names_whitelist=variable_names_whitelist,
|
variable_names_allowlist=variable_names_allowlist,
|
||||||
variable_names_blacklist=variable_names_blacklist)
|
variable_names_blacklist=variable_names_blacklist)
|
||||||
self._build_tensor_data()
|
self._build_tensor_data()
|
||||||
|
|
||||||
@ -849,12 +849,12 @@ class _SessionConverterData(_ConverterData):
|
|||||||
session,
|
session,
|
||||||
graph_def,
|
graph_def,
|
||||||
output_node_names,
|
output_node_names,
|
||||||
variable_names_whitelist=None,
|
variable_names_allowlist=None,
|
||||||
variable_names_blacklist=None):
|
variable_names_blacklist=None):
|
||||||
graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
|
graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
|
||||||
super(_SessionConverterData, self).__init__(
|
super(_SessionConverterData, self).__init__(
|
||||||
graph_def,
|
graph_def,
|
||||||
variable_names_whitelist=variable_names_whitelist,
|
variable_names_allowlist=variable_names_allowlist,
|
||||||
variable_names_blacklist=variable_names_blacklist)
|
variable_names_blacklist=variable_names_blacklist)
|
||||||
|
|
||||||
nodes_to_convert = []
|
nodes_to_convert = []
|
||||||
@ -1114,7 +1114,7 @@ def convert_variables_to_constants_from_session_graph(
|
|||||||
session,
|
session,
|
||||||
graph_def,
|
graph_def,
|
||||||
output_node_names,
|
output_node_names,
|
||||||
variable_names_whitelist=None,
|
variable_names_allowlist=None,
|
||||||
variable_names_blacklist=None):
|
variable_names_blacklist=None):
|
||||||
"""Replaces all the variables in a graph with constants of the same values.
|
"""Replaces all the variables in a graph with constants of the same values.
|
||||||
|
|
||||||
@ -1129,7 +1129,7 @@ def convert_variables_to_constants_from_session_graph(
|
|||||||
session: Active TensorFlow session containing the variables.
|
session: Active TensorFlow session containing the variables.
|
||||||
graph_def: A GraphDef to convert.
|
graph_def: A GraphDef to convert.
|
||||||
output_node_names: List of name strings for the result nodes of the graph.
|
output_node_names: List of name strings for the result nodes of the graph.
|
||||||
variable_names_whitelist: The set of variable names to convert (by default,
|
variable_names_allowlist: The set of variable names to convert (by default,
|
||||||
all variables are converted).
|
all variables are converted).
|
||||||
variable_names_blacklist: The set of variable names to omit converting to
|
variable_names_blacklist: The set of variable names to omit converting to
|
||||||
constants.
|
constants.
|
||||||
@ -1142,6 +1142,6 @@ def convert_variables_to_constants_from_session_graph(
|
|||||||
session=session,
|
session=session,
|
||||||
graph_def=graph_def,
|
graph_def=graph_def,
|
||||||
output_node_names=output_node_names,
|
output_node_names=output_node_names,
|
||||||
variable_names_whitelist=variable_names_whitelist,
|
variable_names_allowlist=variable_names_allowlist,
|
||||||
variable_names_blacklist=variable_names_blacklist))
|
variable_names_blacklist=variable_names_blacklist))
|
||||||
return graph_def
|
return graph_def
|
||||||
|
@ -49,7 +49,7 @@ from tensorflow.python.util import object_identity
|
|||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
|
|
||||||
WHITELIST_COLLECTIONS = [
|
ALLOWLIST_COLLECTIONS = [
|
||||||
ops.GraphKeys.GLOBAL_VARIABLES,
|
ops.GraphKeys.GLOBAL_VARIABLES,
|
||||||
ops.GraphKeys.LOCAL_VARIABLES,
|
ops.GraphKeys.LOCAL_VARIABLES,
|
||||||
ops.GraphKeys.TRAINABLE_VARIABLES,
|
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||||
@ -172,9 +172,9 @@ class FuncGraph(ops.Graph):
|
|||||||
name: the name of the function.
|
name: the name of the function.
|
||||||
collections: a dictionary of collections this FuncGraph should start
|
collections: a dictionary of collections this FuncGraph should start
|
||||||
with. If not specified (None), the FuncGraph will read (but not write
|
with. If not specified (None), the FuncGraph will read (but not write
|
||||||
to) the outer graph's collections that are not whitelisted, and both
|
to) the outer graph's collections that are not allowlisted, and both
|
||||||
read and write to the outer graph's collections that are whitelisted.
|
read and write to the outer graph's collections that are allowlisted.
|
||||||
The current whitelisted collections are the global variables, the
|
The current allowlisted collections are the global variables, the
|
||||||
local variables, and the trainable variables.
|
local variables, and the trainable variables.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
capture_by_value: An optional boolean. If True, the func graph will
|
capture_by_value: An optional boolean. If True, the func graph will
|
||||||
@ -241,10 +241,10 @@ class FuncGraph(ops.Graph):
|
|||||||
|
|
||||||
if collections is None:
|
if collections is None:
|
||||||
for collection_name in graph.get_all_collection_keys():
|
for collection_name in graph.get_all_collection_keys():
|
||||||
if collection_name not in WHITELIST_COLLECTIONS:
|
if collection_name not in ALLOWLIST_COLLECTIONS:
|
||||||
self._collections[collection_name] = graph.get_collection(
|
self._collections[collection_name] = graph.get_collection(
|
||||||
collection_name)
|
collection_name)
|
||||||
for collection_name in WHITELIST_COLLECTIONS:
|
for collection_name in ALLOWLIST_COLLECTIONS:
|
||||||
self._collections[collection_name] = graph.get_collection_ref(
|
self._collections[collection_name] = graph.get_collection_ref(
|
||||||
collection_name)
|
collection_name)
|
||||||
else:
|
else:
|
||||||
@ -842,9 +842,9 @@ def func_graph_from_py_func(name,
|
|||||||
set, returning an Operation triggers an error.
|
set, returning an Operation triggers an error.
|
||||||
collections: a dictionary of collections this FuncGraph should start
|
collections: a dictionary of collections this FuncGraph should start
|
||||||
with. If not specified (None), the FuncGraph will read (but not write to)
|
with. If not specified (None), the FuncGraph will read (but not write to)
|
||||||
the outer graph's collections that are not whitelisted, and both
|
the outer graph's collections that are not allowlisted, and both
|
||||||
read and write to the outer graph's collections that are whitelisted.
|
read and write to the outer graph's collections that are allowlisted.
|
||||||
The current whitelisted collections are the global variables, the
|
The current allowlisted collections are the global variables, the
|
||||||
local variables, and the trainable variables.
|
local variables, and the trainable variables.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
capture_by_value: An optional boolean. If True, the func graph will capture
|
capture_by_value: An optional boolean. If True, the func graph will capture
|
||||||
|
@ -234,7 +234,7 @@ class _DefinedFunction(object):
|
|||||||
out_names=None,
|
out_names=None,
|
||||||
shape_func=None,
|
shape_func=None,
|
||||||
capture_by_value=False,
|
capture_by_value=False,
|
||||||
whitelisted_stateful_ops=None,
|
allowlisted_stateful_ops=None,
|
||||||
capture_resource_var_by_value=True,
|
capture_resource_var_by_value=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Creates _DefinedFunction.
|
"""Creates _DefinedFunction.
|
||||||
@ -256,7 +256,7 @@ class _DefinedFunction(object):
|
|||||||
output shapes.
|
output shapes.
|
||||||
capture_by_value: Boolean (defaults to False). If True, captured values
|
capture_by_value: Boolean (defaults to False). If True, captured values
|
||||||
will be copied into the function body.
|
will be copied into the function body.
|
||||||
whitelisted_stateful_ops: A set of ops that if stateful we ignore and
|
allowlisted_stateful_ops: A set of ops that if stateful we ignore and
|
||||||
copy into the function body, when `capture_by_value` is True.
|
copy into the function body, when `capture_by_value` is True.
|
||||||
capture_resource_var_by_value: Boolean (defaults to True). If False,
|
capture_resource_var_by_value: Boolean (defaults to True). If False,
|
||||||
captured resource variable returns the handle instead of value.
|
captured resource variable returns the handle instead of value.
|
||||||
@ -275,9 +275,9 @@ class _DefinedFunction(object):
|
|||||||
self._out_names = out_names
|
self._out_names = out_names
|
||||||
self._shape_func = shape_func
|
self._shape_func = shape_func
|
||||||
self._capture_by_value = capture_by_value
|
self._capture_by_value = capture_by_value
|
||||||
self._whitelisted_stateful_ops = whitelisted_stateful_ops
|
self._allowlisted_stateful_ops = allowlisted_stateful_ops
|
||||||
if self._whitelisted_stateful_ops is None:
|
if self._allowlisted_stateful_ops is None:
|
||||||
self._whitelisted_stateful_ops = set()
|
self._allowlisted_stateful_ops = set()
|
||||||
self._capture_resource_var_by_value = capture_resource_var_by_value
|
self._capture_resource_var_by_value = capture_resource_var_by_value
|
||||||
self._extra_kwargs = kwargs
|
self._extra_kwargs = kwargs
|
||||||
# Constructed only when C API is disabled, lazily
|
# Constructed only when C API is disabled, lazily
|
||||||
@ -403,7 +403,7 @@ class _DefinedFunction(object):
|
|||||||
self._capture_by_value,
|
self._capture_by_value,
|
||||||
self._caller_device,
|
self._caller_device,
|
||||||
collections_ref=collections_ref,
|
collections_ref=collections_ref,
|
||||||
whitelisted_stateful_ops=self._whitelisted_stateful_ops,
|
allowlisted_stateful_ops=self._allowlisted_stateful_ops,
|
||||||
capture_resource_var_by_value=self._capture_resource_var_by_value)
|
capture_resource_var_by_value=self._capture_resource_var_by_value)
|
||||||
|
|
||||||
self._extra_inputs = temp_graph.extra_inputs
|
self._extra_inputs = temp_graph.extra_inputs
|
||||||
@ -690,11 +690,11 @@ class _FuncGraph(ops.Graph):
|
|||||||
function argument and the caller passes in the captured tensor.
|
function argument and the caller passes in the captured tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, capture_by_value, whitelisted_stateful_ops,
|
def __init__(self, name, capture_by_value, allowlisted_stateful_ops,
|
||||||
capture_resource_var_by_value, *args, **kwargs):
|
capture_resource_var_by_value, *args, **kwargs):
|
||||||
super(_FuncGraph, self).__init__(*args, **kwargs)
|
super(_FuncGraph, self).__init__(*args, **kwargs)
|
||||||
self._capture_by_value = capture_by_value
|
self._capture_by_value = capture_by_value
|
||||||
self._whitelisted_stateful_ops = whitelisted_stateful_ops
|
self._allowlisted_stateful_ops = allowlisted_stateful_ops
|
||||||
self._capture_resource_var_by_value = capture_resource_var_by_value
|
self._capture_resource_var_by_value = capture_resource_var_by_value
|
||||||
self._building_function = True
|
self._building_function = True
|
||||||
self._outer_graph = ops.get_default_graph()
|
self._outer_graph = ops.get_default_graph()
|
||||||
@ -879,7 +879,7 @@ class _FuncGraph(ops.Graph):
|
|||||||
def _add_op_and_parents(self, op):
|
def _add_op_and_parents(self, op):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
op_def = graph_to_function_def._get_op_def(op)
|
op_def = graph_to_function_def._get_op_def(op)
|
||||||
if op._is_stateful and op not in self._whitelisted_stateful_ops:
|
if op._is_stateful and op not in self._allowlisted_stateful_ops:
|
||||||
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
|
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
|
||||||
"by value." % (op.name, op.type))
|
"by value." % (op.name, op.type))
|
||||||
elif op.type in ("Placeholder", "PlaceholderV2"):
|
elif op.type in ("Placeholder", "PlaceholderV2"):
|
||||||
@ -912,7 +912,7 @@ def func_graph_from_py_func(func,
|
|||||||
container=None,
|
container=None,
|
||||||
collections_ref=None,
|
collections_ref=None,
|
||||||
arg_shapes=None,
|
arg_shapes=None,
|
||||||
whitelisted_stateful_ops=None,
|
allowlisted_stateful_ops=None,
|
||||||
capture_resource_var_by_value=True):
|
capture_resource_var_by_value=True):
|
||||||
"""Returns a _FuncGraph generated from `func`.
|
"""Returns a _FuncGraph generated from `func`.
|
||||||
|
|
||||||
@ -931,7 +931,7 @@ def func_graph_from_py_func(func,
|
|||||||
collections_ref: A reference to a collections dict the _FuncGraph should
|
collections_ref: A reference to a collections dict the _FuncGraph should
|
||||||
use internally.
|
use internally.
|
||||||
arg_shapes: A sequence of the function's argument shapes.
|
arg_shapes: A sequence of the function's argument shapes.
|
||||||
whitelisted_stateful_ops: A set of ops that if stateful we ignore and
|
allowlisted_stateful_ops: A set of ops that if stateful we ignore and
|
||||||
re-create.
|
re-create.
|
||||||
capture_resource_var_by_value: Boolean (defaults to True). If False,
|
capture_resource_var_by_value: Boolean (defaults to True). If False,
|
||||||
captured resource variable returns the handle instead of value.
|
captured resource variable returns the handle instead of value.
|
||||||
@ -944,7 +944,7 @@ def func_graph_from_py_func(func,
|
|||||||
"""
|
"""
|
||||||
if not name:
|
if not name:
|
||||||
name = function_utils.get_func_name(func)
|
name = function_utils.get_func_name(func)
|
||||||
func_graph = _FuncGraph(name, capture_by_value, whitelisted_stateful_ops,
|
func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops,
|
||||||
capture_resource_var_by_value)
|
capture_resource_var_by_value)
|
||||||
|
|
||||||
with func_graph.as_default(), ops.device(device):
|
with func_graph.as_default(), ops.device(device):
|
||||||
|
@ -1043,7 +1043,7 @@ class FunctionTest(test.TestCase):
|
|||||||
self.assertFalse(all(val4 == val2))
|
self.assertFalse(all(val4 == val2))
|
||||||
|
|
||||||
@test_util.run_v1_only("currently failing on v2")
|
@test_util.run_v1_only("currently failing on v2")
|
||||||
def testStatefulFunctionWithWhitelisting(self):
|
def testStatefulFunctionWithAllowlisting(self):
|
||||||
t = random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32)
|
t = random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32)
|
||||||
|
|
||||||
@function.Defun(capture_by_value=True)
|
@function.Defun(capture_by_value=True)
|
||||||
@ -1054,8 +1054,8 @@ class FunctionTest(test.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, "Cannot capture a stateful node"):
|
with self.assertRaisesRegex(ValueError, "Cannot capture a stateful node"):
|
||||||
res = StatefulFn()
|
res = StatefulFn()
|
||||||
|
|
||||||
# This time we whitelist this op, so that its recreated.
|
# This time we allowlist this op, so that its recreated.
|
||||||
@function.Defun(capture_by_value=True, whitelisted_stateful_ops=set([t.op]))
|
@function.Defun(capture_by_value=True, allowlisted_stateful_ops=set([t.op]))
|
||||||
def StatefulFn2():
|
def StatefulFn2():
|
||||||
return t + constant_op.constant(3, dtype=dtypes.int32)
|
return t + constant_op.constant(3, dtype=dtypes.int32)
|
||||||
|
|
||||||
|
@ -276,7 +276,7 @@ def convert_variables_to_constants(sess,
|
|||||||
session=sess,
|
session=sess,
|
||||||
graph_def=input_graph_def,
|
graph_def=input_graph_def,
|
||||||
output_node_names=output_node_names,
|
output_node_names=output_node_names,
|
||||||
variable_names_whitelist=variable_names_whitelist,
|
variable_names_allowlist=variable_names_whitelist,
|
||||||
variable_names_blacklist=variable_names_blacklist)
|
variable_names_blacklist=variable_names_blacklist)
|
||||||
# The previous code logic generated an empty versions field, we clear it here
|
# The previous code logic generated an empty versions field, we clear it here
|
||||||
# to maintain backwards compatibility.
|
# to maintain backwards compatibility.
|
||||||
|
@ -472,7 +472,7 @@ class ImportGraphDefTest(test.TestCase):
|
|||||||
node { name: 'B' op: 'FloatInput' input: 'A:0' }
|
node { name: 'B' op: 'FloatInput' input: 'A:0' }
|
||||||
"""))
|
"""))
|
||||||
|
|
||||||
def testShapeWhitelistViolation(self):
|
def testShapeAllowlistViolation(self):
|
||||||
# L2 loss produces a scalar shape, but the graph
|
# L2 loss produces a scalar shape, but the graph
|
||||||
# has the wrong shape, so raise an error.
|
# has the wrong shape, so raise an error.
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
|
@ -351,7 +351,7 @@ string GenEagerPythonOp::Code() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<string, string> type_annotations;
|
std::unordered_map<string, string> type_annotations;
|
||||||
// Only populate map for whitelisted ops
|
// Only populate map for allowlisted ops
|
||||||
if (add_type_annotations_) {
|
if (add_type_annotations_) {
|
||||||
type_annotations = GetTypeAnnotations();
|
type_annotations = GetTypeAnnotations();
|
||||||
}
|
}
|
||||||
|
@ -108,7 +108,7 @@ string InferSourceFileName(const char* argv_zero) {
|
|||||||
void PrintAllPythonOps(const std::vector<string>& op_list,
|
void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||||
const std::vector<string>& api_def_dirs,
|
const std::vector<string>& api_def_dirs,
|
||||||
const string& source_file_name,
|
const string& source_file_name,
|
||||||
bool op_list_is_whitelist,
|
bool op_list_is_allowlist,
|
||||||
const std::unordered_set<string> type_annotate_ops) {
|
const std::unordered_set<string> type_annotate_ops) {
|
||||||
OpList ops;
|
OpList ops;
|
||||||
OpRegistry::Global()->Export(false, &ops);
|
OpRegistry::Global()->Export(false, &ops);
|
||||||
@ -126,11 +126,11 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
|
|||||||
api_def_map.UpdateDocs();
|
api_def_map.UpdateDocs();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op_list_is_whitelist) {
|
if (op_list_is_allowlist) {
|
||||||
std::unordered_set<string> whitelist(op_list.begin(), op_list.end());
|
std::unordered_set<string> allowlist(op_list.begin(), op_list.end());
|
||||||
OpList pruned_ops;
|
OpList pruned_ops;
|
||||||
for (const auto& op_def : ops.op()) {
|
for (const auto& op_def : ops.op()) {
|
||||||
if (whitelist.find(op_def.name()) != whitelist.end()) {
|
if (allowlist.find(op_def.name()) != allowlist.end()) {
|
||||||
*pruned_ops.mutable_op()->Add() = op_def;
|
*pruned_ops.mutable_op()->Add() = op_def;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -165,13 +165,13 @@ int main(int argc, char* argv[]) {
|
|||||||
|
|
||||||
if (argc == 2) {
|
if (argc == 2) {
|
||||||
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
|
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
|
||||||
false /* op_list_is_whitelist */,
|
false /* op_list_is_allowlist */,
|
||||||
type_annotate_ops);
|
type_annotate_ops);
|
||||||
} else if (argc == 3) {
|
} else if (argc == 3) {
|
||||||
std::vector<tensorflow::string> hidden_ops;
|
std::vector<tensorflow::string> hidden_ops;
|
||||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
|
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
|
||||||
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
|
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
|
||||||
false /* op_list_is_whitelist */,
|
false /* op_list_is_allowlist */,
|
||||||
type_annotate_ops);
|
type_annotate_ops);
|
||||||
} else if (argc == 4) {
|
} else if (argc == 4) {
|
||||||
std::vector<tensorflow::string> op_list;
|
std::vector<tensorflow::string> op_list;
|
||||||
|
@ -201,7 +201,7 @@ def _recurrent_lstm(c, h):
|
|||||||
def _make_node_with_color(color, input_tensor, name=None):
|
def _make_node_with_color(color, input_tensor, name=None):
|
||||||
"""Returns a node representative of the specified list type."""
|
"""Returns a node representative of the specified list type."""
|
||||||
color = color.lower()
|
color = color.lower()
|
||||||
if color == 'w': # White node
|
if color == 'w': # Allow node
|
||||||
weights = _weight(input_tensor.get_shape().as_list())
|
weights = _weight(input_tensor.get_shape().as_list())
|
||||||
return math_ops.matmul(input_tensor, weights, name=name)
|
return math_ops.matmul(input_tensor, weights, name=name)
|
||||||
if color == 'g': # Gray node
|
if color == 'g': # Gray node
|
||||||
@ -371,7 +371,7 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
The loop has different node colors in different sections of the graph. The
|
The loop has different node colors in different sections of the graph. The
|
||||||
arguments must be strings where each character represents the color of a
|
arguments must be strings where each character represents the color of a
|
||||||
node in that section of the graph: w = white, g = gray, c = clear,
|
node in that section of the graph: w = allow, g = gray, c = clear,
|
||||||
b = black. CAPITALIZED characters indicate that the node is expected to be
|
b = black. CAPITALIZED characters indicate that the node is expected to be
|
||||||
changed to DT_HALF during graph optimization.
|
changed to DT_HALF during graph optimization.
|
||||||
|
|
||||||
|
@ -1594,7 +1594,7 @@ def assert_not_batched(dataset):
|
|||||||
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
||||||
return assert_not_batched(dataset._dataset)
|
return assert_not_batched(dataset._dataset)
|
||||||
else:
|
else:
|
||||||
whitelisted_types = [
|
allowed_types = [
|
||||||
dataset_ops._OptionsDataset,
|
dataset_ops._OptionsDataset,
|
||||||
dataset_ops.ConcatenateDataset,
|
dataset_ops.ConcatenateDataset,
|
||||||
dataset_ops.CacheDataset,
|
dataset_ops.CacheDataset,
|
||||||
@ -1615,7 +1615,7 @@ def assert_not_batched(dataset):
|
|||||||
readers.TextLineDatasetV2,
|
readers.TextLineDatasetV2,
|
||||||
readers.TFRecordDatasetV2,
|
readers.TFRecordDatasetV2,
|
||||||
]
|
]
|
||||||
for ty in whitelisted_types:
|
for ty in allowed_types:
|
||||||
if isinstance(dataset, ty):
|
if isinstance(dataset, ty):
|
||||||
for input_dataset in dataset._inputs():
|
for input_dataset in dataset._inputs():
|
||||||
assert_not_batched(input_dataset)
|
assert_not_batched(input_dataset)
|
||||||
@ -1649,7 +1649,7 @@ def assert_not_shuffled(dataset):
|
|||||||
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
||||||
return assert_not_shuffled(dataset._dataset)
|
return assert_not_shuffled(dataset._dataset)
|
||||||
else:
|
else:
|
||||||
whitelisted_types = [
|
allowed_types = [
|
||||||
dataset_ops._OptionsDataset,
|
dataset_ops._OptionsDataset,
|
||||||
dataset_ops.BatchDataset,
|
dataset_ops.BatchDataset,
|
||||||
dataset_ops.ConcatenateDataset,
|
dataset_ops.ConcatenateDataset,
|
||||||
@ -1672,7 +1672,7 @@ def assert_not_shuffled(dataset):
|
|||||||
readers.TextLineDatasetV2,
|
readers.TextLineDatasetV2,
|
||||||
readers.TFRecordDatasetV2,
|
readers.TFRecordDatasetV2,
|
||||||
]
|
]
|
||||||
for ty in whitelisted_types:
|
for ty in allowed_types:
|
||||||
if isinstance(dataset, ty):
|
if isinstance(dataset, ty):
|
||||||
for input_dataset in dataset._inputs():
|
for input_dataset in dataset._inputs():
|
||||||
assert_not_shuffled(input_dataset)
|
assert_not_shuffled(input_dataset)
|
||||||
|
@ -2858,7 +2858,7 @@ class DistributedCallbackModel(Model):
|
|||||||
orig_model_weights)
|
orig_model_weights)
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
# Whitelisted attributes of the model that can be accessed by the user
|
# Allowed attributes of the model that can be accessed by the user
|
||||||
# during a callback.
|
# during a callback.
|
||||||
if item not in ('_setattr_tracking', '_layers'):
|
if item not in ('_setattr_tracking', '_layers'):
|
||||||
logging.warning('You are accessing attribute ' + item + ' of the '
|
logging.warning('You are accessing attribute ' + item + ' of the '
|
||||||
|
@ -333,7 +333,7 @@ class TimeDistributedTest(keras_parameterized.TestCase):
|
|||||||
keras.layers.RNN(keras.layers.SimpleRNNCell(10), stateful=True))
|
keras.layers.RNN(keras.layers.SimpleRNNCell(10), stateful=True))
|
||||||
self.assertFalse(td2._always_use_reshape)
|
self.assertFalse(td2._always_use_reshape)
|
||||||
|
|
||||||
# Custom layers are not whitelisted for the fast reshape implementation.
|
# Custom layers are not allowlisted for the fast reshape implementation.
|
||||||
td3 = keras.layers.TimeDistributed(NoReshapeLayer())
|
td3 = keras.layers.TimeDistributed(NoReshapeLayer())
|
||||||
self.assertFalse(td3._always_use_reshape)
|
self.assertFalse(td3._always_use_reshape)
|
||||||
|
|
||||||
|
@ -898,7 +898,7 @@ class OptimizerWithFunctionTest(test.TestCase):
|
|||||||
|
|
||||||
_NUM_LEARNERS = 50
|
_NUM_LEARNERS = 50
|
||||||
APPLY_SCOPE = 'debug_apply'
|
APPLY_SCOPE = 'debug_apply'
|
||||||
WHITELIST = [
|
ALLOWLIST = [
|
||||||
# optimizer_v2._deduplicate_indexed_slices contains an indexed slice:
|
# optimizer_v2._deduplicate_indexed_slices contains an indexed slice:
|
||||||
# array_ops.shape(unique_indices)[0]
|
# array_ops.shape(unique_indices)[0]
|
||||||
# which winds up expanding to [0:1:1] thereby creating three constants
|
# which winds up expanding to [0:1:1] thereby creating three constants
|
||||||
@ -1025,8 +1025,8 @@ def identify_redundant_ops(graph):
|
|||||||
# Certain ops are simply not worth eliminating, and are instead simply
|
# Certain ops are simply not worth eliminating, and are instead simply
|
||||||
# ignored.
|
# ignored.
|
||||||
name, op_type = op_defs[0].name, op_defs[0].type
|
name, op_type = op_defs[0].name, op_defs[0].type
|
||||||
if any(whitelisted_scope in name and op_type == whitelisted_type
|
if any(allowlisted_scope in name and op_type == allowlisted_type
|
||||||
for whitelisted_scope, whitelisted_type in WHITELIST):
|
for allowlisted_scope, allowlisted_type in ALLOWLIST):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_duplicates += len(op_defs)
|
num_duplicates += len(op_defs)
|
||||||
|
@ -45,7 +45,7 @@ def index_directory(directory,
|
|||||||
valid files found in the directory. Labels should be sorted according
|
valid files found in the directory. Labels should be sorted according
|
||||||
to the alphanumeric order of the image file paths
|
to the alphanumeric order of the image file paths
|
||||||
(obtained via `os.walk(directory)` in Python).
|
(obtained via `os.walk(directory)` in Python).
|
||||||
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
|
formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
|
||||||
class_names: Only valid if "labels" is "inferred". This is the explict
|
class_names: Only valid if "labels" is "inferred". This is the explict
|
||||||
list of class names (must match names of subdirectories). Used
|
list of class names (must match names of subdirectories). Used
|
||||||
to control the order of the classes
|
to control the order of the classes
|
||||||
@ -136,7 +136,7 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
|
|||||||
class_indices: dict mapping class names to their index.
|
class_indices: dict mapping class names to their index.
|
||||||
follow_links: boolean, whether to recursively follow subdirectories
|
follow_links: boolean, whether to recursively follow subdirectories
|
||||||
(if False, we only list top-level images in `directory`).
|
(if False, we only list top-level images in `directory`).
|
||||||
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
|
formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple `(filenames, labels)`. `filenames` is a list of relative file
|
tuple `(filenames, labels)`. `filenames` is a list of relative file
|
||||||
|
@ -28,7 +28,7 @@ from tensorflow.python.ops import io_ops
|
|||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
WHITELIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')
|
ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.preprocessing.image_dataset_from_directory', v1=[])
|
@keras_export('keras.preprocessing.image_dataset_from_directory', v1=[])
|
||||||
@ -175,7 +175,7 @@ def image_dataset_from_directory(directory,
|
|||||||
image_paths, labels, class_names = dataset_utils.index_directory(
|
image_paths, labels, class_names = dataset_utils.index_directory(
|
||||||
directory,
|
directory,
|
||||||
labels,
|
labels,
|
||||||
formats=WHITELIST_FORMATS,
|
formats=ALLOWLIST_FORMATS,
|
||||||
class_names=class_names,
|
class_names=class_names,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
@ -865,7 +865,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
|||||||
This only allows capturing tensors in the forward graph. A ValueError is
|
This only allows capturing tensors in the forward graph. A ValueError is
|
||||||
raised if an attempt is made to capture a tensor not in the forward graph.
|
raised if an attempt is made to capture a tensor not in the forward graph.
|
||||||
To manually capture capture a tensor that is not in the forward graph, call
|
To manually capture capture a tensor that is not in the forward graph, call
|
||||||
`capture` with `whitelisted=True`.
|
`capture` with `allowlisted=True`.
|
||||||
|
|
||||||
Note: The `captures` dict does not contain the forward tensor since it is not
|
Note: The `captures` dict does not contain the forward tensor since it is not
|
||||||
directly captured. It contains the accumulator corresponding to this forward
|
directly captured. It contains the accumulator corresponding to this forward
|
||||||
@ -968,16 +968,16 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
|||||||
op_def=op_def,
|
op_def=op_def,
|
||||||
compute_device=compute_device)
|
compute_device=compute_device)
|
||||||
|
|
||||||
def capture(self, tensor, name=None, whitelisted=False):
|
def capture(self, tensor, name=None, allowlisted=False):
|
||||||
"""Selectively captures external tensors.
|
"""Selectively captures external tensors.
|
||||||
|
|
||||||
If `whitelisted` is False only allows capturing tensors in the
|
If `allowlisted` is False only allows capturing tensors in the
|
||||||
`_forward_graph`.
|
`_forward_graph`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: Tensor. May be from this FuncGraph or a different graph.
|
tensor: Tensor. May be from this FuncGraph or a different graph.
|
||||||
name: Optional name if a placeholder is created.
|
name: Optional name if a placeholder is created.
|
||||||
whitelisted: If False (default), only allows capturing tensors from the
|
allowlisted: If False (default), only allows capturing tensors from the
|
||||||
forward graph.
|
forward graph.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -985,9 +985,9 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If attempting to capture an external tensor not in the forward
|
ValueError: If attempting to capture an external tensor not in the forward
|
||||||
graph with `whitelisted` set to False.
|
graph with `allowlisted` set to False.
|
||||||
"""
|
"""
|
||||||
if not whitelisted and (isinstance(tensor, ops.EagerTensor) or
|
if not allowlisted and (isinstance(tensor, ops.EagerTensor) or
|
||||||
(tensor.graph is not self and
|
(tensor.graph is not self and
|
||||||
tensor.graph != self._forward_graph)):
|
tensor.graph != self._forward_graph)):
|
||||||
with self._forward_cond_graph.as_default():
|
with self._forward_cond_graph.as_default():
|
||||||
@ -1136,7 +1136,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
|||||||
"Resource tensors must be loop invariants %s." % tensor_in_outer_graph)
|
"Resource tensors must be loop invariants %s." % tensor_in_outer_graph)
|
||||||
|
|
||||||
self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
|
self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
|
||||||
tensor_in_outer_graph, whitelisted=True)
|
tensor_in_outer_graph, allowlisted=True)
|
||||||
return self._indirect_captures[ops.tensor_id(tensor)]
|
return self._indirect_captures[ops.tensor_id(tensor)]
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,10 +143,10 @@ def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
|
|||||||
# computation.
|
# computation.
|
||||||
with body_grad_graph.as_default():
|
with body_grad_graph.as_default():
|
||||||
input_slices = ops.IndexedSlices(
|
input_slices = ops.IndexedSlices(
|
||||||
values=body_grad_graph.capture(init_slices.values, whitelisted=True),
|
values=body_grad_graph.capture(init_slices.values, allowlisted=True),
|
||||||
indices=body_grad_graph.capture(init_slices.indices, whitelisted=True),
|
indices=body_grad_graph.capture(init_slices.indices, allowlisted=True),
|
||||||
dense_shape=body_grad_graph.capture(init_slices.dense_shape,
|
dense_shape=body_grad_graph.capture(
|
||||||
whitelisted=True))
|
init_slices.dense_shape, allowlisted=True))
|
||||||
|
|
||||||
# Remove the captured tensors from the function inputs. We'll add them back
|
# Remove the captured tensors from the function inputs. We'll add them back
|
||||||
# at the correct index in _update_indexed_slices_param.
|
# at the correct index in _update_indexed_slices_param.
|
||||||
|
@ -36,7 +36,7 @@ from tensorflow.python.platform import tf_logging
|
|||||||
# corresponding kernel; nodes without a corresponding kernel (perhaps due to
|
# corresponding kernel; nodes without a corresponding kernel (perhaps due to
|
||||||
# attr types) generate a warning but are otherwise ignored. Ops in this set are
|
# attr types) generate a warning but are otherwise ignored. Ops in this set are
|
||||||
# registered even if there's no corresponding kernel.
|
# registered even if there's no corresponding kernel.
|
||||||
OPS_WITHOUT_KERNEL_WHITELIST = frozenset([
|
OPS_WITHOUT_KERNEL_ALLOWLIST = frozenset([
|
||||||
# AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
|
# AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
|
||||||
# core/common_runtime/accumulate_n_optimizer.cc.
|
# core/common_runtime/accumulate_n_optimizer.cc.
|
||||||
'AccumulateNV2'
|
'AccumulateNV2'
|
||||||
@ -67,7 +67,7 @@ def _get_ops_from_graphdef(graph_def):
|
|||||||
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
|
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
|
||||||
node_def.SerializeToString())
|
node_def.SerializeToString())
|
||||||
op = str(node_def.op)
|
op = str(node_def.op)
|
||||||
if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
|
if kernel_class or op in OPS_WITHOUT_KERNEL_ALLOWLIST:
|
||||||
op_and_kernel = (op, str(kernel_class.decode('utf-8'))
|
op_and_kernel = (op, str(kernel_class.decode('utf-8'))
|
||||||
if kernel_class else None)
|
if kernel_class else None)
|
||||||
ops.add(op_and_kernel)
|
ops.add(op_and_kernel)
|
||||||
|
@ -68,7 +68,7 @@ class EmbeddingColumnTest(test.TestCase):
|
|||||||
tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension)
|
tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension)
|
||||||
|
|
||||||
def test_custom_column(self):
|
def test_custom_column(self):
|
||||||
# This column is not in any whitelist but should succeed because
|
# This column is not in any allowlist but should succeed because
|
||||||
# it inherits from V2 CategoricalColumn.
|
# it inherits from V2 CategoricalColumn.
|
||||||
categorical_column = fc_lib.categorical_column_with_identity(
|
categorical_column = fc_lib.categorical_column_with_identity(
|
||||||
key='aaa', num_buckets=10)
|
key='aaa', num_buckets=10)
|
||||||
|
@ -122,7 +122,7 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
|||||||
|
|
||||||
* `ClearList`: Ops that do not have numerically significant adverse effects.
|
* `ClearList`: Ops that do not have numerically significant adverse effects.
|
||||||
E.g. `ArgMax` and `Floor`.
|
E.g. `ArgMax` and `Floor`.
|
||||||
* `WhiteList`: Ops that are considered numerically safe for execution in
|
* `AllowList`: Ops that are considered numerically safe for execution in
|
||||||
float16, and thus are always converted. E.g. `Conv2D`.
|
float16, and thus are always converted. E.g. `Conv2D`.
|
||||||
* `BlackList`: Ops that are numerically unsafe to execute in float16 and
|
* `BlackList`: Ops that are numerically unsafe to execute in float16 and
|
||||||
can negatively affect downstream nodes. E.g. `Softmax`.
|
can negatively affect downstream nodes. E.g. `Softmax`.
|
||||||
@ -267,7 +267,7 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
|||||||
|
|
||||||
* `ClearList`: Ops that do not have numerically significant adverse effects.
|
* `ClearList`: Ops that do not have numerically significant adverse effects.
|
||||||
E.g. `ArgMax` and `Floor`.
|
E.g. `ArgMax` and `Floor`.
|
||||||
* `WhiteList`: Ops that are considered numerically safe for execution in
|
* `AllowList`: Ops that are considered numerically safe for execution in
|
||||||
float16, and thus are always converted. E.g. `Conv2D`.
|
float16, and thus are always converted. E.g. `Conv2D`.
|
||||||
* `BlackList`: Ops that are numerically unsafe to execute in float16 and
|
* `BlackList`: Ops that are numerically unsafe to execute in float16 and
|
||||||
can negatively affect downstream nodes. E.g. `Softmax`.
|
can negatively affect downstream nodes. E.g. `Softmax`.
|
||||||
|
@ -93,7 +93,7 @@ def remove_undocumented(module_name, allowed_exception_list=None,
|
|||||||
doc_string_modules: a list of modules from which to take the docstrings.
|
doc_string_modules: a list of modules from which to take the docstrings.
|
||||||
If None, then a list containing only the module named `module_name` is used.
|
If None, then a list containing only the module named `module_name` is used.
|
||||||
|
|
||||||
Furthermore, if a symbol previously added with `add_to_global_whitelist`,
|
Furthermore, if a symbol previously added with `add_to_global_allowlist`,
|
||||||
then it will always be allowed. This is useful for internal tests.
|
then it will always be allowed. This is useful for internal tests.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -96,8 +96,8 @@ do_pylint() {
|
|||||||
# --incremental Performs check on only the python files changed in the
|
# --incremental Performs check on only the python files changed in the
|
||||||
# last non-merge git commit.
|
# last non-merge git commit.
|
||||||
|
|
||||||
# Use this list to whitelist pylint errors
|
# Use this list to allowlist pylint errors
|
||||||
ERROR_WHITELIST="^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline "\
|
ERROR_ALLOWLIST="^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline "\
|
||||||
"^tensorflow/python/platform/default/_gfile\.py.*\[E0301.*non-iterator "\
|
"^tensorflow/python/platform/default/_gfile\.py.*\[E0301.*non-iterator "\
|
||||||
"^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\
|
"^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\
|
||||||
"^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\
|
"^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\
|
||||||
@ -115,7 +115,7 @@ do_pylint() {
|
|||||||
"^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "\
|
"^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "\
|
||||||
"^tensorflow/python/keras/preprocessing/image\.py.*\[E0240.*Inconsistent method resolution "
|
"^tensorflow/python/keras/preprocessing/image\.py.*\[E0240.*Inconsistent method resolution "
|
||||||
|
|
||||||
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
echo "ERROR_ALLOWLIST=\"${ERROR_ALLOWLIST}\""
|
||||||
|
|
||||||
if [[ $# != "0" ]] && [[ $# != "1" ]]; then
|
if [[ $# != "0" ]] && [[ $# != "1" ]]; then
|
||||||
echo "Invalid syntax when invoking do_pylint"
|
echo "Invalid syntax when invoking do_pylint"
|
||||||
@ -195,16 +195,16 @@ do_pylint() {
|
|||||||
|
|
||||||
N_ERRORS=0
|
N_ERRORS=0
|
||||||
while read -r LINE; do
|
while read -r LINE; do
|
||||||
IS_WHITELISTED=0
|
IS_ALLOWLISTED=0
|
||||||
for WL_REGEX in ${ERROR_WHITELIST}; do
|
for WL_REGEX in ${ERROR_ALLOWLIST}; do
|
||||||
if echo ${LINE} | grep -q "${WL_REGEX}"; then
|
if echo ${LINE} | grep -q "${WL_REGEX}"; then
|
||||||
echo "Found a whitelisted error:"
|
echo "Found a allowlisted error:"
|
||||||
echo " ${LINE}"
|
echo " ${LINE}"
|
||||||
IS_WHITELISTED=1
|
IS_ALLOWLISTED=1
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
if [[ ${IS_WHITELISTED} == "0" ]]; then
|
if [[ ${IS_ALLOWLISTED} == "0" ]]; then
|
||||||
echo "${LINE}" >> ${NONWL_ERRORS_FILE}
|
echo "${LINE}" >> ${NONWL_ERRORS_FILE}
|
||||||
echo "" >> ${NONWL_ERRORS_FILE}
|
echo "" >> ${NONWL_ERRORS_FILE}
|
||||||
((N_ERRORS++))
|
((N_ERRORS++))
|
||||||
@ -213,11 +213,11 @@ do_pylint() {
|
|||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [[ ${N_ERRORS} != 0 ]]; then
|
if [[ ${N_ERRORS} != 0 ]]; then
|
||||||
echo "FAIL: Found ${N_ERRORS} non-whitelisted pylint errors:"
|
echo "FAIL: Found ${N_ERRORS} non-allowlisted pylint errors:"
|
||||||
cat "${NONWL_ERRORS_FILE}"
|
cat "${NONWL_ERRORS_FILE}"
|
||||||
return 1
|
return 1
|
||||||
else
|
else
|
||||||
echo "PASS: No non-whitelisted pylint errors were found."
|
echo "PASS: No non-allowlisted pylint errors were found."
|
||||||
return 0
|
return 0
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
@ -370,7 +370,7 @@ do_external_licenses_check(){
|
|||||||
-v ${MISSING_LICENSES_FILE} > temp.txt
|
-v ${MISSING_LICENSES_FILE} > temp.txt
|
||||||
mv temp.txt ${MISSING_LICENSES_FILE}
|
mv temp.txt ${MISSING_LICENSES_FILE}
|
||||||
|
|
||||||
# Whitelist
|
# Allowlist
|
||||||
echo ${EXTRA_LICENSE_FILE}
|
echo ${EXTRA_LICENSE_FILE}
|
||||||
grep \
|
grep \
|
||||||
-e "//third_party/mkl" \
|
-e "//third_party/mkl" \
|
||||||
|
@ -40,7 +40,7 @@ FUTURES_PATTERN_2 = re.compile(
|
|||||||
FUTURES_PATTERN_3 = re.compile(r'^from __future__ import (\w+) as \w+\s*$')
|
FUTURES_PATTERN_3 = re.compile(r'^from __future__ import (\w+) as \w+\s*$')
|
||||||
REQUIRED_FUTURES = frozenset(['absolute_import', 'division', 'print_function'])
|
REQUIRED_FUTURES = frozenset(['absolute_import', 'division', 'print_function'])
|
||||||
|
|
||||||
WHITELIST = [
|
ALLOWLIST = [
|
||||||
'python/platform/control_imports.py',
|
'python/platform/control_imports.py',
|
||||||
'tools/docker/jupyter_notebook_config.py',
|
'tools/docker/jupyter_notebook_config.py',
|
||||||
'tools/ci_build/update_version.py',
|
'tools/ci_build/update_version.py',
|
||||||
@ -93,12 +93,12 @@ def main():
|
|||||||
BASE_DIR)
|
BASE_DIR)
|
||||||
|
|
||||||
# Verify that all files have futures
|
# Verify that all files have futures
|
||||||
whitelist = frozenset(os.path.join(BASE_DIR, w) for w in WHITELIST)
|
allowlist = frozenset(os.path.join(BASE_DIR, w) for w in ALLOWLIST)
|
||||||
old_division = frozenset(os.path.join(BASE_DIR, w) for w in OLD_DIVISION)
|
old_division = frozenset(os.path.join(BASE_DIR, w) for w in OLD_DIVISION)
|
||||||
for root, _, filenames in os.walk(BASE_DIR):
|
for root, _, filenames in os.walk(BASE_DIR):
|
||||||
for f in fnmatch.filter(filenames, '*.py'):
|
for f in fnmatch.filter(filenames, '*.py'):
|
||||||
path = os.path.join(root, f)
|
path = os.path.join(root, f)
|
||||||
if path not in whitelist:
|
if path not in allowlist:
|
||||||
try:
|
try:
|
||||||
check_file(path, old_division=path in old_division)
|
check_file(path, old_division=path in old_division)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user