Replace instances of "whitelist" with "allowlist" where possible. See Google Developer guidelines at https://developers.google.com/style/word-list#blacklist for more information.
PiperOrigin-RevId: 320210110 Change-Id: I480d2b1c80d7d77fdd071b7642011758988f18c0
This commit is contained in:
parent
a85beef408
commit
7eab1f3bfe
RELEASE.mdSECURITY.md
tensorflow
compiler
core
common_runtime
framework
grappler
costs
optimizers
kernels/data
go/op
lite
delegates
flex
hexagon
nnapi
experimental/acceleration
g3doc/guide
java/src/test/java/org/tensorflow/lite/gpu
kernels
micro/tools/make
toco/tflite
tools
python
__init__.py
autograph
converters
core
g3doc/reference
impl
pyct
data/ops
debug
cli
lib
check_numerics_callback.pydebug_utils.pydebug_utils_test.pydist_session_debug_grpc_test.pygrpc_large_data_test.pysession_debug_grpc_test.pysource_utils.pysource_utils_test.py
wrappers
distribute
eager
framework
auto_control_deps.pyconvert_to_constants.pyfunc_graph.pyfunction.pyfunction_test.pygraph_util_impl.pyimporter_test.pypython_op_gen.ccpython_op_gen_main.cc
grappler
keras
engine
layers
optimizer_v2
preprocessing
ops
tools
tpu
training/experimental
util
tools
@ -50,6 +50,9 @@
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" with "allowlist" where possible.
|
||||
Please see https://developers.google.com/style/word-list#blacklist for more
|
||||
context.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
@ -44,7 +44,7 @@ Even if the untrusted party only supplies the serialized computation
|
||||
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
|
||||
set of computation primitives available to TensorFlow is powerful enough that
|
||||
you should assume that the TensorFlow process effectively executes arbitrary
|
||||
code. One common solution is to whitelist only a few safe Ops. While this is
|
||||
code. One common solution is to allow only a few safe Ops. While this is
|
||||
possible in theory, we still recommend you sandbox the execution.
|
||||
|
||||
It depends on the computation graph whether a user provided checkpoint is safe.
|
||||
|
@ -1096,33 +1096,33 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
|
||||
return true;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<string> GetOrCreateWhitelist() {
|
||||
absl::flat_hash_map<string, std::vector<string>>* whitelist_table =
|
||||
tensorflow::GetWhitelistTable();
|
||||
absl::flat_hash_set<string> GetOrCreateAllowlist() {
|
||||
absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
|
||||
tensorflow::GetAllowlistTable();
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
absl::flat_hash_set<string> whitelist;
|
||||
absl::flat_hash_set<string> allowlist;
|
||||
|
||||
for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
|
||||
if (s == "FUSIBLE") {
|
||||
for (auto pair : *whitelist_table) {
|
||||
whitelist.insert(pair.second.begin(), pair.second.end());
|
||||
for (auto pair : *allowlist_table) {
|
||||
allowlist.insert(pair.second.begin(), pair.second.end());
|
||||
}
|
||||
} else if (whitelist_table->contains(s)) {
|
||||
auto v = whitelist_table->at(s);
|
||||
whitelist.insert(v.begin(), v.end());
|
||||
} else if (allowlist_table->contains(s)) {
|
||||
auto v = allowlist_table->at(s);
|
||||
allowlist.insert(v.begin(), v.end());
|
||||
} else if (!s.empty()) {
|
||||
// Should be a user provided TF operation.
|
||||
whitelist.insert(string(s));
|
||||
allowlist.insert(string(s));
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(2) && !whitelist.empty()) {
|
||||
std::vector<string> vwhitelist(whitelist.begin(), whitelist.end());
|
||||
absl::c_sort(vwhitelist);
|
||||
if (VLOG_IS_ON(2) && !allowlist.empty()) {
|
||||
std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
|
||||
absl::c_sort(vallowlist);
|
||||
VLOG(2) << "XLA clustering will only consider the following TF operations: "
|
||||
<< absl::StrJoin(vwhitelist, " ");
|
||||
<< absl::StrJoin(vallowlist, " ");
|
||||
}
|
||||
return whitelist;
|
||||
return allowlist;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
@ -1156,12 +1156,12 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
|
||||
VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
|
||||
|
||||
auto whitelist = GetOrCreateWhitelist();
|
||||
auto allowlist = GetOrCreateAllowlist();
|
||||
|
||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||
// Check that user's provided TF operation really exists.
|
||||
for (const auto& s : whitelist) {
|
||||
for (const auto& s : allowlist) {
|
||||
if (!all_ops.contains(string(s))) {
|
||||
return errors::InvalidArgument(
|
||||
"The operation '", s,
|
||||
@ -1206,7 +1206,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
|
||||
if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
|
||||
VLOG(1) << "Rejecting TF operation " << node->def().op()
|
||||
<< " as it is not listed in --tf_xla_ops_to_cluster.";
|
||||
continue;
|
||||
@ -1781,7 +1781,7 @@ Status MarkForCompilationPass::RunForTest(
|
||||
return MarkForCompilation(options, debug_options);
|
||||
}
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
|
||||
// Table format: category name: {list of TF operations in that category}
|
||||
static absl::flat_hash_map<string, std::vector<string>>* result =
|
||||
new absl::flat_hash_map<string, std::vector<string>>{
|
||||
@ -1845,7 +1845,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
namespace testing {
|
||||
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
||||
|
||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
absl::flat_hash_set<string> result{"AdjustContrastv2",
|
||||
"AdjustHue",
|
||||
"AdjustSaturation",
|
||||
|
@ -58,7 +58,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info = nullptr);
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable();
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
|
||||
|
||||
namespace testing {
|
||||
// DO NOT USE IN PRODUCTION.
|
||||
@ -66,8 +66,8 @@ namespace testing {
|
||||
// Resets some internal state to let us write reliable unit tests.
|
||||
void ResetClusterSequenceNumber();
|
||||
|
||||
// Return a list of operation that we choose not to put into the whitelist.
|
||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp();
|
||||
// Return a list of operation that we choose not to put into the allowlist.
|
||||
absl::flat_hash_set<string> GetKnownXLAAllowlistOp();
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1802,34 +1802,34 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
|
||||
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
||||
}
|
||||
}
|
||||
TEST(XlaCompilationTest, XLALiteWhitelist) {
|
||||
auto* whitelist_table = tensorflow::GetWhitelistTable();
|
||||
absl::flat_hash_set<string> hwhitelist;
|
||||
TEST(XlaCompilationTest, XLALiteAllowlist) {
|
||||
auto* allowlist_table = tensorflow::GetAllowlistTable();
|
||||
absl::flat_hash_set<string> hallowlist;
|
||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||
|
||||
// Check that all the operations in the table are existing TF operations
|
||||
for (auto pair : *whitelist_table) {
|
||||
hwhitelist.insert(pair.second.begin(), pair.second.end());
|
||||
for (auto pair : *allowlist_table) {
|
||||
hallowlist.insert(pair.second.begin(), pair.second.end());
|
||||
for (auto op : pair.second) {
|
||||
ASSERT_TRUE(all_ops.contains(op));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all registered XLA operation are in the whitelist
|
||||
// Check that all registered XLA operation are in the allowlist
|
||||
// table or are known to not be in it.
|
||||
|
||||
absl::flat_hash_set<string> known_not_in_list =
|
||||
tensorflow::testing::GetKnownXLAWhitelistOp();
|
||||
tensorflow::testing::GetKnownXLAAllowlistOp();
|
||||
std::vector<string> unknow_op;
|
||||
for (string op : vall_ops) {
|
||||
if (!hwhitelist.contains(op) && !known_not_in_list.contains(op)) {
|
||||
if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
|
||||
unknow_op.push_back(op);
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(unknow_op.empty())
|
||||
<< "Someone added support for a new TF opeations inside XLA. They must "
|
||||
"be included in the XLALite whitelist or blacklist:\n"
|
||||
"be included in the XLALite allowlist or blacklist:\n"
|
||||
<< absl::StrJoin(unknow_op, "\n");
|
||||
}
|
||||
} // namespace
|
||||
|
@ -30,7 +30,7 @@ struct PassConfig {
|
||||
explicit PassConfig(QuantizationSpecs specs)
|
||||
: emit_builtin_tflite_ops(true),
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
trim_functions_allowlist({}),
|
||||
quant_specs(std::move(specs)),
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
@ -44,8 +44,8 @@ struct PassConfig {
|
||||
// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
|
||||
// TF ops before legalization to TF Lite dialect.
|
||||
bool lower_tensor_list_ops;
|
||||
// The whitelist of functions that would be preserved after trimming.
|
||||
llvm::ArrayRef<std::string> trim_functions_whitelist;
|
||||
// The allowlist of functions that would be preserved after trimming.
|
||||
llvm::ArrayRef<std::string> trim_functions_allowlist;
|
||||
// All information about quantization.
|
||||
QuantizationSpecs quant_specs;
|
||||
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
||||
|
@ -71,7 +71,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
||||
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
@ -101,7 +101,7 @@ using mlir::Value;
|
||||
using tensorflow::OpOrArgLocNameMapper;
|
||||
using tensorflow::OpOrArgNameMapper;
|
||||
using tensorflow::Status;
|
||||
using tflite::flex::IsWhitelistedFlexOp;
|
||||
using tflite::flex::IsAllowlistedFlexOp;
|
||||
using xla::StatusOr;
|
||||
|
||||
template <typename T>
|
||||
@ -972,7 +972,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
// model is of an open op system.
|
||||
//
|
||||
// The following algorithm is followed:
|
||||
// if flex is enabled and the op is whitelisted as flex
|
||||
// if flex is enabled and the op is allowlisted as flex
|
||||
// we emit op as flex.
|
||||
// if custom is enabled
|
||||
// we emit the op as custom.
|
||||
@ -982,11 +982,11 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
}
|
||||
|
||||
// Flex op case
|
||||
// Eventually, the whitelist will go away and we will rely on some TF op
|
||||
// Eventually, the allowlist will go away and we will rely on some TF op
|
||||
// trait (e.g. No side effect) to determine if it is a supported "Flex"
|
||||
// op or not.
|
||||
if (enabled_op_types_.contains(OpType::kSelectTf) &&
|
||||
IsWhitelistedFlexOp(node_def->op())) {
|
||||
IsAllowlistedFlexOp(node_def->op())) {
|
||||
// Construct ops as flex op encoding TensorFlow node definition
|
||||
// as custom options.
|
||||
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
||||
@ -1037,7 +1037,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
}
|
||||
|
||||
// Insert failed op to `flex_ops` or `custom_ops`.
|
||||
if (IsWhitelistedFlexOp(node_def->op())) {
|
||||
if (IsAllowlistedFlexOp(node_def->op())) {
|
||||
failed_flex_ops_.insert(os.str());
|
||||
} else {
|
||||
failed_custom_ops_.insert(os.str());
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-allowlist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: quantize_float_placeholder_only
|
||||
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s
|
||||
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-allowlist="bar,foobar" %s | FileCheck %s
|
||||
|
||||
func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
return %arg0 : tensor<1x4xf32>
|
||||
|
@ -61,7 +61,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
|
||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||
// pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist);
|
||||
llvm::ArrayRef<std::string> trim_funcs_allowlist);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
||||
// pass.
|
||||
|
@ -35,9 +35,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> quantize_whitelist(
|
||||
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of whitelisted functions to be "
|
||||
static llvm::cl::list<std::string> quantize_allowlist(
|
||||
"tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of allowlisted functions to be "
|
||||
"quantized. Only used in tests"),
|
||||
llvm::cl::CommaSeparated);
|
||||
|
||||
@ -108,7 +108,7 @@ class PrepareQuantizePass
|
||||
|
||||
// Get the min and max values from the quantization specification for the
|
||||
// current function function and argument index. Uses default values if
|
||||
// the function is specified in the `quantize_whitelist`.
|
||||
// the function is specified in the `quantize_allowlist`.
|
||||
std::pair<llvm::Optional<double>, llvm::Optional<double>>
|
||||
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
|
||||
if (func_name == quant_specs_.target_func) {
|
||||
@ -132,7 +132,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
// Skip this function because it isn't the target function from the spec or
|
||||
// in the function while list.
|
||||
if (target_func != func_name &&
|
||||
!llvm::is_contained(quantize_whitelist, func_name)) {
|
||||
!llvm::is_contained(quantize_allowlist, func_name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -29,12 +29,12 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
||||
// The cmd line flag to specify the whitelist of functions. Rest are trimmed
|
||||
// The cmd line flag to specify the allowlist of functions. Rest are trimmed
|
||||
// after this pass is run.
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> trim_funcs_whitelist(
|
||||
"tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of whitelisted functions. The first "
|
||||
static llvm::cl::list<std::string> trim_funcs_allowlist(
|
||||
"tfl-trim-funcs-allowlist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of allowlisted functions. The first "
|
||||
"function specified will be used as main."),
|
||||
llvm::cl::CommaSeparated);
|
||||
|
||||
@ -43,25 +43,25 @@ namespace TFL {
|
||||
namespace {
|
||||
|
||||
// The pass to trim functions before we legalize to TFL
|
||||
// dialect using the specified whitelist.
|
||||
// dialect using the specified allowlist.
|
||||
class TrimFunctionsPass
|
||||
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
|
||||
: trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||
explicit TrimFunctionsPass() : trim_funcs_allowlist_(trim_funcs_allowlist) {}
|
||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_allowlist)
|
||||
: trim_funcs_allowlist_(trim_funcs_allowlist) {}
|
||||
|
||||
private:
|
||||
void runOnOperation() override;
|
||||
bool TrimModule();
|
||||
void Verify();
|
||||
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist_;
|
||||
llvm::ArrayRef<std::string> trim_funcs_allowlist_;
|
||||
};
|
||||
|
||||
void TrimFunctionsPass::runOnOperation() {
|
||||
// trim the functions in the module using the trim_funcs_whitelist_
|
||||
// by removing functions not in the whitelist.
|
||||
// trim the functions in the module using the trim_funcs_allowlist_
|
||||
// by removing functions not in the allowlist.
|
||||
if (TrimModule()) {
|
||||
// verify the updated module is still valid, if not signal the
|
||||
// pass as failed.
|
||||
@ -70,20 +70,20 @@ void TrimFunctionsPass::runOnOperation() {
|
||||
}
|
||||
|
||||
bool TrimFunctionsPass::TrimModule() {
|
||||
// if no trim_funcs_whitelist_ is specified, this pass is a no-op.
|
||||
if (trim_funcs_whitelist_.empty()) return false;
|
||||
// if no trim_funcs_allowlist_ is specified, this pass is a no-op.
|
||||
if (trim_funcs_allowlist_.empty()) return false;
|
||||
|
||||
llvm::SmallVector<FuncOp, 4> funcs_to_trim;
|
||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||
if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) {
|
||||
// If no main is specified in the whitelist, use the 1st func
|
||||
// in trim_funcs_whitelist as the main.
|
||||
if (llvm::is_contained(trim_funcs_allowlist_, func.getName())) {
|
||||
// If no main is specified in the allowlist, use the 1st func
|
||||
// in trim_funcs_allowlist as the main.
|
||||
// TODO(ashwinm): Currently tflite flatbuffer export assumes there is
|
||||
// always a main. This is strictly not required for TFlite. We need to
|
||||
// remove that restriction once we have support to attribute the main
|
||||
// tensorflow function in MLIR TF import using an entry_point attr.
|
||||
if (!llvm::is_contained(trim_funcs_whitelist_, "main") &&
|
||||
func.getName() == trim_funcs_whitelist_[0]) {
|
||||
if (!llvm::is_contained(trim_funcs_allowlist_, "main") &&
|
||||
func.getName() == trim_funcs_allowlist_[0]) {
|
||||
func.setName("main");
|
||||
}
|
||||
} else {
|
||||
@ -99,7 +99,7 @@ bool TrimFunctionsPass::TrimModule() {
|
||||
}
|
||||
|
||||
// validate that all reachable functions from the remaining functions are
|
||||
// also in the whitelist.
|
||||
// also in the allowlist.
|
||||
void TrimFunctionsPass::Verify() {
|
||||
// TODO(ashwinm): Instead, we should make sure that references to all
|
||||
// SymbolRefAttrs of all ops are present.
|
||||
@ -109,7 +109,7 @@ void TrimFunctionsPass::Verify() {
|
||||
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
|
||||
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
|
||||
return getOperation().emitError()
|
||||
<< func.getName() << " is not in the funcs whitelist";
|
||||
<< func.getName() << " is not in the funcs allowlist";
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walk_result.wasInterrupted()) return signalPassFailure();
|
||||
@ -121,13 +121,13 @@ void TrimFunctionsPass::Verify() {
|
||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||
/// pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
|
||||
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
|
||||
llvm::ArrayRef<std::string> trim_funcs_allowlist) {
|
||||
return std::make_unique<TrimFunctionsPass>(trim_funcs_allowlist);
|
||||
}
|
||||
|
||||
static PassRegistration<TrimFunctionsPass> pass(
|
||||
"tfl-trim-funcs-tf",
|
||||
"Trim functions to restrict them to a specified whitelist prior to "
|
||||
"Trim functions to restrict them to a specified allowlist prior to "
|
||||
"legalization to TensorFlow lite dialect");
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -23,8 +23,8 @@ func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: not_whitelisted_op
|
||||
func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
// CHECK-LABEL: not_allowlisted_op
|
||||
func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
// CHECK: tf.TensorListReserve
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
// CHECK: tf.TensorListGetItem
|
||||
|
@ -75,10 +75,10 @@ namespace {
|
||||
template <typename T, size_t N>
|
||||
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
|
||||
|
||||
static bool IsOpWhitelisted(Operation* op) {
|
||||
static bool IsOpAllowlisted(Operation* op) {
|
||||
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
|
||||
// building valid MLIR using MlirHloBuilder.
|
||||
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
||||
// TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
|
||||
// all tf2xla kernels.
|
||||
// clang-format off
|
||||
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
||||
@ -342,7 +342,7 @@ LogicalResult FuncLegalizer::Legalize() {
|
||||
}
|
||||
|
||||
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
||||
if (!IsOpWhitelisted(op)) return success();
|
||||
if (!IsOpAllowlisted(op)) return success();
|
||||
|
||||
// Only static shaped operands are supported in XLA builders for now.
|
||||
for (Type ty : op->getOperandTypes()) {
|
||||
|
@ -63,7 +63,7 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
if (x.name != y.name) return true;
|
||||
if (x.label != y.label) return true;
|
||||
// The registrations refer to the same Op: ensures they are compatible and
|
||||
// are restricted to different device whitelists.
|
||||
// are restricted to different device allowlists.
|
||||
if (x.compilation_only != y.compilation_only) {
|
||||
LOG(WARNING) << "Registrations of " << x.name
|
||||
<< " have incompatible compilation_only settings.";
|
||||
@ -84,14 +84,14 @@ XlaOpRegistry::~XlaOpRegistry() = default;
|
||||
<< " have incompatible allow_string_type settings.";
|
||||
return false;
|
||||
}
|
||||
if (!x.has_device_whitelist && !y.has_device_whitelist) {
|
||||
if (!x.has_device_allowlist && !y.has_device_allowlist) {
|
||||
LOG(WARNING) << "Duplicate registrations of " << x.name
|
||||
<< "with no device whitelists.";
|
||||
<< "with no device allowlists.";
|
||||
return false;
|
||||
}
|
||||
if (x.has_device_whitelist && y.has_device_whitelist) {
|
||||
for (const auto& device : x.device_whitelist) {
|
||||
if (y.device_whitelist.count(device) != 0) {
|
||||
if (x.has_device_allowlist && y.has_device_allowlist) {
|
||||
for (const auto& device : x.device_allowlist) {
|
||||
if (y.device_allowlist.count(device) != 0) {
|
||||
LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
|
||||
<< device;
|
||||
return false;
|
||||
@ -185,28 +185,28 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
// The goal is to allow the co-existence of backend-specific kernels and
|
||||
// generic kernels. To achieve this, we enforce the following order of
|
||||
// registrations for one op:
|
||||
// 1. Process op registration with device whitelists:
|
||||
// 1. Process op registration with device allowlists:
|
||||
// this pass registers backend-specific kernels for this op.
|
||||
// 2. Process op registration without device whitelists:
|
||||
// 2. Process op registration without device allowlists:
|
||||
// this pass registers the kernels for all the other supported backends.
|
||||
for (auto& ops : registry.ops_) {
|
||||
const string& op_name = ops.first;
|
||||
std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
|
||||
// Partition the op registration so that the ones with device whitelists
|
||||
// precede the one without device whitelist.
|
||||
// Partition the op registration so that the ones with device allowlists
|
||||
// precede the one without device allowlist.
|
||||
std::partition(op_registrations.begin(), op_registrations.end(),
|
||||
[](const std::unique_ptr<OpRegistration>& op_reg) {
|
||||
return op_reg->has_device_whitelist;
|
||||
return op_reg->has_device_allowlist;
|
||||
});
|
||||
|
||||
// Collect a set of backend registered by ops with device whitelists.
|
||||
// The op registration without whitelists will register a generic kernel
|
||||
// Collect a set of backend registered by ops with device allowlists.
|
||||
// The op registration without allowlists will register a generic kernel
|
||||
// for all other backends not in this set.
|
||||
std::unordered_set<string> whitelisted_backend;
|
||||
std::unordered_set<string> allowlisted_backend;
|
||||
for (auto& op_registration : op_registrations) {
|
||||
if (op_registration->has_device_whitelist) {
|
||||
whitelisted_backend.insert(op_registration->device_whitelist.begin(),
|
||||
op_registration->device_whitelist.end());
|
||||
if (op_registration->has_device_allowlist) {
|
||||
allowlisted_backend.insert(op_registration->device_allowlist.begin(),
|
||||
op_registration->device_allowlist.end());
|
||||
}
|
||||
}
|
||||
|
||||
@ -238,19 +238,19 @@ void XlaOpRegistry::RegisterCompilationKernels() {
|
||||
}
|
||||
|
||||
for (auto& backend : registry.backends_) {
|
||||
// If the operator has a device whitelist, only register on whitelisted
|
||||
// If the operator has a device allowlist, only register on allowlisted
|
||||
// devices.
|
||||
if (op_registration->has_device_whitelist &&
|
||||
op_registration->device_whitelist.find(backend.first) ==
|
||||
op_registration->device_whitelist.end()) {
|
||||
if (op_registration->has_device_allowlist &&
|
||||
op_registration->device_allowlist.find(backend.first) ==
|
||||
op_registration->device_allowlist.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the operator does NOT has a device whitelist, skip all devices
|
||||
// If the operator does NOT has a device allowlist, skip all devices
|
||||
// that has already been registered.
|
||||
if (!op_registration->has_device_whitelist &&
|
||||
whitelisted_backend.find(backend.first) !=
|
||||
whitelisted_backend.end()) {
|
||||
if (!op_registration->has_device_allowlist &&
|
||||
allowlisted_backend.find(backend.first) !=
|
||||
allowlisted_backend.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -478,17 +478,17 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
||||
absl::Span<const absl::string_view> devices) {
|
||||
registration_->has_device_whitelist = true;
|
||||
registration_->has_device_allowlist = true;
|
||||
for (absl::string_view device : devices) {
|
||||
registration_->device_whitelist.emplace(device);
|
||||
registration_->device_allowlist.emplace(device);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(
|
||||
absl::string_view device) {
|
||||
registration_->has_device_whitelist = true;
|
||||
registration_->device_whitelist.emplace(device);
|
||||
registration_->has_device_allowlist = true;
|
||||
registration_->device_allowlist.emplace(device);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -258,10 +258,10 @@ class XlaOpRegistry {
|
||||
// Mapping from attribute name to a list of supported types.
|
||||
std::unordered_map<string, std::set<DataType>> type_constraints;
|
||||
|
||||
// An optional whitelist of devices. If there is no whitelist, all devices
|
||||
// An optional allowlist of devices. If there is no allowlist, all devices
|
||||
// are permitted.
|
||||
bool has_device_whitelist = false;
|
||||
std::unordered_set<string> device_whitelist;
|
||||
bool has_device_allowlist = false;
|
||||
std::unordered_set<string> device_allowlist;
|
||||
|
||||
// Names of arguments that must be compile-time constants.
|
||||
std::unordered_set<string> compile_time_constant_inputs;
|
||||
@ -279,8 +279,8 @@ class XlaOpRegistry {
|
||||
// Returns true if registrations x and y can both be added to the registry.
|
||||
// This is always the case if they refer to different ops. If they refer to
|
||||
// the same op name, they must: have the same values for compilation_only,
|
||||
// allow_resource_types and allow_variant_types; use a device_whitelist; and
|
||||
// their whitelists must not intersect.
|
||||
// allow_resource_types and allow_variant_types; use a device_allowlist; and
|
||||
// their allowlists must not intersect.
|
||||
static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
|
||||
|
||||
static Status CompileTimeConstantInputs(const NodeDef& node_def,
|
||||
@ -319,7 +319,7 @@ class XlaOpRegistrationBuilder {
|
||||
// Starts an operator registration chain.
|
||||
static XlaOpRegistrationBuilder Name(absl::string_view name);
|
||||
|
||||
// Specifies a whitelist of devices on which the operator may run.
|
||||
// Specifies a allowlist of devices on which the operator may run.
|
||||
XlaOpRegistrationBuilder& Device(absl::string_view devices);
|
||||
XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
|
||||
|
||||
|
@ -378,7 +378,7 @@ struct TensorAndDevice {
|
||||
};
|
||||
|
||||
// Tensors of some DataTypes cannot placed in device memory as feeds or
|
||||
// fetches. Validate against a whitelist of those known to work.
|
||||
// fetches. Validate against a allowlist of those known to work.
|
||||
bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
|
||||
// The mechanism for supporting feeds of device-backed Tensors requires
|
||||
// the _Arg kernel to be registered for the corresponding type (and that
|
||||
@ -391,7 +391,7 @@ bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
|
||||
// For now, we return true iff there are _Arg AND _Retval kernels for dtype on
|
||||
// the device. False negatives are okay, false positives would be bad.
|
||||
//
|
||||
// TODO(ashankar): Instead of a whitelist here, perhaps we could query
|
||||
// TODO(ashankar): Instead of a allowlist here, perhaps we could query
|
||||
// the kernel registry for _Arg and _Retval kernels instead.
|
||||
if (device_type == DEVICE_CPU) return true;
|
||||
if (device_type != DEVICE_GPU) return false;
|
||||
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
|
||||
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -23,7 +23,7 @@ namespace tensorflow {
|
||||
namespace data {
|
||||
// Registry for stateful ops that need to be used in dataset functions.
|
||||
// See below macro for usage details.
|
||||
class WhitelistedStatefulOpRegistry {
|
||||
class AllowlistedStatefulOpRegistry {
|
||||
public:
|
||||
Status Add(string op_name) {
|
||||
op_names_.insert(std::move(op_name));
|
||||
@ -37,29 +37,29 @@ class WhitelistedStatefulOpRegistry {
|
||||
|
||||
bool Contains(const string& op_name) { return op_names_.count(op_name); }
|
||||
|
||||
static WhitelistedStatefulOpRegistry* Global() {
|
||||
static auto* reg = new WhitelistedStatefulOpRegistry;
|
||||
static AllowlistedStatefulOpRegistry* Global() {
|
||||
static auto* reg = new AllowlistedStatefulOpRegistry;
|
||||
return reg;
|
||||
}
|
||||
|
||||
private:
|
||||
WhitelistedStatefulOpRegistry() = default;
|
||||
WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) =
|
||||
AllowlistedStatefulOpRegistry() = default;
|
||||
AllowlistedStatefulOpRegistry(AllowlistedStatefulOpRegistry const& copy) =
|
||||
delete;
|
||||
WhitelistedStatefulOpRegistry operator=(
|
||||
WhitelistedStatefulOpRegistry const& copy) = delete;
|
||||
AllowlistedStatefulOpRegistry operator=(
|
||||
AllowlistedStatefulOpRegistry const& copy) = delete;
|
||||
|
||||
std::unordered_set<string> op_names_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
||||
// Use this macro to whitelist an op that is marked stateful but needs to be
|
||||
// Use this macro to allowlist an op that is marked stateful but needs to be
|
||||
// used inside a map_fn in an input pipeline. This is only needed if you wish
|
||||
// to be able to checkpoint the state of the input pipeline. We currently
|
||||
// do not allow stateful ops to be defined inside of map_fns since it is not
|
||||
// possible to save their state.
|
||||
// Note that the state of the whitelisted ops inside functions will not be
|
||||
// Note that the state of the allowlisted ops inside functions will not be
|
||||
// saved during checkpointing, hence this should only be used if the op is
|
||||
// marked stateful for reasons like to avoid constant folding during graph
|
||||
// optimization but is not stateful.
|
||||
@ -73,9 +73,9 @@ class WhitelistedStatefulOpRegistry {
|
||||
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
|
||||
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
|
||||
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
|
||||
static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
|
||||
::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name)
|
||||
static ::tensorflow::Status allowlist_op##ctr TF_ATTRIBUTE_UNUSED = \
|
||||
::tensorflow::data::AllowlistedStatefulOpRegistry::Global()->Add(name)
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_
|
||||
|
@ -542,8 +542,8 @@ bool IsNumericType(const DataType dtype) {
|
||||
return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
|
||||
}
|
||||
|
||||
bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
|
||||
static const gtl::FlatSet<string>* const kOpTpeWhitelist =
|
||||
bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
|
||||
static const gtl::FlatSet<string>* const kOpTpeAllowlist =
|
||||
CHECK_NOTNULL((new gtl::FlatSet<string>{
|
||||
// Unary arithmetic ops
|
||||
"Floor",
|
||||
@ -589,7 +589,7 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) {
|
||||
"Fill",
|
||||
"Cast",
|
||||
}));
|
||||
return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end();
|
||||
return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
|
||||
}
|
||||
|
||||
// Negative shape size of '-1' represents unknown, while negative shape sizes
|
||||
@ -1441,7 +1441,7 @@ class SymbolicShapeRefiner {
|
||||
|
||||
// Due to the cost of running EvaluateNode(), we limit only to white listed
|
||||
// op types.
|
||||
if (!IsWhiteListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
|
||||
if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1008,7 +1008,7 @@ TEST_F(GraphPropertiesTest, IdentityPassingShape) {
|
||||
|
||||
TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
|
||||
// When using aggressive_shape_inference, we run EvaluateNode() for
|
||||
// whitelisted ops and small input / output tensors. For instance, Fill op is
|
||||
// allowlisted ops and small input / output tensors. For instance, Fill op is
|
||||
// evaluated and produces output tensor value if output tensor size is smal
|
||||
// (currently, fewer than 17 elements); otherwise we don't run EvaluateNode().
|
||||
// This is to avoid wasting time and memory for producing huge tensors (e.g.,
|
||||
|
@ -842,11 +842,11 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) {
|
||||
return AllowedDataTypes(*attr_def);
|
||||
}
|
||||
|
||||
Status ValidateLists(const gtl::FlatSet<string>& white_list,
|
||||
Status ValidateLists(const gtl::FlatSet<string>& allow_list,
|
||||
const gtl::FlatSet<string>& black_list,
|
||||
const gtl::FlatSet<string>& gray_list,
|
||||
const gtl::FlatSet<string>& clear_list) {
|
||||
std::vector<gtl::FlatSet<string>> lists{white_list, black_list, gray_list,
|
||||
std::vector<gtl::FlatSet<string>> lists{allow_list, black_list, gray_list,
|
||||
clear_list};
|
||||
std::multiset<string> counts;
|
||||
for (const auto& list : lists) {
|
||||
@ -973,25 +973,25 @@ class AutoMixedPrecisionImpl {
|
||||
void FindTensorListImplicitFloat32Edges(
|
||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||
std::vector<NodeTypeIdEdge>* implicit_data_edges) const;
|
||||
void AddWhitelistOps(absl::flat_hash_set<int>* white_set) const;
|
||||
void AddAllowlistOps(absl::flat_hash_set<int>* allow_set) const;
|
||||
void PropagateBlackFwdThroughClearAndGray(
|
||||
absl::flat_hash_set<int>* black_set) const;
|
||||
void ForceColorMatchBetweenTensorListOps(
|
||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||
absl::flat_hash_set<int>* white_set,
|
||||
absl::flat_hash_set<int>* allow_set,
|
||||
absl::flat_hash_set<int>* black_set) const;
|
||||
void AddClearAndGrayToWhiteIfBetweenWhite(
|
||||
void AddClearAndGrayToAllowIfBetweenAllow(
|
||||
const absl::flat_hash_set<int>& black_set,
|
||||
absl::flat_hash_set<int>* white_set) const;
|
||||
void PropagateWhiteThroughClear(const absl::flat_hash_set<int>& black_set,
|
||||
absl::flat_hash_set<int>* white_set) const;
|
||||
absl::flat_hash_set<int>* allow_set) const;
|
||||
void PropagateAllowThroughClear(const absl::flat_hash_set<int>& black_set,
|
||||
absl::flat_hash_set<int>* allow_set) const;
|
||||
Status ForceColorMatchOnRecurrentEdges(
|
||||
absl::flat_hash_set<int>* white_set) const;
|
||||
void MakeCastsWhiteIfAllOutputsWhite(
|
||||
absl::flat_hash_set<int>* white_set) const;
|
||||
absl::flat_hash_set<int>* allow_set) const;
|
||||
void MakeCastsAllowIfAllOutputsAllow(
|
||||
absl::flat_hash_set<int>* allow_set) const;
|
||||
NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16,
|
||||
const string& device) const;
|
||||
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& white_set);
|
||||
Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set<int>& allow_set);
|
||||
|
||||
VirtualPlacer virtual_placer_;
|
||||
std::unordered_set<string> nodes_to_preserve_;
|
||||
@ -1005,7 +1005,7 @@ class AutoMixedPrecisionImpl {
|
||||
GraphTypeTopologyView graph_type_view_;
|
||||
bool force_all_fp16_;
|
||||
AutoMixedPrecisionMode mode_;
|
||||
gtl::FlatSet<string> f16_whitelist_;
|
||||
gtl::FlatSet<string> f16_allowlist_;
|
||||
gtl::FlatSet<string> f16_blacklist_;
|
||||
gtl::FlatSet<string> f16_graylist_;
|
||||
gtl::FlatSet<string> f16_clearlist_;
|
||||
@ -1079,8 +1079,8 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) {
|
||||
f.open(fname.c_str(), std::fstream::out);
|
||||
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||
get_mixed_precision_lists();
|
||||
f << "WhiteList:\n";
|
||||
for (const auto& x : mp_lists->WhiteList()) {
|
||||
f << "AllowList:\n";
|
||||
for (const auto& x : mp_lists->AllowList()) {
|
||||
f << x << "\n";
|
||||
}
|
||||
f << "\nBlackList:\n";
|
||||
@ -1254,11 +1254,11 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
|
||||
std::unique_ptr<AutoMixedPrecisionLists> mp_lists =
|
||||
get_mixed_precision_lists();
|
||||
f16_whitelist_ = mp_lists->WhiteList();
|
||||
f16_allowlist_ = mp_lists->AllowList();
|
||||
f16_blacklist_ = mp_lists->BlackList();
|
||||
f16_graylist_ = mp_lists->GrayList();
|
||||
f16_clearlist_ = mp_lists->ClearList();
|
||||
TF_RETURN_IF_ERROR(ValidateLists(f16_whitelist_, f16_blacklist_,
|
||||
TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_blacklist_,
|
||||
f16_graylist_, f16_clearlist_));
|
||||
|
||||
size_t timestamp = Env::Default()->NowMicros() / 1000;
|
||||
@ -1316,8 +1316,8 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
// boundaries between f16/non-f16 nodes.
|
||||
|
||||
// The algorithm for deciding which nodes to change to f16 is as follows:
|
||||
// 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set.
|
||||
// This is done under the assumption that whitelist ops are always
|
||||
// 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set.
|
||||
// This is done under the assumption that allowlist ops are always
|
||||
// numerically-safe in f16 and that they are the most important ops for
|
||||
// improving performance.
|
||||
// 2) Add nodes to the black_set iff they are numerically-dangerous (aka
|
||||
@ -1329,20 +1329,20 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
// numerical accuracy of the model.
|
||||
// 3) For all remaining nodes that are not considered dangerous (greylist
|
||||
// and clearlist ops), find those that are between (i.e., both upstream
|
||||
// and downstream of) white nodes, and add them to the white_set.
|
||||
// This is done to avoid unnecessary casts between whitelist ops.
|
||||
// 4) For all remaining clearlist nodes, add them to the white_set if they are
|
||||
// connected to a node in the white_set via other clearlist nodes.
|
||||
// This is done to increase the number of ops in the white_set without
|
||||
// and downstream of) allow nodes, and add them to the allow_set.
|
||||
// This is done to avoid unnecessary casts between allowlist ops.
|
||||
// 4) For all remaining clearlist nodes, add them to the allow_set if they are
|
||||
// connected to a node in the allow_set via other clearlist nodes.
|
||||
// This is done to increase the number of ops in the allow_set without
|
||||
// affecting numerical stability.
|
||||
|
||||
absl::flat_hash_set<int> white_set;
|
||||
VLOG(2) << "Beginning pass 1 to add whitelist ops";
|
||||
AddWhitelistOps(&white_set);
|
||||
absl::flat_hash_set<int> allow_set;
|
||||
VLOG(2) << "Beginning pass 1 to add allowlist ops";
|
||||
AddAllowlistOps(&allow_set);
|
||||
VLOG(2) << "Finished pass 1";
|
||||
|
||||
if (white_set.empty()) {
|
||||
LOG(INFO) << "No whitelist ops found, nothing to do";
|
||||
if (allow_set.empty()) {
|
||||
LOG(INFO) << "No allowlist ops found, nothing to do";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1353,33 +1353,33 @@ Status AutoMixedPrecisionImpl::Optimize() {
|
||||
|
||||
VLOG(2) << "Forcing color match between data structure ops";
|
||||
for (const auto& cluster : tensor_list_clusters) {
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &white_set, &black_set);
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
|
||||
}
|
||||
|
||||
VLOG(2) << "Beginning pass 3 to set clear and gray nodes to white if they "
|
||||
"are between white ops";
|
||||
AddClearAndGrayToWhiteIfBetweenWhite(black_set, &white_set);
|
||||
VLOG(2) << "Beginning pass 3 to set clear and gray nodes to allow if they "
|
||||
"are between allow ops";
|
||||
AddClearAndGrayToAllowIfBetweenAllow(black_set, &allow_set);
|
||||
VLOG(2) << "Finished pass 3";
|
||||
|
||||
VLOG(2) << "Beginning pass 4 to propagate white from white nodes through "
|
||||
VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through "
|
||||
"clearlist ops";
|
||||
PropagateWhiteThroughClear(black_set, &white_set);
|
||||
PropagateAllowThroughClear(black_set, &allow_set);
|
||||
VLOG(2) << "Finished pass 4";
|
||||
|
||||
VLOG(2) << "Forcing color match between data structure ops";
|
||||
for (const auto& cluster : tensor_list_clusters) {
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &white_set, &black_set);
|
||||
ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set);
|
||||
}
|
||||
|
||||
VLOG(2) << "Forcing color match on loop edges";
|
||||
TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&white_set));
|
||||
TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set));
|
||||
|
||||
VLOG(2) << "Finding existing casts that can be made white";
|
||||
MakeCastsWhiteIfAllOutputsWhite(&white_set);
|
||||
VLOG(2) << "Finding existing casts that can be made allow";
|
||||
MakeCastsAllowIfAllOutputsAllow(&allow_set);
|
||||
|
||||
VLOG(2) << "Beginning final pass to change type attributes and insert Cast "
|
||||
"ops at paint boundaries";
|
||||
TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(white_set));
|
||||
TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set));
|
||||
VLOG(2) << "Finished final pass";
|
||||
|
||||
TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp));
|
||||
@ -1516,19 +1516,19 @@ void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges(
|
||||
}
|
||||
}
|
||||
|
||||
void AutoMixedPrecisionImpl::AddWhitelistOps(
|
||||
absl::flat_hash_set<int>* white_set) const {
|
||||
// Add whitelisted ops to white_set.
|
||||
void AutoMixedPrecisionImpl::AddAllowlistOps(
|
||||
absl::flat_hash_set<int>* allow_set) const {
|
||||
// Add allowlisted ops to allow_set.
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!ShouldProcess(*root.node)) continue;
|
||||
bool force_white = force_all_fp16_ && CanForceFP16(*root.node);
|
||||
if (f16_whitelist_.count(root.node->op()) || force_white) {
|
||||
bool inserted = white_set->insert(root_idx).second;
|
||||
bool force_allow = force_all_fp16_ && CanForceFP16(*root.node);
|
||||
if (f16_allowlist_.count(root.node->op()) || force_allow) {
|
||||
bool inserted = allow_set->insert(root_idx).second;
|
||||
if (VLOG_IS_ON(2) && inserted) {
|
||||
VLOG(2) << "Painting type " << root.type_attr.DebugString()
|
||||
<< " of node " << root.node->name() << " WHITE because its op "
|
||||
<< root.node->op() << " is on the whitelist";
|
||||
<< " of node " << root.node->name() << " ALLOW because its op "
|
||||
<< root.node->op() << " is on the allowlist";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1537,8 +1537,8 @@ void AutoMixedPrecisionImpl::AddWhitelistOps(
|
||||
// Adds nodes to black_set iff they are on the blacklist or they are on a
|
||||
// forward path from a blacklist node to a black/gray node (including the node
|
||||
// at the end of the path) through clear and gray nodes.
|
||||
// E.g., black -> gray -> clear -> gray -> clear -> white -> gray
|
||||
// becomes: black -> black -> black -> black -> clear -> white -> gray.
|
||||
// E.g., black -> gray -> clear -> gray -> clear -> allow -> gray
|
||||
// becomes: black -> black -> black -> black -> clear -> allow -> gray.
|
||||
void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||
absl::flat_hash_set<int>* black_set) const {
|
||||
if (force_all_fp16_) return;
|
||||
@ -1588,14 +1588,14 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray(
|
||||
}
|
||||
}
|
||||
|
||||
void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
||||
void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow(
|
||||
const absl::flat_hash_set<int>& black_set,
|
||||
absl::flat_hash_set<int>* white_set) const {
|
||||
// Find clear/graylist ops that are downstream of white ops.
|
||||
absl::flat_hash_set<int> downstream_of_white_set;
|
||||
absl::flat_hash_set<int>* allow_set) const {
|
||||
// Find clear/graylist ops that are downstream of allow ops.
|
||||
absl::flat_hash_set<int> downstream_of_allow_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!ShouldProcess(*root.node) || !f16_whitelist_.count(root.node->op())) {
|
||||
if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(
|
||||
@ -1603,8 +1603,8 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
return idx == root_idx ||
|
||||
(!downstream_of_white_set.count(idx) &&
|
||||
!f16_whitelist_.count(item.node->op()) &&
|
||||
(!downstream_of_allow_set.count(idx) &&
|
||||
!f16_allowlist_.count(item.node->op()) &&
|
||||
!black_set.count(idx) && ShouldProcess(*item.node) &&
|
||||
// TODO(benbarsdell): Consider allowing propagation through
|
||||
// ops that are already float16 in order to reduce the number
|
||||
@ -1614,45 +1614,45 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite(
|
||||
f16_graylist_.count(item.node->op())));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder(
|
||||
[&](int idx) { downstream_of_white_set.insert(idx); }));
|
||||
[&](int idx) { downstream_of_allow_set.insert(idx); }));
|
||||
}
|
||||
|
||||
// Set nodes that are both downstream and upstream of white ops to white.
|
||||
absl::flat_hash_set<int> upstream_of_white_set;
|
||||
// Set nodes that are both downstream and upstream of allow ops to allow.
|
||||
absl::flat_hash_set<int> upstream_of_allow_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) ||
|
||||
!f16_whitelist_.count(root.node->op())) {
|
||||
if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) ||
|
||||
!f16_allowlist_.count(root.node->op())) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(
|
||||
graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs,
|
||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||
return idx == root_idx || (!upstream_of_white_set.count(idx) &&
|
||||
downstream_of_white_set.count(idx));
|
||||
return idx == root_idx || (!upstream_of_allow_set.count(idx) &&
|
||||
downstream_of_allow_set.count(idx));
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||
upstream_of_white_set.insert(idx);
|
||||
bool inserted = white_set->insert(idx).second;
|
||||
upstream_of_allow_set.insert(idx);
|
||||
bool inserted = allow_set->insert(idx).second;
|
||||
if (VLOG_IS_ON(2) && inserted) {
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
VLOG(2) << "Painting type " << item.type_attr.DebugString()
|
||||
<< " of " << item.node->op() << " node "
|
||||
<< item.node->name() << " WHITE";
|
||||
<< item.node->name() << " ALLOW";
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
||||
void AutoMixedPrecisionImpl::PropagateAllowThroughClear(
|
||||
const absl::flat_hash_set<int>& black_set,
|
||||
absl::flat_hash_set<int>* white_set) const {
|
||||
// Propagate white from white nodes through clearlist ops.
|
||||
absl::flat_hash_set<int>* allow_set) const {
|
||||
// Propagate allow from allow nodes through clearlist ops.
|
||||
absl::flat_hash_set<int> clear_prop_set;
|
||||
for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) {
|
||||
const NodeTypeId& root = *graph_type_view_.GetNode(root_idx);
|
||||
if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) ||
|
||||
!white_set->count(root_idx)) {
|
||||
!allow_set->count(root_idx)) {
|
||||
continue;
|
||||
}
|
||||
DfsTypeTraversal(
|
||||
@ -1661,7 +1661,7 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
||||
DfsTypePredicates::Enter([&](int idx) -> bool {
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
return idx == root_idx ||
|
||||
(!white_set->count(idx) && !black_set.count(idx) &&
|
||||
(!allow_set->count(idx) && !black_set.count(idx) &&
|
||||
ShouldProcess(*item.node) && IsFloat32(item) &&
|
||||
SupportsF16(item) &&
|
||||
(f16_clearlist_.count(item.node->op())) &&
|
||||
@ -1673,30 +1673,30 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear(
|
||||
}),
|
||||
DfsTypeCallbacks::PreOrder([&](int idx) {
|
||||
clear_prop_set.insert(idx);
|
||||
bool inserted = white_set->insert(idx).second;
|
||||
bool inserted = allow_set->insert(idx).second;
|
||||
if (VLOG_IS_ON(2) && inserted) {
|
||||
const NodeTypeId& item = *graph_type_view_.GetNode(idx);
|
||||
VLOG(2) << "Painting type " << item.type_attr.DebugString()
|
||||
<< " of " << item.node->op() << " node "
|
||||
<< item.node->name() << " WHITE";
|
||||
<< item.node->name() << " ALLOW";
|
||||
}
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Forces NextIteration nodes and their output Merge node(s) to have the same
|
||||
// color. Specifically, it removes them all from white_set if any of the Merge
|
||||
// nodes is not in white_set, otherwise it adds the NextIteration node to
|
||||
// white_set.
|
||||
// color. Specifically, it removes them all from allow_set if any of the Merge
|
||||
// nodes is not in allow_set, otherwise it adds the NextIteration node to
|
||||
// allow_set.
|
||||
Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||
absl::flat_hash_set<int>* white_set) const {
|
||||
absl::flat_hash_set<int>* allow_set) const {
|
||||
for (const NodeDef& node : graph_->node()) {
|
||||
if (node.op() == "NextIteration") {
|
||||
GraphView::OutputPort output_port(&node, 0);
|
||||
const auto& fanout = graph_view_.GetFanout(output_port);
|
||||
std::vector<int> merge_idxs;
|
||||
merge_idxs.reserve(fanout.size());
|
||||
bool any_merge_is_not_white = false;
|
||||
bool any_merge_is_not_allow = false;
|
||||
for (const auto& output : fanout) {
|
||||
const NodeDef& merge_node = *output.node;
|
||||
if (merge_node.op() != "Merge") {
|
||||
@ -1712,8 +1712,8 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||
}
|
||||
int merge_idx = maybe_merge_idx.value();
|
||||
merge_idxs.push_back(merge_idx);
|
||||
any_merge_is_not_white =
|
||||
any_merge_is_not_white || !white_set->count(merge_idx);
|
||||
any_merge_is_not_allow =
|
||||
any_merge_is_not_allow || !allow_set->count(merge_idx);
|
||||
}
|
||||
const absl::optional<int> maybe_nextiter_idx =
|
||||
graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T"));
|
||||
@ -1722,9 +1722,9 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||
node.name(), " not found in graph view");
|
||||
}
|
||||
int nextiter_idx = maybe_nextiter_idx.value();
|
||||
if (any_merge_is_not_white) {
|
||||
if (any_merge_is_not_allow) {
|
||||
for (int merge_idx : merge_idxs) {
|
||||
if (white_set->erase(merge_idx)) {
|
||||
if (allow_set->erase(merge_idx)) {
|
||||
VLOG(2) << "Painting type T of Merge node "
|
||||
<< graph_type_view_.GetNode(merge_idx)->node->name()
|
||||
<< " BLACK to match the color of its sibling Merge nodes "
|
||||
@ -1732,14 +1732,14 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||
<< node.name();
|
||||
}
|
||||
}
|
||||
if (white_set->erase(nextiter_idx)) {
|
||||
if (allow_set->erase(nextiter_idx)) {
|
||||
VLOG(2) << "Painting type T of NextIteration node " << node.name()
|
||||
<< " BLACK to match the color of its output Merge node(s)";
|
||||
}
|
||||
} else {
|
||||
if (white_set->insert(nextiter_idx).second) {
|
||||
if (allow_set->insert(nextiter_idx).second) {
|
||||
VLOG(2) << "Painting type T of NextIteration node " << node.name()
|
||||
<< " WHITE to match the color of its output Merge node(s)";
|
||||
<< " ALLOW to match the color of its output Merge node(s)";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1750,10 +1750,10 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges(
|
||||
// Forces all of the given Tensor List nodes into the same color set.
|
||||
void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
|
||||
const absl::flat_hash_set<const NodeDef*>& tensor_list_nodes,
|
||||
absl::flat_hash_set<int>* white_set,
|
||||
absl::flat_hash_set<int>* allow_set,
|
||||
absl::flat_hash_set<int>* black_set) const {
|
||||
bool any_black = false;
|
||||
bool any_white = false;
|
||||
bool any_allow = false;
|
||||
std::vector<int> node_type_idxs;
|
||||
node_type_idxs.reserve(tensor_list_nodes.size());
|
||||
for (const NodeDef* node : tensor_list_nodes) {
|
||||
@ -1769,23 +1769,23 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps(
|
||||
if (black_set->count(node_type_idx)) {
|
||||
any_black = true;
|
||||
break;
|
||||
} else if (white_set->count(node_type_idx)) {
|
||||
any_white = true;
|
||||
} else if (allow_set->count(node_type_idx)) {
|
||||
any_allow = true;
|
||||
}
|
||||
}
|
||||
if (!any_black && !any_white) return;
|
||||
if (!any_black && !any_allow) return;
|
||||
for (int node_type_idx : node_type_idxs) {
|
||||
const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx);
|
||||
VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of "
|
||||
<< node_type.node->op() << " node " << node_type.node->name() << " "
|
||||
<< (any_black ? "BLACK" : "WHITE")
|
||||
<< (any_black ? "BLACK" : "ALLOW")
|
||||
<< " because at least one of its siblings is "
|
||||
<< (any_black ? "BLACK" : "WHITE");
|
||||
<< (any_black ? "BLACK" : "ALLOW");
|
||||
if (any_black) {
|
||||
white_set->erase(node_type_idx);
|
||||
allow_set->erase(node_type_idx);
|
||||
black_set->insert(node_type_idx);
|
||||
} else {
|
||||
white_set->insert(node_type_idx);
|
||||
allow_set->insert(node_type_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1807,10 +1807,10 @@ bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable(
|
||||
return false;
|
||||
}
|
||||
|
||||
// This adds existing Cast nodes to white_set if all of their outputs are white,
|
||||
// This adds existing Cast nodes to allow_set if all of their outputs are allow,
|
||||
// avoiding the need to add a new Cast node after an existing Cast.
|
||||
void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
||||
absl::flat_hash_set<int>* white_set) const {
|
||||
void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow(
|
||||
absl::flat_hash_set<int>* allow_set) const {
|
||||
int num_nodes_preop = graph_->node_size();
|
||||
for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) {
|
||||
NodeDef* node = graph_->mutable_node(node_idx);
|
||||
@ -1818,7 +1818,7 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
||||
if (node->op() != "Cast" || !IsFloat32(node_type)) {
|
||||
continue;
|
||||
}
|
||||
bool all_fanouts_white = true;
|
||||
bool all_fanouts_allow = true;
|
||||
MutableGraphView::OutputPort src(node, 0);
|
||||
const auto& fanout = graph_view_.GetFanout(src);
|
||||
for (const MutableGraphView::InputPort& dst : fanout) {
|
||||
@ -1830,13 +1830,13 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
||||
<< "Type attribute " << dst_type_attr.DebugString() << " of node "
|
||||
<< dst.node->name() << " not found in graph view";
|
||||
int dst_type_idx = maybe_dst_type_idx.value();
|
||||
bool dst_is_white = white_set->count(dst_type_idx);
|
||||
if (!dst_is_white) {
|
||||
all_fanouts_white = false;
|
||||
bool dst_is_allow = allow_set->count(dst_type_idx);
|
||||
if (!dst_is_allow) {
|
||||
all_fanouts_allow = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!fanout.empty() && all_fanouts_white) {
|
||||
if (!fanout.empty() && all_fanouts_allow) {
|
||||
const absl::optional<int> maybe_node_type_idx =
|
||||
graph_type_view_.GetNodeIndex(node_type);
|
||||
DCHECK(maybe_node_type_idx.has_value())
|
||||
@ -1844,16 +1844,16 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite(
|
||||
<< " of node " << node_type.node->name()
|
||||
<< " not found in graph view";
|
||||
int node_type_idx = maybe_node_type_idx.value();
|
||||
white_set->insert(node_type_idx);
|
||||
allow_set->insert(node_type_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Changes all white-painted type attributes to DT_HALF or DT_BFLOAT16, and
|
||||
// Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and
|
||||
// inserts Cast nodes at node outputs for all edges that connect
|
||||
// white-painted <-> non-white-painted type attributes.
|
||||
// allow-painted <-> non-allow-painted type attributes.
|
||||
Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
const absl::flat_hash_set<int>& white_set) {
|
||||
const absl::flat_hash_set<int>& allow_set) {
|
||||
int num_nodes_changed = 0;
|
||||
int num_nonvar_casts_to_f16 = 0;
|
||||
int num_nodes_preop = graph_->node_size();
|
||||
@ -1869,8 +1869,8 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
}
|
||||
int node_type_idx = maybe_node_type_idx.value();
|
||||
if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue;
|
||||
bool src_is_white = white_set.count(node_type_idx);
|
||||
if (src_is_white) {
|
||||
bool src_is_allow = allow_set.count(node_type_idx);
|
||||
if (src_is_allow) {
|
||||
VLOG(1) << "Changing type " << type_attr.DebugString() << " of "
|
||||
<< node->op() << " node " << node->name() << " to "
|
||||
<< DataTypeString(target_dtype_);
|
||||
@ -1896,10 +1896,10 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts(
|
||||
" not found in graph view");
|
||||
}
|
||||
int dst_type_idx = maybe_dst_type_idx.value();
|
||||
bool dst_is_white = white_set.count(dst_type_idx);
|
||||
if (src_is_white != dst_is_white) {
|
||||
bool dst_is_allow = allow_set.count(dst_type_idx);
|
||||
if (src_is_allow != dst_is_allow) {
|
||||
if (!added_cast_node) {
|
||||
bool to_f16 = dst_is_white;
|
||||
bool to_f16 = dst_is_allow;
|
||||
VLOG(1) << "Inserting cast to "
|
||||
<< (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT")
|
||||
<< " at " << src.node->op() << " " << src.node->name()
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Represents the four lists of ops: the white list, gray list, black list, and
|
||||
// Represents the four lists of ops: the allow list, gray list, black list, and
|
||||
// clear list. These lists determine which ops are converted to fp16/bf16
|
||||
// (referred to as 'f16' for short) and which ops stay as fp32.
|
||||
class AutoMixedPrecisionLists {
|
||||
@ -33,7 +33,7 @@ class AutoMixedPrecisionLists {
|
||||
// Returns the set of ops that are considered numerically-safe (for execution
|
||||
// in f16), performance-critical, and can run in f16. These ops are always
|
||||
// converted to f16.
|
||||
virtual gtl::FlatSet<string> WhiteList() = 0;
|
||||
virtual gtl::FlatSet<string> AllowList() = 0;
|
||||
// Returns the set of ops that can run in f16 and are considered numerically-
|
||||
// safe (for execution in f16), but which may be made unsafe by an upstream
|
||||
// blacklist op.
|
||||
@ -51,8 +51,10 @@ class AutoMixedPrecisionLists {
|
||||
protected:
|
||||
// Adds or removes ops from list if certain environmental variables are set.
|
||||
static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
|
||||
CHECK(list_name == "WHITELIST" || list_name == "GRAYLIST" || // Crash OK.
|
||||
list_name == "BLACKLIST" || list_name == "CLEARLIST");
|
||||
CHECK(list_name == "ALLOWLIST" || list_name == "GRAYLIST" || // Crash OK.
|
||||
list_name == "BLACKLIST" || list_name == "CLEARLIST" ||
|
||||
// TODO(reedwm): for bkwds compat; remove when no longer necessary:
|
||||
list_name == "WHITELIST");
|
||||
string add_env_var =
|
||||
"TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
|
||||
string remove_env_var =
|
||||
@ -104,7 +106,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
|
||||
: cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
|
||||
|
||||
gtl::FlatSet<string> WhiteList() override {
|
||||
gtl::FlatSet<string> AllowList() override {
|
||||
auto list = gtl::FlatSet<string>{
|
||||
"BlockLSTM",
|
||||
"BlockLSTMV2",
|
||||
@ -144,7 +146,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
|
||||
list.insert("Conv3DBackpropInput");
|
||||
list.insert("Conv3DBackpropInputV2");
|
||||
}
|
||||
UpdateList("ALLOWLIST", &list);
|
||||
// For backwards compatibility, keeping the original env variable here.
|
||||
// TODO(reedwm): This should be removed if we don't have active users.
|
||||
UpdateList("WHITELIST", &list);
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
@ -338,8 +344,8 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||
AutoMixedPrecisionListsMkl() {}
|
||||
|
||||
// Only ops which are supported by MKL in bfloat16 should be added to the
|
||||
// white list, gray list, or clear list.
|
||||
gtl::FlatSet<string> WhiteList() override {
|
||||
// allow list, gray list, or clear list.
|
||||
gtl::FlatSet<string> AllowList() override {
|
||||
auto list = gtl::FlatSet<string>{"Conv2D",
|
||||
"Conv2DBackpropFilter",
|
||||
"Conv2DBackpropInput",
|
||||
@ -353,7 +359,7 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2"};
|
||||
|
||||
UpdateList("WHITELIST", &list);
|
||||
UpdateList("ALLOWLIST", &list);
|
||||
return list;
|
||||
}
|
||||
|
||||
|
@ -169,10 +169,10 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
Output eye = ops::Const(s.WithOpName("eye"),
|
||||
GenerateIdentityMatrix<DT_FLOAT>(size, size));
|
||||
Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, eye);
|
||||
Output gry1 = test_op_factory(s.WithOpName("gry1"), wht1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, eye);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye);
|
||||
Output gry1 = test_op_factory(s.WithOpName("gry1"), allow1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, eye);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch1"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
@ -190,9 +190,9 @@ class AutoMixedPrecisionTest : public GrapplerTest {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(),
|
||||
DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch, feed);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
@ -247,8 +247,8 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
||||
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
|
||||
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||
@ -267,7 +267,7 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
||||
@ -288,8 +288,8 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
|
||||
Output gry2 = ops::Log(s.WithOpName("gry2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), gry2);
|
||||
Output blk2 = ops::SparseMatMul(s.WithOpName("blk2"), clr4, clr4);
|
||||
@ -314,7 +314,7 @@ TEST_F(AutoMixedPrecisionTest, Simple) {
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||
@ -335,10 +335,10 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), input);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr1, clr1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
|
||||
auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2});
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
|
||||
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -357,7 +357,7 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
@ -372,18 +372,18 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) {
|
||||
TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output blk1 = ops::Exp(s.WithOpName("blk1"), gry1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), blk1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht2);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2);
|
||||
Output blk2 = ops::Exp(s.WithOpName("blk2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"wht1", "clr2", "clr3"};
|
||||
item.fetch = {"allow1", "clr2", "clr3"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
|
||||
@ -396,12 +396,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||
@ -418,12 +418,13 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), input);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr1, clr1);
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), wht1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2").WithDevice(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
gry1, gry1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), wht2);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1);
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
|
||||
Output allow2 =
|
||||
ops::MatMul(s.WithOpName("allow2").WithDevice(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"),
|
||||
gry1, gry1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||
|
||||
GrapplerItem item;
|
||||
@ -441,9 +442,9 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) {
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -459,12 +460,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT);
|
||||
Output clr1 = ops::Identity(s.WithOpName("clr1"), var1);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, clr1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1);
|
||||
Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32});
|
||||
Output clr2 = ops::Identity(s.WithOpName("clr2"), input2);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), input, clr2);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1);
|
||||
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), wht2);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
|
||||
Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch1", "fetch2"};
|
||||
@ -485,10 +486,10 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) {
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch, feed);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
@ -507,22 +508,24 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16});
|
||||
Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0});
|
||||
Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0});
|
||||
Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1},
|
||||
"SAME", ops::Conv2D::DataFormat("NHWC"));
|
||||
Output allow1 =
|
||||
ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME",
|
||||
ops::Conv2D::DataFormat("NHWC"));
|
||||
auto fbn1_op =
|
||||
ops::FusedBatchNorm(s.WithOpName("fbn1"), wht1, scale, offset, mean,
|
||||
ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean,
|
||||
variance, ops::FusedBatchNorm::DataFormat("NHWC"));
|
||||
Output fbn1 = fbn1_op.y;
|
||||
Output fbn1_rs1 = fbn1_op.reserve_space_1;
|
||||
Output fbn1_rs2 = fbn1_op.reserve_space_2;
|
||||
Output bng1 = ops::FusedBatchNormGrad(
|
||||
s.WithOpName("bng1"), fbn1, wht1, scale, fbn1_rs1, fbn1_rs2,
|
||||
ops::FusedBatchNormGrad::DataFormat("NHWC"))
|
||||
s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1,
|
||||
fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC"))
|
||||
.x_backprop;
|
||||
Output gry1 = ops::Add(s.WithOpName("gry1"), fbn1, bng1);
|
||||
Output wht2 = ops::Conv2D(s.WithOpName("wht2"), gry1, weight, {1, 1, 1, 1},
|
||||
"SAME", ops::Conv2D::DataFormat("NHWC"));
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), wht2);
|
||||
Output allow2 =
|
||||
ops::Conv2D(s.WithOpName("allow2"), gry1, weight, {1, 1, 1, 1}, "SAME",
|
||||
ops::Conv2D::DataFormat("NHWC"));
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
@ -537,7 +540,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 3);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2");
|
||||
EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT);
|
||||
@ -545,7 +548,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
@ -558,13 +561,13 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) {
|
||||
TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {wht1, wht1, wht1});
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1});
|
||||
Output gry1 =
|
||||
ops::AddN(s.WithOpName("gry1"),
|
||||
{clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]});
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), wht2);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), allow2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
@ -580,12 +583,12 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) {
|
||||
EXPECT_EQ(type, DT_HALF);
|
||||
}
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
@ -599,8 +602,8 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), true, {32, 32});
|
||||
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), allow1);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
@ -617,7 +620,7 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) {
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 1);
|
||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL);
|
||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(tensors.size(), tensors_expected.size());
|
||||
@ -640,8 +643,8 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
|
||||
Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output;
|
||||
auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), swt1.output_true);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), gry1, gry1);
|
||||
Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), gry1, gry1);
|
||||
Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1);
|
||||
Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), ext1);
|
||||
// Add a second merge node from the same NextIteration node. This case arises
|
||||
@ -670,13 +673,13 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) {
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
// Note that mrg1 gets painted black because it is between blk1 and gry1. This
|
||||
// forces nxt1 and mrg2 to be painted black as well (they would otherwise be
|
||||
// painted white because they are clear and have a direct path to wht1).
|
||||
// painted allow because they are clear and have a direct path to allow1).
|
||||
EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT);
|
||||
@ -699,9 +702,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
||||
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
||||
auto tl1w1 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
auto tl1w2 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
|
||||
// Ensure that TensorListResize doesn't cause any problems.
|
||||
Output tl1rs =
|
||||
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
||||
@ -709,9 +712,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
||||
shape, DT_FLOAT)
|
||||
.item;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
auto tl1w3 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2);
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
|
||||
Output tl1r2 =
|
||||
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
||||
shape, DT_FLOAT)
|
||||
@ -742,11 +745,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) {
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
||||
@ -767,15 +770,16 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
auto tl1w1 =
|
||||
ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
auto tl1w2 =
|
||||
ops::TensorListPushBack(s.WithOpName("tl1w2"), tl1w1.output_handle, wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"),
|
||||
tl1w1.output_handle, allow1);
|
||||
Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"),
|
||||
tl1w2.output_handle, shape, DT_FLOAT)
|
||||
.tensor;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||
auto tl1w3 = ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, wht2);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
auto tl1w3 =
|
||||
ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2);
|
||||
Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"),
|
||||
tl1w3.output_handle, shape, DT_FLOAT)
|
||||
.tensor;
|
||||
@ -804,11 +808,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) {
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT);
|
||||
@ -826,19 +830,19 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
tensorflow::Input shape = {32};
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), wht1, shape);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape);
|
||||
Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle,
|
||||
shape, DT_FLOAT)
|
||||
.tensor;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||
|
||||
// This tests that a white-painted object node (tl2) will force an unpainted
|
||||
// client node (tl2w1) to be painted white as well. (Without the force, tl2w1
|
||||
// This tests that a allow-painted object node (tl2) will force an unpainted
|
||||
// client node (tl2w1) to be painted allow as well. (Without the force, tl2w1
|
||||
// would remain unpainted, producing an invalid graph).
|
||||
auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), wht1, shape);
|
||||
auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape);
|
||||
auto tl2w1 =
|
||||
ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input);
|
||||
|
||||
@ -856,11 +860,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
||||
|
||||
@ -878,12 +882,13 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
||||
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
||||
auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT);
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
Output tl1_tl2 =
|
||||
ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle});
|
||||
Output wht1_wht1 = ops::Stack(s.WithOpName("wht1_wht1"), {wht1, wht1});
|
||||
auto tl12w1 =
|
||||
ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2, wht1_wht1);
|
||||
Output allow1_allow1 =
|
||||
ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1});
|
||||
auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2,
|
||||
allow1_allow1);
|
||||
OutputList tl12w1_outputs =
|
||||
ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2)
|
||||
.output;
|
||||
@ -898,8 +903,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
||||
ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT)
|
||||
.tensor;
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl3r1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch1"};
|
||||
@ -915,8 +920,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||
@ -961,8 +966,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
||||
TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib));
|
||||
tensorflow::Input shape = {32, 32};
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32});
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1);
|
||||
auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT);
|
||||
auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, gry1);
|
||||
auto _gry1 = tensorflow::ops::AsNodeOut(s, gry1);
|
||||
@ -981,8 +986,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
||||
Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"),
|
||||
tl2w1.output_handle, shape, DT_FLOAT)
|
||||
.tensor;
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), tl1r1, tl2r1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch1"};
|
||||
@ -997,8 +1002,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) {
|
||||
|
||||
GraphView output_view(&output);
|
||||
const char* type_key = "element_dtype";
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF);
|
||||
@ -1031,8 +1036,8 @@ int GetCudaVersion(const Cluster& cluster) {
|
||||
TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32});
|
||||
Output wht1 = ops::BatchMatMul(s.WithOpName("wht1"), input, input);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1);
|
||||
Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input);
|
||||
Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch1"};
|
||||
@ -1049,10 +1054,10 @@ TEST_F(AutoMixedPrecisionTest, BatchMatMul) {
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
if (GetCudaVersion(*virtual_cluster_.get()) >= 9010) {
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size() + 2);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF);
|
||||
} else {
|
||||
EXPECT_EQ(output.node_size(), item.graph.node_size());
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT);
|
||||
}
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
@ -1187,8 +1192,8 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32});
|
||||
Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1);
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1);
|
||||
Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2);
|
||||
Output fetch = ops::Identity(s.WithOpName("fetch"), clr2);
|
||||
@ -1207,7 +1212,7 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) {
|
||||
GraphView output_view(&output);
|
||||
EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT);
|
||||
@ -1228,8 +1233,8 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
||||
Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1);
|
||||
Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1);
|
||||
Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2);
|
||||
Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1);
|
||||
Output blk2 = ops::Log(s.WithOpName("blk2"), clr3);
|
||||
Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2);
|
||||
Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4);
|
||||
@ -1254,7 +1259,7 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) {
|
||||
EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT);
|
||||
EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT);
|
||||
@ -1280,9 +1285,9 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
||||
Output idx3 = ops::Const(s.WithOpName("idx3"), 3);
|
||||
auto tl1w1 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input);
|
||||
Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input);
|
||||
Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input);
|
||||
auto tl1w2 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1);
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1);
|
||||
// Ensure that TensorListResize doesn't cause any problems.
|
||||
Output tl1rs =
|
||||
ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6);
|
||||
@ -1290,9 +1295,9 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
||||
shape, DT_FLOAT)
|
||||
.item;
|
||||
Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1);
|
||||
Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1);
|
||||
Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1);
|
||||
auto tl1w3 =
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2);
|
||||
ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2);
|
||||
Output tl1r2 =
|
||||
ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3,
|
||||
shape, DT_FLOAT)
|
||||
@ -1325,13 +1330,13 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) {
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(),
|
||||
DT_BFLOAT16);
|
||||
EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT);
|
||||
|
@ -1020,9 +1020,9 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Skips nodes that must be preserved except whitelisted nodes.
|
||||
// Skips nodes that must be preserved except allowlisted nodes.
|
||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
|
||||
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
|
||||
nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -1082,13 +1082,13 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node,
|
||||
}
|
||||
}
|
||||
|
||||
// Don't fold nodes that have no outgoing edges except whitelisted nodes.
|
||||
// Don't fold nodes that have no outgoing edges except allowlisted nodes.
|
||||
// Such nodes could be introduced by an earlier constant folding pass and are
|
||||
// preserved in case users want to fetch their values; re-processing them
|
||||
// would lead to an error of adding a duplicated node to graph.
|
||||
const auto& outputs = node_map_->GetOutputs(node.name());
|
||||
if (outputs.empty() &&
|
||||
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
|
||||
nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@ -3874,7 +3874,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
||||
GraphDef* optimized_graph) {
|
||||
graph_ = &item->graph;
|
||||
node_map_.reset(new NodeMap(graph_));
|
||||
nodes_whitelist_.clear();
|
||||
nodes_allowlist_.clear();
|
||||
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
|
||||
// has a single fanout, it would be rewritten as a constant with the same
|
||||
// node name, and therefore users are still able to fetch it. This is not
|
||||
@ -3885,7 +3885,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
||||
for (const auto& fetch : item->fetch) {
|
||||
const NodeDef* fetch_node = node_map_->GetNode(fetch);
|
||||
if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
|
||||
nodes_whitelist_.insert(fetch_node->name());
|
||||
nodes_allowlist_.insert(fetch_node->name());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -328,7 +328,7 @@ class ConstantFolding : public GraphOptimizer {
|
||||
std::unique_ptr<NodeMap> node_map_;
|
||||
std::unordered_set<string> nodes_to_preserve_;
|
||||
// TODO(rmlarsen): Could these be keyed on absl::string_view?
|
||||
absl::flat_hash_set<string> nodes_whitelist_;
|
||||
absl::flat_hash_set<string> nodes_allowlist_;
|
||||
absl::flat_hash_set<string> feed_nodes_;
|
||||
absl::flat_hash_map<string, bool> maybe_foldable_nodes_;
|
||||
bool has_fetch_;
|
||||
|
@ -232,16 +232,16 @@ Status IsFunctionStateful(const FunctionLibraryDefinition& library,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns whether an op has been whitelisted as stateless. Uses a heuristic to
|
||||
// whitelist source dataset ops which have been marked stateful due to
|
||||
// Returns whether an op has been allowlisted as stateless. Uses a heuristic to
|
||||
// allowlist source dataset ops which have been marked stateful due to
|
||||
// b/65524810. Also looks up the `op_def->name` in the global
|
||||
// `WhitelistedStatefulOpRegistry`.
|
||||
bool IsOpWhitelisted(const OpDef* op_def) {
|
||||
// `AllowlistedStatefulOpRegistry`.
|
||||
bool IsOpAllowlisted(const OpDef* op_def) {
|
||||
return (op_def->output_arg_size() == 1 &&
|
||||
op_def->output_arg(0).type() == DT_VARIANT &&
|
||||
(absl::EndsWith(op_def->name(), "Dataset") ||
|
||||
absl::EndsWith(op_def->name(), "DatasetV2"))) ||
|
||||
WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
|
||||
AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name());
|
||||
}
|
||||
|
||||
Status LookupFunction(const FunctionLibraryDefinition& lib_def,
|
||||
@ -389,7 +389,7 @@ Status IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
// TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
|
||||
// `LookUpOpDef` errors here.
|
||||
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
|
||||
IsOpWhitelisted(op_def) || !op_def->is_stateful() ||
|
||||
IsOpAllowlisted(op_def) || !op_def->is_stateful() ||
|
||||
op_def->name() == "Assert") {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -41478,13 +41478,13 @@ func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_ke
|
||||
// DatasetToGraphAttr is an optional argument to DatasetToGraph.
|
||||
type DatasetToGraphAttr func(optionalAttr)
|
||||
|
||||
// DatasetToGraphStatefulWhitelist sets the optional stateful_whitelist attribute to value.
|
||||
// DatasetToGraphStatefulAllowlist sets the optional stateful_allowlist attribute to value.
|
||||
// If not specified, defaults to <>
|
||||
//
|
||||
// REQUIRES: len(value) >= 0
|
||||
func DatasetToGraphStatefulWhitelist(value []string) DatasetToGraphAttr {
|
||||
func DatasetToGraphStatefulAllowlist(value []string) DatasetToGraphAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["stateful_whitelist"] = value
|
||||
m["stateful_allowlist"] = value
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -233,10 +233,10 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "whitelisted_flex_ops_lib",
|
||||
srcs = [
|
||||
"whitelisted_flex_ops.cc",
|
||||
"allowlisted_flex_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"whitelisted_flex_ops.h",
|
||||
"allowlisted_flex_ops.h",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) {
|
||||
static const std::set<std::string>* whitelisted_flex_ops =
|
||||
bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name) {
|
||||
static const std::set<std::string>* allowlisted_flex_ops =
|
||||
new std::set<std::string>({
|
||||
// go/keep-sorted start
|
||||
"Abort",
|
||||
@ -538,8 +538,8 @@ bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) {
|
||||
"_Send",
|
||||
// go/keep-sorted end
|
||||
});
|
||||
return whitelisted_flex_ops->find(tensorflow_op_name) !=
|
||||
whitelisted_flex_ops->end();
|
||||
return allowlisted_flex_ops->find(tensorflow_op_name) !=
|
||||
allowlisted_flex_ops->end();
|
||||
// Prevent lint error about this function being too long. This function
|
||||
// is a set of ops, and making it shorter won't help readbility.
|
||||
// NOLINTNEXTLINE
|
@ -12,24 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tflite {
|
||||
namespace flex {
|
||||
|
||||
// Whether the given op has been statically whitelisted for flex export.
|
||||
// Whether the given op has been statically allowlisted for flex export.
|
||||
//
|
||||
// This static whitelist is formed by the intersection of ops supported by
|
||||
// This static allowlist is formed by the intersection of ops supported by
|
||||
// TensorFlowMobile on both iOS and Android. As the converter is likely running
|
||||
// on a host that has the full suite of TensorFlow ops available, we use this
|
||||
// static whitelist to ensure compatibility when deploying to a mobile device.
|
||||
// TODO(b/118389105): Automate generation of the whitelisted flex ops.
|
||||
bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name);
|
||||
// static allowlist to ensure compatibility when deploying to a mobile device.
|
||||
// TODO(b/118389105): Automate generation of the allowlisted flex ops.
|
||||
bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name);
|
||||
|
||||
} // namespace flex
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_
|
@ -70,7 +70,7 @@ TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We maintain an op-version whitelist here to ensure we don't accept unintended
|
||||
// We maintain an op-version allowlist here to ensure we don't accept unintended
|
||||
// ops.
|
||||
bool CheckOpVersion(const TfLiteRegistration* registration) {
|
||||
switch (registration->builtin_code) {
|
||||
|
@ -18,7 +18,7 @@ namespace tflite {
|
||||
|
||||
const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
|
||||
R"(
|
||||
## Every Test can be whitelisted or blacklisted using a regexp on its test_id
|
||||
## Every Test can be allowlisted or blacklisted using a regexp on its test_id
|
||||
|
||||
## Test_id
|
||||
#
|
||||
@ -28,7 +28,7 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig =
|
||||
# the ordinal is the position in the list of parameters generated by the
|
||||
# cardinal product of all the different parameter sets
|
||||
|
||||
# Blacklist/Whitelist
|
||||
# Blacklist/Allowlist
|
||||
# To blacklist an element simply add - before the test_id regex
|
||||
|
||||
## Rules evaluation
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
|
||||
// NNAPI specific configuration for the validation whitelist.
|
||||
// NNAPI specific configuration for the validation allowlist.
|
||||
class NnapiAccelerationTestParams {
|
||||
public:
|
||||
// Content in nnapi_acceleration_test_list.cc.
|
||||
|
@ -4526,7 +4526,7 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
||||
} else {
|
||||
// If no accelerator is specified, only use NNAPI if an accelerator is
|
||||
// available. Any available accelerator will make the device_count larger
|
||||
// than 1. More sophisticated check and whitelisting can be added later.
|
||||
// than 1. More sophisticated check and allowlisting can be added later.
|
||||
uint32_t device_count = 0;
|
||||
RETURN_TFLITE_ERROR_IF_NN_ERROR(
|
||||
context, nnapi->ANeuralNetworks_getDeviceCount(&device_count),
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Accelerator whitelisting
|
||||
# Accelerator allowlisting
|
||||
|
||||
Experimental library and tools for determining whether an accelerator engine
|
||||
works well on a given device, and for a given model.
|
||||
@ -6,7 +6,7 @@ works well on a given device, and for a given model.
|
||||
## Platform-agnostic, Android-first
|
||||
|
||||
Android-focused, since the much smaller set of configurations on iOS means there
|
||||
is much less need for whitelisting on iOS.
|
||||
is much less need for allowlisting on iOS.
|
||||
|
||||
## Not just for TfLite
|
||||
|
||||
|
@ -32,7 +32,7 @@ package tflite.proto;
|
||||
// compatibility list entries have been developed for and what settings are used
|
||||
// for NNAPI.
|
||||
enum ExecutionPreference {
|
||||
// Match any selected preference. Whitelist (semantically - value is same as
|
||||
// Match any selected preference. Allowlist (semantically - value is same as
|
||||
// on input).
|
||||
ANY = 0;
|
||||
// Match low latency preference. Both compatibility list and input.
|
||||
|
@ -39,8 +39,8 @@ for `target_spec.supported_ops`:
|
||||
|
||||
* `TFLITE_BUILTINS` - Converts models using TensorFlow Lite builtin ops.
|
||||
* `SELECT_TF_OPS` - Converts models using TensorFlow ops. The exact subset of
|
||||
supported ops can be found in the whitelist at
|
||||
`lite/delegates/flex/whitelisted_flex_ops.cc`.
|
||||
supported ops can be found in the allowlist at
|
||||
`lite/delegates/flex/allowlisted_flex_ops.cc`.
|
||||
|
||||
Note: `target_spec.supported_ops` was previously `target_ops` in the Python API.
|
||||
|
||||
|
@ -27,8 +27,8 @@ public final class CompatibilityListTest {
|
||||
|
||||
@Test
|
||||
public void testBasic() throws Exception {
|
||||
try (CompatibilityList whitelist = new CompatibilityList()) {
|
||||
assertThat(whitelist.isDelegateSupportedOnThisDevice()).isTrue();
|
||||
try (CompatibilityList allowlist = new CompatibilityList()) {
|
||||
assertThat(allowlist.isDelegateSupportedOnThisDevice()).isTrue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
namespace tflite {
|
||||
|
||||
// Returns the test id to use to retrieve the acceleration configuration
|
||||
// in the acceleration whitelist.
|
||||
// in the acceleration allowlist.
|
||||
std::string GetCurrentTestId();
|
||||
|
||||
} // namespace tflite
|
||||
|
@ -51,14 +51,14 @@ struct SimpleConfig {
|
||||
|
||||
class ReadAccelerationConfigTest : public ::testing::Test {
|
||||
public:
|
||||
std::unordered_map<std::string, SimpleConfig> whitelist_;
|
||||
std::unordered_map<std::string, SimpleConfig> allowlist_;
|
||||
std::unordered_map<std::string, SimpleConfig> blacklist_;
|
||||
std::function<void(std::string, std::string, bool)> consumer_ =
|
||||
[this](std::string key, std::string value, bool is_blacklist) {
|
||||
if (is_blacklist) {
|
||||
blacklist_[key] = {value};
|
||||
} else {
|
||||
whitelist_[key] = {value};
|
||||
allowlist_[key] = {value};
|
||||
}
|
||||
};
|
||||
};
|
||||
@ -66,21 +66,21 @@ class ReadAccelerationConfigTest : public ::testing::Test {
|
||||
TEST_F(ReadAccelerationConfigTest, ReadsAKeyOnlyLine) {
|
||||
ReadAccelerationConfig("key", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_.find("key"), Not(Eq(whitelist_.end())));
|
||||
EXPECT_THAT(allowlist_.find("key"), Not(Eq(allowlist_.end())));
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ReadsABlacklistKeyOnlyLine) {
|
||||
ReadAccelerationConfig("-key", consumer_);
|
||||
|
||||
EXPECT_THAT(blacklist_.find("key"), Not(Eq(whitelist_.end())));
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_THAT(blacklist_.find("key"), Not(Eq(allowlist_.end())));
|
||||
EXPECT_TRUE(allowlist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ReadsAKeyValueLine) {
|
||||
ReadAccelerationConfig("key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key"].value, Eq("value"));
|
||||
EXPECT_THAT(allowlist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
@ -88,13 +88,13 @@ TEST_F(ReadAccelerationConfigTest, ReadsABlackListKeyValueLine) {
|
||||
ReadAccelerationConfig("-key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(allowlist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, KeysAreLeftTrimmed) {
|
||||
ReadAccelerationConfig(" key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key"].value, Eq("value"));
|
||||
EXPECT_THAT(allowlist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
@ -102,58 +102,58 @@ TEST_F(ReadAccelerationConfigTest, BlKeysAreLeftTrimmed) {
|
||||
ReadAccelerationConfig(" -key,value", consumer_);
|
||||
|
||||
EXPECT_THAT(blacklist_["key"].value, Eq("value"));
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(allowlist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, IgnoresCommentedLines) {
|
||||
ReadAccelerationConfig("#key,value", consumer_);
|
||||
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(allowlist_.empty());
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, CommentCanHaveTrailingBlanks) {
|
||||
ReadAccelerationConfig(" #key,value", consumer_);
|
||||
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(allowlist_.empty());
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, CommentsAreOnlyForTheFullLine) {
|
||||
ReadAccelerationConfig("key,value #comment", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key"].value, Eq("value #comment"));
|
||||
EXPECT_THAT(allowlist_["key"].value, Eq("value #comment"));
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, IgnoresEmptyLines) {
|
||||
ReadAccelerationConfig("", consumer_);
|
||||
|
||||
EXPECT_TRUE(whitelist_.empty());
|
||||
EXPECT_TRUE(allowlist_.empty());
|
||||
EXPECT_TRUE(blacklist_.empty());
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLines) {
|
||||
ReadAccelerationConfig("key1,value1\nkey2,value2\n-key3,value3", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key1"].value, Eq("value1"));
|
||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
||||
EXPECT_THAT(allowlist_["key1"].value, Eq("value1"));
|
||||
EXPECT_THAT(allowlist_["key2"].value, Eq("value2"));
|
||||
EXPECT_THAT(blacklist_["key3"].value, Eq("value3"));
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithCommentsAndSpaces) {
|
||||
ReadAccelerationConfig("key1,value1\n#comment\n\nkey2,value2", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key1"].value, Eq("value1"));
|
||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
||||
EXPECT_THAT(allowlist_["key1"].value, Eq("value1"));
|
||||
EXPECT_THAT(allowlist_["key2"].value, Eq("value2"));
|
||||
}
|
||||
|
||||
TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithMissingConfigValues) {
|
||||
ReadAccelerationConfig("key1\nkey2,value2\nkey3\nkey4,value4", consumer_);
|
||||
|
||||
EXPECT_THAT(whitelist_["key1"].value, Eq(""));
|
||||
EXPECT_THAT(whitelist_["key2"].value, Eq("value2"));
|
||||
EXPECT_THAT(whitelist_["key3"].value, Eq(""));
|
||||
EXPECT_THAT(whitelist_["key4"].value, Eq("value4"));
|
||||
EXPECT_THAT(allowlist_["key1"].value, Eq(""));
|
||||
EXPECT_THAT(allowlist_["key2"].value, Eq("value2"));
|
||||
EXPECT_THAT(allowlist_["key3"].value, Eq(""));
|
||||
EXPECT_THAT(allowlist_["key4"].value, Eq("value4"));
|
||||
}
|
||||
|
||||
TEST(GetAccelerationTestParam, LoadsTestConfig) {
|
||||
|
@ -27,7 +27,7 @@ import six
|
||||
|
||||
|
||||
def sanitize_xml(unsanitized):
|
||||
"""Uses a whitelist to avoid generating bad XML."""
|
||||
"""Uses a allowlist to avoid generating bad XML."""
|
||||
return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized))
|
||||
|
||||
|
||||
|
@ -794,7 +794,7 @@ TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) {
|
||||
EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM);
|
||||
EXPECT_EQ(key.custom_code(), "HashTableV2");
|
||||
EXPECT_EQ(key.version(), 1);
|
||||
// While HashTableV2 is excluded from the whitelisted flex op list, eventually
|
||||
// While HashTableV2 is excluded from the allowlisted flex op list, eventually
|
||||
// it won't be, and the following expectations will need to change as the op
|
||||
// is explicitly blacklisted due to lack of asset support.
|
||||
EXPECT_FALSE(key.is_flex_op());
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
|
||||
// TODO(ycling): Consider refactoring to extract the LSTM definition out of
|
||||
// graph_transformation module.
|
||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
||||
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
@ -2116,7 +2116,7 @@ bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
|
||||
return false;
|
||||
}
|
||||
// Check if we can find the `OpDef` for the TensorFlow op. If we can find
|
||||
// it and it has been whitelisted, export the op as an Flex op. Otherwise,
|
||||
// it and it has been allowlisted, export the op as an Flex op. Otherwise,
|
||||
// export it as a regular custom op.
|
||||
const tensorflow::OpDef* op_def = nullptr;
|
||||
if (!tensorflow::OpRegistry::Global()
|
||||
@ -2125,9 +2125,9 @@ bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!::tflite::flex::IsWhitelistedFlexOp(tensorflow_op_name)) {
|
||||
if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) {
|
||||
LOG(WARNING) << "Op " << tensorflow_op_name
|
||||
<< " is a valid TensorFlow op but has not been whitelisted for"
|
||||
<< " is a valid TensorFlow op but has not been allowlisted for"
|
||||
" the TensorFlow Lite flex op set.";
|
||||
return false;
|
||||
}
|
||||
|
@ -156,7 +156,7 @@ To do so, we utilize the `preprocess_coco_minival` Python binary as follows:
|
||||
bazel run //tensorflow/lite/tools/evaluation/tasks/coco_object_detection:preprocess_coco_minival -- \
|
||||
--images_folder=/path/to/val2014 \
|
||||
--instances_file=/path/to/instances_val2014.json \
|
||||
--whitelist_file=/path/to/minival_whitelist.txt \
|
||||
--allowlist_file=/path/to/minival_allowlist.txt \
|
||||
--output_folder=/path/to/output/folder
|
||||
|
||||
```
|
||||
|
@ -16,13 +16,13 @@
|
||||
|
||||
The 2014 validation images & annotations can be downloaded from:
|
||||
http://cocodataset.org/#download
|
||||
The minival image ID whitelist, a subset of the 2014 validation set, can be
|
||||
The minival image ID allowlist, a subset of the 2014 validation set, can be
|
||||
found here:
|
||||
https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_minival_ids.txt.
|
||||
|
||||
This script takes in the original images folder, instances JSON file and
|
||||
image ID whitelist and produces the following in the specified output folder:
|
||||
A subfolder for whitelisted images (images/), and a file (ground_truth.pbtxt)
|
||||
image ID allowlist and produces the following in the specified output folder:
|
||||
A subfolder for allowlisted images (images/), and a file (ground_truth.pbtxt)
|
||||
containing an instance of tflite::evaluation::ObjectDetectionGroundTruth.
|
||||
"""
|
||||
|
||||
@ -40,17 +40,17 @@ from tensorflow.lite.tools.evaluation.proto import evaluation_stages_pb2
|
||||
|
||||
|
||||
def _get_ground_truth_detections(instances_file,
|
||||
whitelist_file=None,
|
||||
allowlist_file=None,
|
||||
num_images=None):
|
||||
"""Processes the annotations JSON file and returns ground truth data corresponding to whitelisted image IDs.
|
||||
"""Processes the annotations JSON file and returns ground truth data corresponding to allowlisted image IDs.
|
||||
|
||||
Args:
|
||||
instances_file: COCO instances JSON file, usually named as
|
||||
instances_val20xx.json.
|
||||
whitelist_file: File containing COCO minival image IDs to whitelist for
|
||||
allowlist_file: File containing COCO minival image IDs to allowlist for
|
||||
evaluation, one per line.
|
||||
num_images: Number of whitelisted images to pre-process. First num_images
|
||||
are chosen based on sorted list of filenames. If None, all whitelisted
|
||||
num_images: Number of allowlisted images to pre-process. First num_images
|
||||
are chosen based on sorted list of filenames. If None, all allowlisted
|
||||
files are preprocessed.
|
||||
|
||||
Returns:
|
||||
@ -70,17 +70,17 @@ def _get_ground_truth_detections(instances_file,
|
||||
image_data = collections.OrderedDict()
|
||||
all_file_names = []
|
||||
|
||||
# Read whitelist.
|
||||
if whitelist_file is not None:
|
||||
with open(whitelist_file, 'r') as whitelist:
|
||||
image_id_whitelist = set([int(x) for x in whitelist.readlines()])
|
||||
# Read allowlist.
|
||||
if allowlist_file is not None:
|
||||
with open(allowlist_file, 'r') as allowlist:
|
||||
image_id_allowlist = set([int(x) for x in allowlist.readlines()])
|
||||
else:
|
||||
image_id_whitelist = [image['id'] for image in data_dict['images']]
|
||||
image_id_allowlist = [image['id'] for image in data_dict['images']]
|
||||
|
||||
# Get image names and dimensions.
|
||||
for image_dict in data_dict['images']:
|
||||
image_id = image_dict['id']
|
||||
if image_id not in image_id_whitelist:
|
||||
if image_id not in image_id_allowlist:
|
||||
continue
|
||||
image_data_dict = {}
|
||||
image_data_dict['id'] = image_dict['id']
|
||||
@ -99,7 +99,7 @@ def _get_ground_truth_detections(instances_file,
|
||||
# Get detected object annotations per image.
|
||||
for annotation_dict in data_dict['annotations']:
|
||||
image_id = annotation_dict['image_id']
|
||||
if image_id not in image_id_whitelist:
|
||||
if image_id not in image_id_allowlist:
|
||||
continue
|
||||
if image_id not in image_data:
|
||||
continue
|
||||
@ -133,7 +133,7 @@ def _dump_data(ground_truth_detections, images_folder_path, output_folder_path):
|
||||
"""Dumps images & data from ground-truth objects into output_folder_path.
|
||||
|
||||
The following are created in output_folder_path:
|
||||
images/: sub-folder for whitelisted validation images.
|
||||
images/: sub-folder for allowlisted validation images.
|
||||
ground_truth.pb: A binary proto file containing all ground-truth
|
||||
object-sets.
|
||||
|
||||
@ -193,14 +193,14 @@ def _parse_args():
|
||||
help='Full path of the input JSON file, like instances_val20xx.json.',
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
'--whitelist_file',
|
||||
'--allowlist_file',
|
||||
type=str,
|
||||
help='File with COCO image ids to preprocess, one on each line.',
|
||||
required=False)
|
||||
parser.add_argument(
|
||||
'--num_images',
|
||||
type=int,
|
||||
help='Number of whitelisted images to preprocess into the output folder.',
|
||||
help='Number of allowlisted images to preprocess into the output folder.',
|
||||
required=False)
|
||||
parser.add_argument(
|
||||
'--output_folder',
|
||||
@ -213,6 +213,6 @@ def _parse_args():
|
||||
if __name__ == '__main__':
|
||||
args = _parse_args()
|
||||
ground_truths = _get_ground_truth_detections(args.instances_file,
|
||||
args.whitelist_file,
|
||||
args.allowlist_file,
|
||||
args.num_images)
|
||||
_dump_data(ground_truths, args.images_folder, args.output_folder)
|
||||
|
@ -55,7 +55,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder,
|
||||
const TensorType& output_type, bool allow_float,
|
||||
ErrorReporter* error_reporter);
|
||||
|
||||
// Same as above, but enables only quantizing a whitelist of operations,
|
||||
// Same as above, but enables only quantizing an allowlist of operations,
|
||||
// specified by their operator output name.
|
||||
//
|
||||
// Note: This is a private API, subject to change.
|
||||
|
@ -158,6 +158,6 @@ _exported_dunders = set([
|
||||
'__monolithic_build__',
|
||||
])
|
||||
|
||||
# Expose symbols minus dunders, unless they are whitelisted above.
|
||||
# Expose symbols minus dunders, unless they are allowlisted above.
|
||||
# This is necessary to export our dunders.
|
||||
__all__ = [s for s in dir() if s in _exported_dunders or not s.startswith('_')]
|
||||
|
@ -177,7 +177,7 @@ class CallTreeTransformer(converter.Base):
|
||||
# Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
|
||||
# the normal mechanisms to bypass these literals because they are sensitive
|
||||
# to the frame they are being called from.
|
||||
# TODO(mdan): Generalize this to a "static whitelist" config.
|
||||
# TODO(mdan): Generalize this to a "static allowlist" config.
|
||||
if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
|
||||
global set_trace_warned
|
||||
if not set_trace_warned:
|
||||
|
@ -32,16 +32,16 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def whitelist(f):
|
||||
def allowlist(f):
|
||||
"""Helper that marks a callable as whtelitisted."""
|
||||
if 'whitelisted_module_for_testing' not in sys.modules:
|
||||
whitelisted_mod = imp.new_module('whitelisted_module_for_testing')
|
||||
sys.modules['whitelisted_module_for_testing'] = whitelisted_mod
|
||||
if 'allowlisted_module_for_testing' not in sys.modules:
|
||||
allowlisted_mod = imp.new_module('allowlisted_module_for_testing')
|
||||
sys.modules['allowlisted_module_for_testing'] = allowlisted_mod
|
||||
config.CONVERSION_RULES = (
|
||||
(config.DoNotConvert('whitelisted_module_for_testing'),) +
|
||||
(config.DoNotConvert('allowlisted_module_for_testing'),) +
|
||||
config.CONVERSION_RULES)
|
||||
|
||||
f.__module__ = 'whitelisted_module_for_testing'
|
||||
f.__module__ = 'allowlisted_module_for_testing'
|
||||
|
||||
|
||||
def is_inside_generated_code():
|
||||
|
@ -44,18 +44,18 @@ are handled correctly.
|
||||
|
||||
The following types of functions are not converted:
|
||||
|
||||
* functions already converted
|
||||
* functions defined in in a whitelisted module (see autograph/core/config.py)
|
||||
* non-Python functions (such as native bindings)
|
||||
* `print`, `pdb.set_trace`, `ipdb.set_trace`
|
||||
* most built-in functions (exceptions are listed in
|
||||
* functions already converted
|
||||
* functions defined in in a allowlisted module (see autograph/core/config.py)
|
||||
* non-Python functions (such as native bindings)
|
||||
* `print`, `pdb.set_trace`, `ipdb.set_trace`
|
||||
* most built-in functions (exceptions are listed in
|
||||
autograph/operators/py_builtins.py)
|
||||
* constructors
|
||||
* functions without source code attached (prints a warning)(see
|
||||
* constructors
|
||||
* functions without source code attached (prints a warning)(see
|
||||
[limitations](limitations.md))
|
||||
* generator functions (prints a warning)
|
||||
* iterator protocol methods (`__next__`, `__iter__`)
|
||||
* context manager methods (`__enter__`, `__exit__`)
|
||||
* generator functions (prints a warning)
|
||||
* iterator protocol methods (`__next__`, `__iter__`)
|
||||
* context manager methods (`__enter__`, `__exit__`)
|
||||
|
||||
When AutoGraph encounters a function that it cannot convert outside of this
|
||||
list, it prints a warning.
|
||||
|
@ -342,16 +342,16 @@ def converted_call(f,
|
||||
raise ValueError('either caller_fn_scope or options must have a value')
|
||||
options = caller_fn_scope.callopts
|
||||
|
||||
if conversion.is_in_whitelist_cache(f, options):
|
||||
logging.log(2, 'Whitelisted %s: from cache', f)
|
||||
if conversion.is_in_allowlist_cache(f, options):
|
||||
logging.log(2, 'Allowlisted %s: from cache', f)
|
||||
return _call_unconverted(f, args, kwargs, options, False)
|
||||
|
||||
if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
|
||||
logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
|
||||
logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
|
||||
return _call_unconverted(f, args, kwargs, options, False)
|
||||
|
||||
if is_autograph_artifact(f):
|
||||
logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f)
|
||||
logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
# If this is a partial, unwrap it and redo all the checks.
|
||||
@ -385,7 +385,7 @@ def converted_call(f,
|
||||
if conversion.is_unsupported(f):
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
if not options.user_requested and conversion.is_whitelisted(f):
|
||||
if not options.user_requested and conversion.is_allowlisted(f):
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
# internal_convert_user_code is for example turned off when issuing a dynamic
|
||||
@ -425,13 +425,13 @@ def converted_call(f,
|
||||
return _fall_back_unconverted(f, args, kwargs, options, e)
|
||||
|
||||
if not hasattr(target_entity, '__code__'):
|
||||
logging.log(2, 'Permanently whitelisted: %s: native binding',
|
||||
logging.log(2, 'Permanently allowed: %s: native binding',
|
||||
target_entity)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
elif (hasattr(target_entity.__code__, 'co_filename') and
|
||||
target_entity.__code__.co_filename == '<string>'):
|
||||
# TODO(mdan): __globals__['txt'] might work in Py3.
|
||||
logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
|
||||
logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)',
|
||||
target_entity)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
@ -462,7 +462,7 @@ def converted_call(f,
|
||||
def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
||||
"""Calls the original function without converting with AutoGraph."""
|
||||
if update_cache:
|
||||
conversion.cache_whitelisted(f, options)
|
||||
conversion.cache_allowlisted(f, options)
|
||||
|
||||
if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
|
||||
return f.__self__.call(args, kwargs)
|
||||
@ -482,7 +482,7 @@ def _fall_back_unconverted(f, args, kwargs, options, exc):
|
||||
'To silence this warning, decorate the function with'
|
||||
' @tf.autograph.experimental.do_not_convert')
|
||||
if isinstance(exc, errors.UnsupportedLanguageElementError):
|
||||
if not conversion.is_in_whitelist_cache(f, options):
|
||||
if not conversion.is_in_allowlist_cache(f, options):
|
||||
logging.warn(warning_template, f, '', exc)
|
||||
else:
|
||||
file_bug_message = (
|
||||
@ -516,7 +516,7 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
|
||||
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
||||
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
||||
specify.
|
||||
user_requested: bool, whether to ignore the conversion whitelist. See
|
||||
user_requested: bool, whether to ignore the conversion allowlist. See
|
||||
ConversionOptions.user_requested.
|
||||
|
||||
Returns:
|
||||
|
@ -203,14 +203,14 @@ class ApiTest(test.TestCase):
|
||||
z = x + y
|
||||
return z
|
||||
|
||||
test_method_whitelisted = api.do_not_convert(test_method)
|
||||
test_method_allowlisted = api.do_not_convert(test_method)
|
||||
|
||||
tc = TestClass()
|
||||
self.assertTrue(tf_inspect.ismethod(tc.test_method_whitelisted))
|
||||
self.assertTrue(tf_inspect.ismethod(tc.test_method_allowlisted))
|
||||
# Because the wrapped function is not generated, we can't preserve its
|
||||
# arg spec.
|
||||
self.assertEqual((),
|
||||
tuple(function_utils.fn_args(tc.test_method_whitelisted)))
|
||||
tuple(function_utils.fn_args(tc.test_method_allowlisted)))
|
||||
|
||||
def test_do_not_convert_callable_object(self):
|
||||
|
||||
@ -521,12 +521,12 @@ class ApiTest(test.TestCase):
|
||||
ag_logging.set_verbosity(0, False)
|
||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||
|
||||
def test_converted_call_partial_of_whitelisted_function(self):
|
||||
def test_converted_call_partial_of_allowlisted_function(self):
|
||||
|
||||
def test_fn(_):
|
||||
self.assertFalse(converter_testing.is_inside_generated_code())
|
||||
|
||||
converter_testing.whitelist(test_fn)
|
||||
converter_testing.allowlist(test_fn)
|
||||
api.converted_call(
|
||||
functools.partial(test_fn, None), (), None, options=DEFAULT_RECURSIVE)
|
||||
|
||||
@ -563,7 +563,7 @@ class ApiTest(test.TestCase):
|
||||
f, (g, constant_op.constant(1)), None, options=DEFAULT_RECURSIVE)
|
||||
self.assertEqual(self.evaluate(x), 1)
|
||||
|
||||
def test_converted_call_forced_when_explicitly_whitelisted(self):
|
||||
def test_converted_call_forced_when_explicitly_allowlisted(self):
|
||||
|
||||
@api.do_not_convert()
|
||||
def f(x):
|
||||
@ -606,7 +606,7 @@ class ApiTest(test.TestCase):
|
||||
self.assertIsNotNone(
|
||||
api.converted_call(f, (1, 2, 3, 4), None, options=opts))
|
||||
|
||||
def test_converted_call_whitelisted_method(self):
|
||||
def test_converted_call_allowlisted_method(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
@ -614,19 +614,19 @@ class ApiTest(test.TestCase):
|
||||
return converter_testing.is_inside_generated_code()
|
||||
|
||||
obj = TestClass()
|
||||
converter_testing.whitelist(obj.method.__func__)
|
||||
converter_testing.allowlist(obj.method.__func__)
|
||||
|
||||
self.assertFalse(
|
||||
api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE))
|
||||
|
||||
def test_converted_call_whitelisted_method_via_owner(self):
|
||||
def test_converted_call_allowlisted_method_via_owner(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def method(self):
|
||||
return converter_testing.is_inside_generated_code()
|
||||
|
||||
converter_testing.whitelist(TestClass)
|
||||
converter_testing.allowlist(TestClass)
|
||||
|
||||
obj = TestClass()
|
||||
self.assertFalse(
|
||||
@ -852,7 +852,7 @@ class ApiTest(test.TestCase):
|
||||
# invocation would fail.
|
||||
self.assertEqual(self.evaluate(call_in_default_context()), 1)
|
||||
|
||||
def test_converted_call_caching_of_whitelisted_bound_methods(self):
|
||||
def test_converted_call_caching_of_allowlisted_bound_methods(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
@ -863,7 +863,7 @@ class ApiTest(test.TestCase):
|
||||
return self.__private
|
||||
|
||||
# TODO(mdan): Refactor to avoid this use of global state.
|
||||
cache_size_before = len(conversion._WHITELIST_CACHE)
|
||||
cache_size_before = len(conversion._ALLOWLIST_CACHE)
|
||||
|
||||
# First invocation with fallback on, to allow recording it into cache.
|
||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '0'
|
||||
@ -871,15 +871,15 @@ class ApiTest(test.TestCase):
|
||||
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
|
||||
os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1'
|
||||
|
||||
# Entry should be added to the whitelist cache.
|
||||
self.assertEqual(len(conversion._WHITELIST_CACHE), cache_size_before + 1)
|
||||
# Entry should be added to the allowlist cache.
|
||||
self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1)
|
||||
|
||||
# A second invocation should go through even with fallback off.
|
||||
tc = TestClass()
|
||||
api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE)
|
||||
|
||||
# No new entries should appear in the whitelist cache.
|
||||
self.assertEqual(len(conversion._WHITELIST_CACHE), cache_size_before + 1)
|
||||
# No new entries should appear in the allowlist cache.
|
||||
self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1)
|
||||
|
||||
def test_context_tracking_direct_calls(self):
|
||||
|
||||
@ -1102,7 +1102,7 @@ class ApiTest(test.TestCase):
|
||||
|
||||
test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED))
|
||||
|
||||
def test_tf_convert_whitelisted_method(self):
|
||||
def test_tf_convert_allowlisted_method(self):
|
||||
|
||||
if six.PY2:
|
||||
self.skipTest('Test bank not comptible with Python 2.')
|
||||
@ -1112,7 +1112,7 @@ class ApiTest(test.TestCase):
|
||||
def method(self):
|
||||
return converter_testing.is_inside_generated_code()
|
||||
|
||||
converter_testing.whitelist(TestClass.method)
|
||||
converter_testing.allowlist(TestClass.method)
|
||||
|
||||
obj = TestClass()
|
||||
converted_call = api.tf_convert(
|
||||
|
@ -31,7 +31,7 @@ from tensorflow.python.eager import function
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
_WHITELIST_CACHE = cache.UnboundInstanceCache()
|
||||
_ALLOWLIST_CACHE = cache.UnboundInstanceCache()
|
||||
|
||||
|
||||
def _is_of_known_loaded_module(f, module_name):
|
||||
@ -80,53 +80,53 @@ def is_unsupported(o):
|
||||
'{} appears to be decorated by wrapt, which is not yet supported'
|
||||
' by AutoGraph. The function will run as-is.'
|
||||
' You may still apply AutoGraph before the wrapt decorator.'.format(o))
|
||||
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', o)
|
||||
logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
|
||||
return True
|
||||
|
||||
if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
|
||||
logging.log(2, 'Permanently whitelisted: %s: lru_cache', o)
|
||||
logging.log(2, 'Permanently allowed: %s: lru_cache', o)
|
||||
return True
|
||||
|
||||
# Constructors are permanently whitelisted.
|
||||
# Constructors are permanently allowed.
|
||||
# TODO(mdan): Toggle as experimental feature instead.
|
||||
# TODO(b/124016764): Remove this limitation.
|
||||
if inspect_utils.isconstructor(o):
|
||||
logging.log(2, 'Permanently whitelisted: %s: constructor', o)
|
||||
logging.log(2, 'Permanently allowed: %s: constructor', o)
|
||||
return True
|
||||
|
||||
# Other built-in modules are permanently whitelisted.
|
||||
# Other built-in modules are permanently allowed.
|
||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||
if any(
|
||||
_is_of_known_loaded_module(o, m)
|
||||
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', o)
|
||||
logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
|
||||
return True
|
||||
|
||||
# Custom ops and kernels are also permanently whitelisted.
|
||||
# Custom ops and kernels are also permanently allowed.
|
||||
# See tensorflow.framework.load_library.
|
||||
if (hasattr(o, '__module__') and
|
||||
hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
|
||||
logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', o)
|
||||
logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
||||
def is_whitelisted(
|
||||
def is_allowlisted(
|
||||
o, check_call_override=True, allow_namedtuple_subclass=False):
|
||||
"""Checks whether an entity is whitelisted for use in graph mode.
|
||||
"""Checks whether an entity is allowed for use in graph mode.
|
||||
|
||||
Examples of whitelisted entities include all members of the tensorflow
|
||||
Examples of allowed entities include all members of the tensorflow
|
||||
package.
|
||||
|
||||
Args:
|
||||
o: A Python entity.
|
||||
check_call_override: Reserved for internal use. When set to `False`, it
|
||||
disables the rule according to which classes are whitelisted if their
|
||||
__call__ method is whitelisted.
|
||||
disables the rule according to which classes are allowed if their
|
||||
__call__ method is allowed.
|
||||
allow_namedtuple_subclass: Reserved for internal use. When `True`,
|
||||
namedtuple subclasses are not whitelisted.
|
||||
namedtuple subclasses are not allowed.
|
||||
|
||||
Returns:
|
||||
Boolean
|
||||
@ -144,10 +144,10 @@ def is_whitelisted(
|
||||
for rule in config.CONVERSION_RULES:
|
||||
action = rule.get_action(m)
|
||||
if action == config.Action.CONVERT:
|
||||
logging.log(2, 'Not whitelisted: %s: %s', o, rule)
|
||||
logging.log(2, 'Not allowed: %s: %s', o, rule)
|
||||
return False
|
||||
elif action == config.Action.DO_NOT_CONVERT:
|
||||
logging.log(2, 'Whitelisted: %s: %s', o, rule)
|
||||
logging.log(2, 'Allowlisted: %s: %s', o, rule)
|
||||
return True
|
||||
|
||||
# The check for __code__ below is because isgeneratorfunction crashes
|
||||
@ -156,26 +156,26 @@ def is_whitelisted(
|
||||
logging.warn(
|
||||
'Entity %s appears to be a generator function. It will not be converted'
|
||||
' by AutoGraph.', o)
|
||||
logging.log(2, 'Whitelisted: %s: generator functions are not converted', o)
|
||||
logging.log(2, 'Allowlisted: %s: generator functions are not converted', o)
|
||||
return True
|
||||
|
||||
if (check_call_override and not tf_inspect.isclass(o) and
|
||||
hasattr(o, '__call__')):
|
||||
# Callable objects: whitelisted if their __call__ method is.
|
||||
# Callable objects: allowed if their __call__ method is.
|
||||
# The type check avoids infinite recursion around the __call__ method
|
||||
# of function objects.
|
||||
if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__): # pylint: disable=unidiomatic-typecheck
|
||||
logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
|
||||
if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck
|
||||
logging.log(2, 'Allowlisted: %s: object __call__ allowed', o)
|
||||
return True
|
||||
|
||||
owner_class = None
|
||||
if tf_inspect.ismethod(o):
|
||||
# Methods of whitelisted classes are also whitelisted, even if they are
|
||||
# Methods of allowed classes are also allowed, even if they are
|
||||
# bound via user subclasses.
|
||||
#
|
||||
# For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
|
||||
# defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
|
||||
# whitelisted.
|
||||
# defined as below. `tf.Foo` is allowed. Then `baz.bar` is also
|
||||
# allowed.
|
||||
#
|
||||
# class Custom(tf.Foo):
|
||||
# pass
|
||||
@ -183,22 +183,22 @@ def is_whitelisted(
|
||||
# baz = Custom()
|
||||
#
|
||||
# For the example above, if `Custom` did overload `bar`, then it would no
|
||||
# longer be whitelisted.
|
||||
# longer be allowed.
|
||||
|
||||
owner_class = inspect_utils.getmethodclass(o)
|
||||
if owner_class is function.TfMethodTarget:
|
||||
owner_class = o.__self__.target_class
|
||||
if owner_class is not None:
|
||||
if issubclass(owner_class, unittest.TestCase):
|
||||
logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
|
||||
logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o)
|
||||
return True
|
||||
|
||||
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
||||
if is_whitelisted(
|
||||
if is_allowlisted(
|
||||
owner_class,
|
||||
check_call_override=False,
|
||||
allow_namedtuple_subclass=True):
|
||||
logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
|
||||
logging.log(2, 'Allowlisted: %s: owner is allowed %s', o,
|
||||
owner_class)
|
||||
return True
|
||||
|
||||
@ -208,27 +208,27 @@ def is_whitelisted(
|
||||
# graph mode since they are just containers.
|
||||
if allow_namedtuple_subclass:
|
||||
if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
|
||||
logging.log(2, 'Whitelisted: %s: named tuple', o)
|
||||
logging.log(2, 'Allowlisted: %s: named tuple', o)
|
||||
return True
|
||||
else:
|
||||
logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
|
||||
logging.log(2, 'Allowlisted: %s: named tuple or subclass', o)
|
||||
return True
|
||||
|
||||
logging.log(2, 'Not whitelisted: %s: default rule', o)
|
||||
logging.log(2, 'Not allowed: %s: default rule', o)
|
||||
return False
|
||||
|
||||
|
||||
def is_in_whitelist_cache(entity, options):
|
||||
def is_in_allowlist_cache(entity, options):
|
||||
try:
|
||||
return _WHITELIST_CACHE.has(entity, options)
|
||||
return _ALLOWLIST_CACHE.has(entity, options)
|
||||
except TypeError:
|
||||
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
||||
return False
|
||||
|
||||
|
||||
def cache_whitelisted(entity, options):
|
||||
def cache_allowlisted(entity, options):
|
||||
try:
|
||||
_WHITELIST_CACHE[entity][options] = True
|
||||
_ALLOWLIST_CACHE[entity][options] = True
|
||||
except TypeError:
|
||||
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
||||
pass
|
||||
|
@ -43,16 +43,16 @@ class ConversionTest(test.TestCase):
|
||||
options=converter.ConversionOptions(recursive=True),
|
||||
autograph_module=api)
|
||||
|
||||
def test_is_whitelisted(self):
|
||||
def test_is_allowlisted(self):
|
||||
|
||||
def test_fn():
|
||||
return constant_op.constant(1)
|
||||
|
||||
self.assertFalse(conversion.is_whitelisted(test_fn))
|
||||
self.assertTrue(conversion.is_whitelisted(utils))
|
||||
self.assertTrue(conversion.is_whitelisted(constant_op.constant))
|
||||
self.assertFalse(conversion.is_allowlisted(test_fn))
|
||||
self.assertTrue(conversion.is_allowlisted(utils))
|
||||
self.assertTrue(conversion.is_allowlisted(constant_op.constant))
|
||||
|
||||
def test_is_whitelisted_tensorflow_like(self):
|
||||
def test_is_allowlisted_tensorflow_like(self):
|
||||
|
||||
tf_like = imp.new_module('tensorflow_foo')
|
||||
def test_fn():
|
||||
@ -60,13 +60,13 @@ class ConversionTest(test.TestCase):
|
||||
tf_like.test_fn = test_fn
|
||||
test_fn.__module__ = tf_like
|
||||
|
||||
self.assertFalse(conversion.is_whitelisted(tf_like.test_fn))
|
||||
self.assertFalse(conversion.is_allowlisted(tf_like.test_fn))
|
||||
|
||||
def test_is_whitelisted_callable_whitelisted_call(self):
|
||||
def test_is_allowlisted_callable_allowlisted_call(self):
|
||||
|
||||
whitelisted_mod = imp.new_module('test_whitelisted_call')
|
||||
sys.modules['test_whitelisted_call'] = whitelisted_mod
|
||||
config.CONVERSION_RULES = ((config.DoNotConvert('test_whitelisted_call'),) +
|
||||
allowlisted_mod = imp.new_module('test_allowlisted_call')
|
||||
sys.modules['test_allowlisted_call'] = allowlisted_mod
|
||||
config.CONVERSION_RULES = ((config.DoNotConvert('test_allowlisted_call'),) +
|
||||
config.CONVERSION_RULES)
|
||||
|
||||
class TestClass(object):
|
||||
@ -74,14 +74,14 @@ class ConversionTest(test.TestCase):
|
||||
def __call__(self):
|
||||
pass
|
||||
|
||||
def whitelisted_method(self):
|
||||
def allowlisted_method(self):
|
||||
pass
|
||||
|
||||
TestClass.__module__ = 'test_whitelisted_call'
|
||||
TestClass.__module__ = 'test_allowlisted_call'
|
||||
if six.PY2:
|
||||
TestClass.__call__.__func__.__module__ = 'test_whitelisted_call'
|
||||
TestClass.__call__.__func__.__module__ = 'test_allowlisted_call'
|
||||
else:
|
||||
TestClass.__call__.__module__ = 'test_whitelisted_call'
|
||||
TestClass.__call__.__module__ = 'test_allowlisted_call'
|
||||
|
||||
class Subclass(TestClass):
|
||||
|
||||
@ -90,20 +90,21 @@ class ConversionTest(test.TestCase):
|
||||
|
||||
tc = Subclass()
|
||||
|
||||
self.assertTrue(conversion.is_whitelisted(TestClass.__call__))
|
||||
self.assertTrue(conversion.is_whitelisted(tc))
|
||||
self.assertTrue(conversion.is_whitelisted(tc.__call__))
|
||||
self.assertTrue(conversion.is_whitelisted(tc.whitelisted_method))
|
||||
self.assertFalse(conversion.is_whitelisted(Subclass))
|
||||
self.assertFalse(conversion.is_whitelisted(tc.converted_method))
|
||||
self.assertTrue(conversion.is_allowlisted(TestClass.__call__))
|
||||
self.assertTrue(conversion.is_allowlisted(tc))
|
||||
self.assertTrue(conversion.is_allowlisted(tc.__call__))
|
||||
self.assertTrue(conversion.is_allowlisted(tc.allowlisted_method))
|
||||
self.assertFalse(conversion.is_allowlisted(Subclass))
|
||||
self.assertFalse(conversion.is_allowlisted(tc.converted_method))
|
||||
|
||||
def test_is_allowlisted_tfmethodwrapper(self):
|
||||
|
||||
def test_is_whitelisted_tfmethodwrapper(self):
|
||||
class TestClass(object):
|
||||
|
||||
def member_function(self):
|
||||
pass
|
||||
|
||||
TestClass.__module__ = 'test_whitelisted_call'
|
||||
TestClass.__module__ = 'test_allowlisted_call'
|
||||
test_obj = TestClass()
|
||||
|
||||
def test_fn(self):
|
||||
@ -114,14 +115,14 @@ class ConversionTest(test.TestCase):
|
||||
function.TfMethodTarget(
|
||||
weakref.ref(test_obj), test_obj.member_function))
|
||||
|
||||
self.assertTrue(conversion.is_whitelisted(bound_method))
|
||||
self.assertTrue(conversion.is_allowlisted(bound_method))
|
||||
|
||||
def test_is_whitelisted_pybind(self):
|
||||
def test_is_allowlisted_pybind(self):
|
||||
test_object = pybind_for_testing.TestClassDef()
|
||||
with test.mock.patch.object(config, 'CONVERSION_RULES', ()):
|
||||
# TODO(mdan): This should return True for functions and methods.
|
||||
# Note: currently, native bindings are whitelisted by a separate check.
|
||||
self.assertFalse(conversion.is_whitelisted(test_object.method))
|
||||
# Note: currently, native bindings are allowlisted by a separate check.
|
||||
self.assertFalse(conversion.is_allowlisted(test_object.method))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -477,12 +477,14 @@ class AnfConfiguredTest(AnfTestBase):
|
||||
def test_anf_some_function_calls(self):
|
||||
# Another example specific configuration that differs from the default:
|
||||
# Moving all arguments out of some function calls but leaving others be.
|
||||
whitelist = ['foo']
|
||||
allowlist = ['foo']
|
||||
|
||||
def transform(parent, field, child):
|
||||
del field
|
||||
del child
|
||||
func_name = parent.func.id
|
||||
return str(func_name) in whitelist
|
||||
return str(func_name) in allowlist
|
||||
|
||||
config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, anf.ANY), transform)]
|
||||
|
||||
def test_function(x, foo, bar):
|
||||
|
@ -24,10 +24,9 @@ from tensorflow.python.autograph.pyct import origin_info
|
||||
|
||||
|
||||
class FrameInfo(
|
||||
collections.namedtuple(
|
||||
'FrameInfo',
|
||||
('filename', 'lineno', 'function_name', 'code', 'is_converted',
|
||||
'is_whitelisted'))):
|
||||
collections.namedtuple('FrameInfo',
|
||||
('filename', 'lineno', 'function_name', 'code',
|
||||
'is_converted', 'is_allowlisted'))):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
@ -75,7 +74,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
||||
origin_info.create_source_map.
|
||||
converter_filename: str, the file path of the converted module. Call frames
|
||||
corresponding to this module are elided and their preceding frames are
|
||||
marked as whitelisted. Note that frames enclosing converted code are
|
||||
marked as allowlisted. Note that frames enclosing converted code are
|
||||
dropped using a different mechanism.
|
||||
|
||||
Returns:
|
||||
@ -93,7 +92,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
||||
function_name=origin.function_name,
|
||||
code=origin.source_code_line,
|
||||
is_converted=True,
|
||||
is_whitelisted=False)
|
||||
is_allowlisted=False)
|
||||
result_frames.append(fi)
|
||||
break
|
||||
|
||||
@ -107,7 +106,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
||||
function_name=prev.function_name,
|
||||
code=prev.code,
|
||||
is_converted=False,
|
||||
is_whitelisted=True)
|
||||
is_allowlisted=True)
|
||||
result_frames[-1] = fi
|
||||
continue
|
||||
|
||||
@ -117,7 +116,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename):
|
||||
function_name=function_name,
|
||||
code=text,
|
||||
is_converted=False,
|
||||
is_whitelisted=False)
|
||||
is_allowlisted=False)
|
||||
result_frames.append(fi)
|
||||
|
||||
return tuple(result_frames)
|
||||
@ -188,7 +187,7 @@ class ErrorMetadataBase(object):
|
||||
frame_info.function_name)
|
||||
if frame_info.is_converted:
|
||||
formatted_line += ' *'
|
||||
elif frame_info.is_whitelisted:
|
||||
elif frame_info.is_allowlisted:
|
||||
formatted_line += ' **'
|
||||
lines.append(formatted_line)
|
||||
|
||||
|
@ -2250,7 +2250,7 @@ class DatasetV1(DatasetV2):
|
||||
# by value _make_dataset() function would try to capture these variant
|
||||
# tensor dataset inputs, which are marked as stateful ops and would throw
|
||||
# an error if we try and capture them. We therefore traverse the graph
|
||||
# to find all these ops and whitelist them so that the capturing
|
||||
# to find all these ops and allowlist them so that the capturing
|
||||
# logic instead of throwing an error recreates these ops which is what was
|
||||
# happening before.
|
||||
all_ds_ops = traverse.obtain_all_variant_tensor_ops(self)
|
||||
@ -2258,7 +2258,7 @@ class DatasetV1(DatasetV2):
|
||||
|
||||
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
|
||||
# a 0-argument function.
|
||||
@function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops)
|
||||
@function.Defun(capture_by_value=True, allowlisted_stateful_ops=all_ds_ops)
|
||||
def _make_dataset():
|
||||
"""Factory function for a dataset."""
|
||||
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
|
||||
|
@ -1246,8 +1246,8 @@ class DebugAnalyzer(object):
|
||||
parsed = self._arg_parsers["list_source"].parse_args(args)
|
||||
source_list = source_utils.list_source_files_against_dump(
|
||||
self._debug_dump,
|
||||
path_regex_whitelist=parsed.path_filter,
|
||||
node_name_regex_whitelist=parsed.node_name_filter)
|
||||
path_regex_allowlist=parsed.path_filter,
|
||||
node_name_regex_allowlist=parsed.node_name_filter)
|
||||
|
||||
top_lines = [
|
||||
RL("List of source files that created nodes in this run", "bold")]
|
||||
|
@ -1578,9 +1578,9 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testListSourceWithCompiledPythonSourceWorks(self):
|
||||
def fake_list_source_files_against_dump(dump,
|
||||
path_regex_whitelist=None,
|
||||
node_name_regex_whitelist=None):
|
||||
del dump, path_regex_whitelist, node_name_regex_whitelist
|
||||
path_regex_allowlist=None,
|
||||
node_name_regex_allowlist=None):
|
||||
del dump, path_regex_allowlist, node_name_regex_allowlist
|
||||
return [("compiled_1.pyc", False, 10, 20, 30, 4),
|
||||
("compiled_2.pyo", False, 10, 20, 30, 5),
|
||||
("uncompiled.py", False, 10, 20, 30, 6)]
|
||||
|
@ -38,7 +38,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# Many ops have benign NaN outputs, and running them with check_numerics
|
||||
# on will create unwanted errors
|
||||
# TODO(b/142497024): Replace this whitelist with function decorators in the ops
|
||||
# TODO(b/142497024): Replace this allowlist with function decorators in the ops
|
||||
IGNORE_OP_OUTPUTS = (
|
||||
# For FusedBatchNorm, if the input tensor is empty then batch_mean and
|
||||
# batch_variance will be NaN. reserve_space holds intermediate values
|
||||
|
@ -83,16 +83,16 @@ def watch_graph(run_options,
|
||||
graph,
|
||||
debug_ops="DebugIdentity",
|
||||
debug_urls=None,
|
||||
node_name_regex_whitelist=None,
|
||||
op_type_regex_whitelist=None,
|
||||
tensor_dtype_regex_whitelist=None,
|
||||
node_name_regex_allowlist=None,
|
||||
op_type_regex_allowlist=None,
|
||||
tensor_dtype_regex_allowlist=None,
|
||||
tolerate_debug_op_creation_failures=False,
|
||||
global_step=-1,
|
||||
reset_disk_byte_usage=False):
|
||||
"""Add debug watches to `RunOptions` for a TensorFlow graph.
|
||||
|
||||
To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist`
|
||||
and `op_type_regex_whitelist` be the default (`None`).
|
||||
To watch all `Tensor`s on the graph, let both `node_name_regex_allowlist`
|
||||
and `op_type_regex_allowlist` be the default (`None`).
|
||||
|
||||
N.B.:
|
||||
1. Under certain circumstances, the `Tensor` may not get actually watched
|
||||
@ -114,17 +114,17 @@ def watch_graph(run_options,
|
||||
For debug op types with customizable attributes, each debug op name string
|
||||
can optionally contain a list of attribute names, in the syntax of:
|
||||
debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
|
||||
node_name_regex_whitelist: Regular-expression whitelist for node_name,
|
||||
node_name_regex_allowlist: Regular-expression allowlist for node_name,
|
||||
e.g., `"(weight_[0-9]+|bias_.*)"`
|
||||
op_type_regex_whitelist: Regular-expression whitelist for the op type of
|
||||
op_type_regex_allowlist: Regular-expression allowlist for the op type of
|
||||
nodes, e.g., `"(Variable|Add)"`.
|
||||
If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
|
||||
If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
|
||||
are set, the two filtering operations will occur in a logical `AND`
|
||||
relation. In other words, a node will be included if and only if it
|
||||
hits both whitelists.
|
||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
|
||||
hits both allowlists.
|
||||
tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
|
||||
data type, e.g., `"^int.*"`.
|
||||
This whitelist operates in logical `AND` relations to the two whitelists
|
||||
This allowlist operates in logical `AND` relations to the two allowlists
|
||||
above.
|
||||
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
||||
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
||||
@ -142,12 +142,14 @@ def watch_graph(run_options,
|
||||
if isinstance(debug_ops, str):
|
||||
debug_ops = [debug_ops]
|
||||
|
||||
node_name_pattern = (re.compile(node_name_regex_whitelist)
|
||||
if node_name_regex_whitelist else None)
|
||||
op_type_pattern = (re.compile(op_type_regex_whitelist)
|
||||
if op_type_regex_whitelist else None)
|
||||
tensor_dtype_pattern = (re.compile(tensor_dtype_regex_whitelist)
|
||||
if tensor_dtype_regex_whitelist else None)
|
||||
node_name_pattern = (
|
||||
re.compile(node_name_regex_allowlist)
|
||||
if node_name_regex_allowlist else None)
|
||||
op_type_pattern = (
|
||||
re.compile(op_type_regex_allowlist) if op_type_regex_allowlist else None)
|
||||
tensor_dtype_pattern = (
|
||||
re.compile(tensor_dtype_regex_allowlist)
|
||||
if tensor_dtype_regex_allowlist else None)
|
||||
|
||||
ops = graph.get_operations()
|
||||
for op in ops:
|
||||
@ -210,7 +212,7 @@ def watch_graph_with_blacklists(run_options,
|
||||
"""Add debug tensor watches, blacklisting nodes and op types.
|
||||
|
||||
This is similar to `watch_graph()`, but the node names and op types are
|
||||
blacklisted, instead of whitelisted.
|
||||
blacklisted, instead of allowlisted.
|
||||
|
||||
N.B.:
|
||||
1. Under certain circumstances, the `Tensor` may not get actually watched
|
||||
@ -238,7 +240,7 @@ def watch_graph_with_blacklists(run_options,
|
||||
neither of the blacklists.
|
||||
tensor_dtype_regex_blacklist: Regular-expression blacklist for Tensor
|
||||
data type, e.g., `"^int.*"`.
|
||||
This blacklist operates in logical `OR` relations to the two whitelists
|
||||
This blacklist operates in logical `OR` relations to the two allowlists
|
||||
above.
|
||||
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
||||
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
||||
|
@ -227,12 +227,12 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
# Assert that the wildcard node name has been created.
|
||||
self.assertIn("*", node_names)
|
||||
|
||||
def testWatchGraph_nodeNameWhitelist(self):
|
||||
def testWatchGraph_nodeNameAllowlist(self):
|
||||
debug_utils.watch_graph(
|
||||
self._run_options,
|
||||
self._graph,
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)")
|
||||
node_name_regex_allowlist="(a1$|a1_init$|a1/.*|p1$)")
|
||||
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
@ -241,50 +241,50 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
|
||||
sorted(node_names))
|
||||
|
||||
def testWatchGraph_opTypeWhitelist(self):
|
||||
def testWatchGraph_opTypeAllowlist(self):
|
||||
debug_utils.watch_graph(
|
||||
self._run_options,
|
||||
self._graph,
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
op_type_regex_whitelist="(Variable|MatMul)")
|
||||
op_type_regex_allowlist="(Variable|MatMul)")
|
||||
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
|
||||
|
||||
def testWatchGraph_nodeNameAndOpTypeWhitelists(self):
|
||||
def testWatchGraph_nodeNameAndOpTypeAllowlists(self):
|
||||
debug_utils.watch_graph(
|
||||
self._run_options,
|
||||
self._graph,
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
node_name_regex_whitelist="([a-z]+1$)",
|
||||
op_type_regex_whitelist="(MatMul)")
|
||||
node_name_regex_allowlist="([a-z]+1$)",
|
||||
op_type_regex_allowlist="(MatMul)")
|
||||
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(["p1"], node_names)
|
||||
|
||||
def testWatchGraph_tensorDTypeWhitelist(self):
|
||||
def testWatchGraph_tensorDTypeAllowlist(self):
|
||||
debug_utils.watch_graph(
|
||||
self._run_options,
|
||||
self._graph,
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
tensor_dtype_regex_whitelist=".*_ref")
|
||||
tensor_dtype_regex_allowlist=".*_ref")
|
||||
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
|
||||
|
||||
def testWatchGraph_nodeNameAndTensorDTypeWhitelists(self):
|
||||
def testWatchGraph_nodeNameAndTensorDTypeAllowlists(self):
|
||||
debug_utils.watch_graph(
|
||||
self._run_options,
|
||||
self._graph,
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
node_name_regex_whitelist="^a.*",
|
||||
tensor_dtype_regex_whitelist=".*_ref")
|
||||
node_name_regex_allowlist="^a.*",
|
||||
tensor_dtype_regex_allowlist=".*_ref")
|
||||
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
|
@ -143,7 +143,7 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
||||
debug_utils.watch_graph(
|
||||
run_options,
|
||||
sess.graph,
|
||||
node_name_regex_whitelist=r"a",
|
||||
node_name_regex_allowlist=r"a",
|
||||
debug_ops=["DebugIdentity"],
|
||||
debug_urls=[self.debug_server_url])
|
||||
|
||||
@ -155,7 +155,7 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
||||
debug_utils.watch_graph(
|
||||
run_options,
|
||||
sess.graph,
|
||||
node_name_regex_whitelist=r"p",
|
||||
node_name_regex_allowlist=r"p",
|
||||
debug_ops=["DebugIdentity(gated_grpc=True)"],
|
||||
debug_urls=[self.debug_server_url])
|
||||
|
||||
@ -209,8 +209,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
|
||||
def watch_fn(feeds, fetch_keys):
|
||||
del feeds, fetch_keys
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity"],
|
||||
node_name_regex_whitelist=r"p")
|
||||
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"p")
|
||||
|
||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||
|
||||
|
@ -71,7 +71,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
||||
del fetches, feeds
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity"],
|
||||
node_name_regex_whitelist=r"original_u")
|
||||
node_name_regex_allowlist=r"original_u")
|
||||
|
||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||
self.assertAllClose(42.0, sess.run(u))
|
||||
@ -101,8 +102,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
||||
def watch_fn(fetches, feeds):
|
||||
del fetches, feeds # Unused by this watch_fn.
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity"],
|
||||
node_name_regex_whitelist=r"u_init")
|
||||
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||
|
||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||
sess.run(u.initializer)
|
||||
@ -125,8 +126,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
||||
def watch_fn(fetches, feeds):
|
||||
del fetches, feeds
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity"],
|
||||
node_name_regex_whitelist=r"u_init")
|
||||
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||
|
||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||
sess.run(u.initializer)
|
||||
@ -155,8 +156,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
||||
def watch_fn(fetches, feeds):
|
||||
del fetches, feeds
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity"],
|
||||
node_name_regex_whitelist=r"u_init")
|
||||
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||
|
||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||
sess.run(u.initializer)
|
||||
@ -177,8 +178,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
||||
def watch_fn(fetches, feeds):
|
||||
del fetches, feeds
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity"],
|
||||
node_name_regex_whitelist=r"u_init")
|
||||
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||
|
||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||
sess.run(u.initializer)
|
||||
@ -200,8 +201,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
|
||||
def watch_fn(fetches, feeds):
|
||||
del fetches, feeds
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity"],
|
||||
node_name_regex_whitelist=r"u_init")
|
||||
debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init")
|
||||
|
||||
sess = grpc_wrapper.GrpcDebugWrapperSession(
|
||||
sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
|
||||
sess.run(u.initializer)
|
||||
|
@ -207,8 +207,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
del feeds, fetch_keys
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity", "DebugNumericSummary"],
|
||||
node_name_regex_whitelist=r".*/read",
|
||||
op_type_regex_whitelist=None,
|
||||
node_name_regex_allowlist=r".*/read",
|
||||
op_type_regex_allowlist=None,
|
||||
tolerate_debug_op_creation_failures=True)
|
||||
|
||||
u = variables.VariableV1(2.1, name="u")
|
||||
|
@ -221,15 +221,15 @@ def annotate_source(dump,
|
||||
|
||||
|
||||
def list_source_files_against_dump(dump,
|
||||
path_regex_whitelist=None,
|
||||
node_name_regex_whitelist=None):
|
||||
path_regex_allowlist=None,
|
||||
node_name_regex_allowlist=None):
|
||||
"""Generate a list of source files with information regarding ops and tensors.
|
||||
|
||||
Args:
|
||||
dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
|
||||
has been loaded.
|
||||
path_regex_whitelist: A regular-expression filter for source file path.
|
||||
node_name_regex_whitelist: A regular-expression filter for node names.
|
||||
path_regex_allowlist: A regular-expression filter for source file path.
|
||||
node_name_regex_allowlist: A regular-expression filter for node names.
|
||||
|
||||
Returns:
|
||||
A list of tuples regarding the Python source files involved in constructing
|
||||
@ -264,10 +264,11 @@ def list_source_files_against_dump(dump,
|
||||
path_to_first_line = {}
|
||||
tensor_name_to_num_dumps = {}
|
||||
|
||||
path_regex = (re.compile(path_regex_whitelist)
|
||||
if path_regex_whitelist else None)
|
||||
node_name_regex = (re.compile(node_name_regex_whitelist)
|
||||
if node_name_regex_whitelist else None)
|
||||
path_regex = (
|
||||
re.compile(path_regex_allowlist) if path_regex_allowlist else None)
|
||||
node_name_regex = (
|
||||
re.compile(node_name_regex_allowlist)
|
||||
if node_name_regex_allowlist else None)
|
||||
|
||||
to_skip_file_paths = set()
|
||||
for op in py_graph.get_operations():
|
||||
|
@ -406,7 +406,7 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testGenerateSourceListWithNodeNameFilter(self):
|
||||
source_list = source_utils.list_source_files_against_dump(
|
||||
self.dump, node_name_regex_whitelist=r"while/Add.*")
|
||||
self.dump, node_name_regex_allowlist=r"while/Add.*")
|
||||
|
||||
# Assert that the file paths are sorted.
|
||||
file_paths = [item[0] for item in source_list]
|
||||
@ -433,8 +433,8 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
|
||||
curr_file_basename = os.path.basename(self.curr_file_path)
|
||||
source_list = source_utils.list_source_files_against_dump(
|
||||
self.dump,
|
||||
path_regex_whitelist=(
|
||||
".*" + curr_file_basename.replace(".", "\\.") + "$"))
|
||||
path_regex_allowlist=(".*" + curr_file_basename.replace(".", "\\.") +
|
||||
"$"))
|
||||
|
||||
self.assertEqual(1, len(source_list))
|
||||
(file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
|
||||
|
@ -169,7 +169,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
log_usage=False)
|
||||
|
||||
def testDumpingWithLegacyWatchFnOnFetchesWorks(self):
|
||||
"""Use a watch_fn that returns different whitelists for different runs."""
|
||||
"""Use a watch_fn that returns different allowlists for different runs."""
|
||||
|
||||
def watch_fn(fetches, feeds):
|
||||
del feeds
|
||||
@ -240,9 +240,9 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
del fetches, feeds
|
||||
return framework.WatchOptions(
|
||||
debug_ops=["DebugIdentity", "DebugNumericSummary"],
|
||||
node_name_regex_whitelist=r"^v.*",
|
||||
op_type_regex_whitelist=r".*",
|
||||
tensor_dtype_regex_whitelist=".*_ref")
|
||||
node_name_regex_allowlist=r"^v.*",
|
||||
op_type_regex_allowlist=r".*",
|
||||
tensor_dtype_regex_allowlist=".*_ref")
|
||||
|
||||
sess = dumping_wrapper.DumpingDebugWrapperSession(
|
||||
self.sess,
|
||||
@ -288,14 +288,13 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
if watch_fn_state["run_counter"] % 2 == 1:
|
||||
# If odd-index run (1-based), watch every ref-type tensor.
|
||||
return framework.WatchOptions(
|
||||
debug_ops="DebugIdentity",
|
||||
tensor_dtype_regex_whitelist=".*_ref")
|
||||
debug_ops="DebugIdentity", tensor_dtype_regex_allowlist=".*_ref")
|
||||
else:
|
||||
# If even-index run, watch nothing.
|
||||
return framework.WatchOptions(
|
||||
debug_ops="DebugIdentity",
|
||||
node_name_regex_whitelist=r"^$",
|
||||
op_type_regex_whitelist=r"^$")
|
||||
node_name_regex_allowlist=r"^$",
|
||||
op_type_regex_allowlist=r"^$")
|
||||
|
||||
dumping_hook = hooks.DumpingDebugHook(
|
||||
self.session_root, watch_fn=counting_watch_fn, log_usage=False)
|
||||
|
@ -234,9 +234,9 @@ class OnRunStartResponse(object):
|
||||
action,
|
||||
debug_urls,
|
||||
debug_ops="DebugIdentity",
|
||||
node_name_regex_whitelist=None,
|
||||
op_type_regex_whitelist=None,
|
||||
tensor_dtype_regex_whitelist=None,
|
||||
node_name_regex_allowlist=None,
|
||||
op_type_regex_allowlist=None,
|
||||
tensor_dtype_regex_allowlist=None,
|
||||
tolerate_debug_op_creation_failures=False):
|
||||
"""Constructor of `OnRunStartResponse`.
|
||||
|
||||
@ -247,10 +247,10 @@ class OnRunStartResponse(object):
|
||||
during the run() call.
|
||||
debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the
|
||||
debugger.
|
||||
node_name_regex_whitelist: Regular-expression whitelist for node
|
||||
node_name_regex_allowlist: Regular-expression allowlist for node
|
||||
name.
|
||||
op_type_regex_whitelist: Regular-expression whitelist for op type.
|
||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
|
||||
op_type_regex_allowlist: Regular-expression allowlist for op type.
|
||||
tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
|
||||
dtype.
|
||||
tolerate_debug_op_creation_failures: Whether debug op creation failures
|
||||
are to be tolerated.
|
||||
@ -264,9 +264,9 @@ class OnRunStartResponse(object):
|
||||
|
||||
self.debug_ops = debug_ops
|
||||
|
||||
self.node_name_regex_whitelist = node_name_regex_whitelist
|
||||
self.op_type_regex_whitelist = op_type_regex_whitelist
|
||||
self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
|
||||
self.node_name_regex_allowlist = node_name_regex_allowlist
|
||||
self.op_type_regex_allowlist = op_type_regex_allowlist
|
||||
self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
|
||||
self.tolerate_debug_op_creation_failures = (
|
||||
tolerate_debug_op_creation_failures)
|
||||
|
||||
@ -329,7 +329,7 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
Args:
|
||||
sess: An (unwrapped) TensorFlow session instance. It should be a subtype
|
||||
of `BaseSession` or `tf.MonitoredSession`.
|
||||
thread_name_filter: Regular-expression filter (whitelist) for name(s) of
|
||||
thread_name_filter: Regular-expression filter (allowlist) for name(s) of
|
||||
thread(s) on which the wrapper session will be active. This regular
|
||||
expression is used in a start-anchored fashion on the thread name, i.e.,
|
||||
by applying the `match` method of the compiled pattern. The default
|
||||
@ -545,11 +545,10 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
decorated_run_options,
|
||||
run_start_resp.debug_urls,
|
||||
debug_ops=run_start_resp.debug_ops,
|
||||
node_name_regex_whitelist=(
|
||||
run_start_resp.node_name_regex_whitelist),
|
||||
op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
|
||||
tensor_dtype_regex_whitelist=(
|
||||
run_start_resp.tensor_dtype_regex_whitelist),
|
||||
node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist),
|
||||
op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist,
|
||||
tensor_dtype_regex_allowlist=(
|
||||
run_start_resp.tensor_dtype_regex_allowlist),
|
||||
tolerate_debug_op_creation_failures=(
|
||||
run_start_resp.tolerate_debug_op_creation_failures))
|
||||
|
||||
@ -707,9 +706,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
run_options,
|
||||
debug_urls,
|
||||
debug_ops="DebugIdentity",
|
||||
node_name_regex_whitelist=None,
|
||||
op_type_regex_whitelist=None,
|
||||
tensor_dtype_regex_whitelist=None,
|
||||
node_name_regex_allowlist=None,
|
||||
op_type_regex_allowlist=None,
|
||||
tensor_dtype_regex_allowlist=None,
|
||||
tolerate_debug_op_creation_failures=False):
|
||||
"""Modify a RunOptions object for debug tensor watching.
|
||||
|
||||
@ -721,10 +720,10 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
debug_urls: (list of str) debug URLs to be entered in run_options.
|
||||
debug_tensor_watch_opts.
|
||||
debug_ops: (str or list of str) debug op(s) to be used by the debugger.
|
||||
node_name_regex_whitelist: Regular-expression whitelist for node
|
||||
node_name_regex_allowlist: Regular-expression allowlist for node
|
||||
name.
|
||||
op_type_regex_whitelist: Regular-expression whitelist for op type.
|
||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor
|
||||
op_type_regex_allowlist: Regular-expression allowlist for op type.
|
||||
tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor
|
||||
dtype.
|
||||
tolerate_debug_op_creation_failures: Whether debug op creation failures
|
||||
are to be tolerated.
|
||||
@ -736,9 +735,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
self._sess.graph,
|
||||
debug_urls=debug_urls,
|
||||
debug_ops=debug_ops,
|
||||
node_name_regex_whitelist=node_name_regex_whitelist,
|
||||
op_type_regex_whitelist=op_type_regex_whitelist,
|
||||
tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist,
|
||||
node_name_regex_allowlist=node_name_regex_allowlist,
|
||||
op_type_regex_allowlist=op_type_regex_allowlist,
|
||||
tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist,
|
||||
tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
|
||||
reset_disk_byte_usage=(self._run_call_count == 1 or
|
||||
self._is_disk_usage_reset_each_run()))
|
||||
@ -821,8 +820,8 @@ class BaseDebugWrapperSession(session.SessionInterface):
|
||||
def close(self):
|
||||
self._sess.close()
|
||||
|
||||
# TODO(cais): Add _node_name_regex_whitelist and
|
||||
# _node_op_type_regex_whitelist.
|
||||
# TODO(cais): Add _node_name_regex_allowlist and
|
||||
# _node_op_type_regex_allowlist.
|
||||
|
||||
def should_stop(self):
|
||||
if hasattr(self._sess, "should_stop"):
|
||||
@ -838,9 +837,9 @@ class WatchOptions(object):
|
||||
|
||||
def __init__(self,
|
||||
debug_ops=None,
|
||||
node_name_regex_whitelist=None,
|
||||
op_type_regex_whitelist=None,
|
||||
tensor_dtype_regex_whitelist=None,
|
||||
node_name_regex_allowlist=None,
|
||||
op_type_regex_allowlist=None,
|
||||
tensor_dtype_regex_allowlist=None,
|
||||
tolerate_debug_op_creation_failures=False):
|
||||
"""Constructor of WatchOptions: Debug watch options.
|
||||
|
||||
@ -848,17 +847,17 @@ class WatchOptions(object):
|
||||
|
||||
Args:
|
||||
debug_ops: (`str` or `list of str`) Debug ops to be used.
|
||||
node_name_regex_whitelist: Regular-expression whitelist for node_name,
|
||||
node_name_regex_allowlist: Regular-expression allowlist for node_name,
|
||||
e.g., `"(weight_[0-9]+|bias_.*)"`
|
||||
op_type_regex_whitelist: Regular-expression whitelist for the op type of
|
||||
op_type_regex_allowlist: Regular-expression allowlist for the op type of
|
||||
nodes, e.g., `"(Variable|Add)"`.
|
||||
If both `node_name_regex_whitelist` and `op_type_regex_whitelist`
|
||||
If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
|
||||
are set, the two filtering operations will occur in a logical `AND`
|
||||
relation. In other words, a node will be included if and only if it
|
||||
hits both whitelists.
|
||||
tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor
|
||||
hits both allowlists.
|
||||
tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
|
||||
data type, e.g., `"^int.*"`.
|
||||
This whitelist operates in logical `AND` relations to the two whitelists
|
||||
This allowlist operates in logical `AND` relations to the two allowlists
|
||||
above.
|
||||
tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
|
||||
failures (e.g., due to dtype incompatibility) are to be tolerated by not
|
||||
@ -868,19 +867,19 @@ class WatchOptions(object):
|
||||
self.debug_ops = debug_ops
|
||||
else:
|
||||
self.debug_ops = ["DebugIdentity"]
|
||||
self.node_name_regex_whitelist = node_name_regex_whitelist
|
||||
self.op_type_regex_whitelist = op_type_regex_whitelist
|
||||
self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist
|
||||
self.node_name_regex_allowlist = node_name_regex_allowlist
|
||||
self.op_type_regex_allowlist = op_type_regex_allowlist
|
||||
self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist
|
||||
self.tolerate_debug_op_creation_failures = (
|
||||
tolerate_debug_op_creation_failures)
|
||||
|
||||
def __repr__(self):
|
||||
return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, "
|
||||
"op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, "
|
||||
"tolerate_debug_op_creation_failures=%r)" % (
|
||||
self.debug_ops, self.node_name_regex_whitelist,
|
||||
self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist,
|
||||
self.tolerate_debug_op_creation_failures))
|
||||
return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, "
|
||||
"op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, "
|
||||
"tolerate_debug_op_creation_failures=%r)" %
|
||||
(self.debug_ops, self.node_name_regex_allowlist,
|
||||
self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist,
|
||||
self.tolerate_debug_op_creation_failures))
|
||||
|
||||
|
||||
class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
|
||||
@ -952,14 +951,14 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
|
||||
OnRunStartAction.DEBUG_RUN,
|
||||
debug_urls,
|
||||
debug_ops=watch_opts.debug_ops,
|
||||
node_name_regex_whitelist=watch_opts.node_name_regex_whitelist,
|
||||
op_type_regex_whitelist=watch_opts.op_type_regex_whitelist,
|
||||
tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist,
|
||||
node_name_regex_allowlist=watch_opts.node_name_regex_allowlist,
|
||||
op_type_regex_allowlist=watch_opts.op_type_regex_allowlist,
|
||||
tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist,
|
||||
tolerate_debug_op_creation_failures=(
|
||||
watch_opts.tolerate_debug_op_creation_failures))
|
||||
|
||||
def _prepare_run_watch_config(self, fetches, feed_dict):
|
||||
"""Get the debug_urls, and node/op whitelists for the current run() call.
|
||||
"""Get the debug_urls, and node/op allowlists for the current run() call.
|
||||
|
||||
Args:
|
||||
fetches: Same as the `fetches` argument to `Session.run()`.
|
||||
@ -969,7 +968,7 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession):
|
||||
debug_urls: (str or list of str) Debug URLs for the current run() call.
|
||||
Currently, the list consists of only one URL that is a file:// URL.
|
||||
watch_options: (WatchOptions) The return value of a watch_fn, containing
|
||||
options including debug_ops, and whitelists.
|
||||
options including debug_ops, and allowlists.
|
||||
"""
|
||||
|
||||
debug_urls = self.prepare_run_debug_urls(fetches, feed_dict)
|
||||
|
@ -124,12 +124,12 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook):
|
||||
run_args.options,
|
||||
on_run_start_response.debug_urls,
|
||||
debug_ops=on_run_start_response.debug_ops,
|
||||
node_name_regex_whitelist=(
|
||||
on_run_start_response.node_name_regex_whitelist),
|
||||
op_type_regex_whitelist=(
|
||||
on_run_start_response.op_type_regex_whitelist),
|
||||
tensor_dtype_regex_whitelist=(
|
||||
on_run_start_response.tensor_dtype_regex_whitelist),
|
||||
node_name_regex_allowlist=(
|
||||
on_run_start_response.node_name_regex_allowlist),
|
||||
op_type_regex_allowlist=(
|
||||
on_run_start_response.op_type_regex_allowlist),
|
||||
tensor_dtype_regex_allowlist=(
|
||||
on_run_start_response.tensor_dtype_regex_allowlist),
|
||||
tolerate_debug_op_creation_failures=(
|
||||
on_run_start_response.tolerate_debug_op_creation_failures))
|
||||
# pylint: enable=protected-access
|
||||
@ -205,9 +205,9 @@ class DumpingDebugHook(session_run_hook.SessionRunHook):
|
||||
run_context.session.graph,
|
||||
debug_urls=debug_urls,
|
||||
debug_ops=watch_options.debug_ops,
|
||||
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
|
||||
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
|
||||
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
|
||||
node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
|
||||
op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
|
||||
tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
|
||||
tolerate_debug_op_creation_failures=(
|
||||
watch_options.tolerate_debug_op_creation_failures),
|
||||
reset_disk_byte_usage=reset_disk_byte_usage)
|
||||
@ -292,9 +292,9 @@ class GrpcDebugHook(session_run_hook.SessionRunHook):
|
||||
debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls(
|
||||
fetches, feed_dict),
|
||||
debug_ops=watch_options.debug_ops,
|
||||
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
|
||||
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
|
||||
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
|
||||
node_name_regex_allowlist=watch_options.node_name_regex_allowlist,
|
||||
op_type_regex_allowlist=watch_options.op_type_regex_allowlist,
|
||||
tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist,
|
||||
tolerate_debug_op_creation_failures=(
|
||||
watch_options.tolerate_debug_op_creation_failures))
|
||||
|
||||
|
@ -552,9 +552,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
|
||||
run_start_response = framework.OnRunStartResponse(
|
||||
action,
|
||||
debug_urls,
|
||||
node_name_regex_whitelist=parsed.node_name_filter,
|
||||
op_type_regex_whitelist=parsed.op_type_filter,
|
||||
tensor_dtype_regex_whitelist=parsed.tensor_dtype_filter)
|
||||
node_name_regex_allowlist=parsed.node_name_filter,
|
||||
op_type_regex_allowlist=parsed.op_type_filter,
|
||||
tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter)
|
||||
|
||||
if parsed.till_filter_pass:
|
||||
# For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN
|
||||
|
@ -88,7 +88,7 @@ def call_for_each_replica(strategy, fn, args=None, kwargs=None):
|
||||
else:
|
||||
# When a tf.function is wrapped to trigger _call_for_each_replica (see
|
||||
# the other branch above), AutoGraph stops conversion at
|
||||
# _call_for_each_replica itself (TF library functions are whitelisted).
|
||||
# _call_for_each_replica itself (TF library functions are allowlisted).
|
||||
# This makes sure that the Python function that originally passed to
|
||||
# the tf.function is still converted.
|
||||
fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx())
|
||||
|
@ -237,7 +237,7 @@ def _parse_func_attrs(attributes):
|
||||
A dict of attributes where the key is the name of attribute and the value
|
||||
is the AttrValue proto.
|
||||
Raises:
|
||||
ValueError: If the kwargs contains unwhitelisted name or unsupported value
|
||||
ValueError: If the kwargs contains unallowlisted name or unsupported value
|
||||
types.
|
||||
"""
|
||||
attrs = {}
|
||||
@ -3625,9 +3625,9 @@ def defun_with_attributes(func=None,
|
||||
input_signature: same as defun()'s input_signature.
|
||||
attributes: A dictionary of arguments which will be added to function def as
|
||||
attributes. Currently only support primitive types as value, and only
|
||||
whitelisted attribute name is allowed. Unwhitelisted attribute name or
|
||||
allowlisted attribute name is allowed. Unallowlisted attribute name or
|
||||
unsupported value will result into ValueError. `func_name` is also one of
|
||||
the whitelisted argument which is a python string, and sets the name for
|
||||
the allowlisted argument which is a python string, and sets the name for
|
||||
this `ConcreteFunction` in the graph.
|
||||
autograph: same as defun()'s autograph.
|
||||
experimental_autograph_options: same as defun()'s
|
||||
|
@ -108,9 +108,9 @@ _ALL_BLACKLISTED_OPS = (
|
||||
set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS)
|
||||
| set(_ORDER_INSENSITIVE_STATEFUL_OPS))
|
||||
|
||||
# Op types that are marked as stateless, but should be whitelisted to add auto
|
||||
# Op types that are marked as stateless, but should be allowlisted to add auto
|
||||
# control dependencies.
|
||||
_WHITELIST_STATELESS_OPS = [
|
||||
_ALLOWLIST_STATELESS_OPS = [
|
||||
# As TPU collective ops are blocking, if there are more than one collective
|
||||
# op in the function, we need to make sure different collectives ops are
|
||||
# scheduled in certain orders. Otherwise if at the same time all the
|
||||
@ -125,7 +125,7 @@ _WHITELIST_STATELESS_OPS = [
|
||||
def op_is_stateful(op):
|
||||
# pylint: disable=protected-access
|
||||
return (op._is_stateful and op.type not in _ALL_BLACKLISTED_OPS) or (
|
||||
op.type in _WHITELIST_STATELESS_OPS)
|
||||
op.type in _ALLOWLIST_STATELESS_OPS)
|
||||
|
||||
|
||||
class ResourceType(enum.Enum):
|
||||
|
@ -710,12 +710,12 @@ class _ConverterData(object):
|
||||
|
||||
def __init__(self,
|
||||
graph_def,
|
||||
variable_names_whitelist=None,
|
||||
variable_names_allowlist=None,
|
||||
variable_names_blacklist=None):
|
||||
self._graph_def = graph_def
|
||||
self._tensor_data = {}
|
||||
self._build_node_defs_list()
|
||||
self._variable_names_whitelist = variable_names_whitelist
|
||||
self._variable_names_allowlist = variable_names_allowlist
|
||||
self._variable_names_blacklist = variable_names_blacklist
|
||||
|
||||
@property
|
||||
@ -740,8 +740,8 @@ class _ConverterData(object):
|
||||
|
||||
def _should_convert(self, name):
|
||||
"""Checks whether to convert the given variable name to a constant."""
|
||||
return (self._variable_names_whitelist is None or
|
||||
name in self._variable_names_whitelist) and (
|
||||
return (self._variable_names_allowlist is None or
|
||||
name in self._variable_names_allowlist) and (
|
||||
self._variable_names_blacklist is None or
|
||||
name not in self._variable_names_blacklist)
|
||||
|
||||
@ -776,7 +776,7 @@ class _FunctionConverterData(_ConverterData):
|
||||
func,
|
||||
lower_control_flow,
|
||||
aggressive_inlining,
|
||||
variable_names_whitelist=None,
|
||||
variable_names_allowlist=None,
|
||||
variable_names_blacklist=None):
|
||||
"""Creates the conversion data for the given function.
|
||||
|
||||
@ -787,7 +787,7 @@ class _FunctionConverterData(_ConverterData):
|
||||
aggressive_inlining: Boolean indicating whether or not to to aggressive
|
||||
function inlining (might be unsafe if function has stateful ops, not
|
||||
properly connected to control outputs).
|
||||
variable_names_whitelist: The set of variable names to convert (by
|
||||
variable_names_allowlist: The set of variable names to convert (by
|
||||
default, all variables are converted).
|
||||
variable_names_blacklist: The set of variable names to omit converting to
|
||||
constants.
|
||||
@ -799,7 +799,7 @@ class _FunctionConverterData(_ConverterData):
|
||||
aggressive_inlining)
|
||||
super(_FunctionConverterData, self).__init__(
|
||||
graph_def,
|
||||
variable_names_whitelist=variable_names_whitelist,
|
||||
variable_names_allowlist=variable_names_allowlist,
|
||||
variable_names_blacklist=variable_names_blacklist)
|
||||
self._build_tensor_data()
|
||||
|
||||
@ -849,12 +849,12 @@ class _SessionConverterData(_ConverterData):
|
||||
session,
|
||||
graph_def,
|
||||
output_node_names,
|
||||
variable_names_whitelist=None,
|
||||
variable_names_allowlist=None,
|
||||
variable_names_blacklist=None):
|
||||
graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
|
||||
super(_SessionConverterData, self).__init__(
|
||||
graph_def,
|
||||
variable_names_whitelist=variable_names_whitelist,
|
||||
variable_names_allowlist=variable_names_allowlist,
|
||||
variable_names_blacklist=variable_names_blacklist)
|
||||
|
||||
nodes_to_convert = []
|
||||
@ -1114,7 +1114,7 @@ def convert_variables_to_constants_from_session_graph(
|
||||
session,
|
||||
graph_def,
|
||||
output_node_names,
|
||||
variable_names_whitelist=None,
|
||||
variable_names_allowlist=None,
|
||||
variable_names_blacklist=None):
|
||||
"""Replaces all the variables in a graph with constants of the same values.
|
||||
|
||||
@ -1129,7 +1129,7 @@ def convert_variables_to_constants_from_session_graph(
|
||||
session: Active TensorFlow session containing the variables.
|
||||
graph_def: A GraphDef to convert.
|
||||
output_node_names: List of name strings for the result nodes of the graph.
|
||||
variable_names_whitelist: The set of variable names to convert (by default,
|
||||
variable_names_allowlist: The set of variable names to convert (by default,
|
||||
all variables are converted).
|
||||
variable_names_blacklist: The set of variable names to omit converting to
|
||||
constants.
|
||||
@ -1142,6 +1142,6 @@ def convert_variables_to_constants_from_session_graph(
|
||||
session=session,
|
||||
graph_def=graph_def,
|
||||
output_node_names=output_node_names,
|
||||
variable_names_whitelist=variable_names_whitelist,
|
||||
variable_names_allowlist=variable_names_allowlist,
|
||||
variable_names_blacklist=variable_names_blacklist))
|
||||
return graph_def
|
||||
|
@ -49,7 +49,7 @@ from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
WHITELIST_COLLECTIONS = [
|
||||
ALLOWLIST_COLLECTIONS = [
|
||||
ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.LOCAL_VARIABLES,
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||
@ -172,9 +172,9 @@ class FuncGraph(ops.Graph):
|
||||
name: the name of the function.
|
||||
collections: a dictionary of collections this FuncGraph should start
|
||||
with. If not specified (None), the FuncGraph will read (but not write
|
||||
to) the outer graph's collections that are not whitelisted, and both
|
||||
read and write to the outer graph's collections that are whitelisted.
|
||||
The current whitelisted collections are the global variables, the
|
||||
to) the outer graph's collections that are not allowlisted, and both
|
||||
read and write to the outer graph's collections that are allowlisted.
|
||||
The current allowlisted collections are the global variables, the
|
||||
local variables, and the trainable variables.
|
||||
Defaults to None.
|
||||
capture_by_value: An optional boolean. If True, the func graph will
|
||||
@ -241,10 +241,10 @@ class FuncGraph(ops.Graph):
|
||||
|
||||
if collections is None:
|
||||
for collection_name in graph.get_all_collection_keys():
|
||||
if collection_name not in WHITELIST_COLLECTIONS:
|
||||
if collection_name not in ALLOWLIST_COLLECTIONS:
|
||||
self._collections[collection_name] = graph.get_collection(
|
||||
collection_name)
|
||||
for collection_name in WHITELIST_COLLECTIONS:
|
||||
for collection_name in ALLOWLIST_COLLECTIONS:
|
||||
self._collections[collection_name] = graph.get_collection_ref(
|
||||
collection_name)
|
||||
else:
|
||||
@ -842,9 +842,9 @@ def func_graph_from_py_func(name,
|
||||
set, returning an Operation triggers an error.
|
||||
collections: a dictionary of collections this FuncGraph should start
|
||||
with. If not specified (None), the FuncGraph will read (but not write to)
|
||||
the outer graph's collections that are not whitelisted, and both
|
||||
read and write to the outer graph's collections that are whitelisted.
|
||||
The current whitelisted collections are the global variables, the
|
||||
the outer graph's collections that are not allowlisted, and both
|
||||
read and write to the outer graph's collections that are allowlisted.
|
||||
The current allowlisted collections are the global variables, the
|
||||
local variables, and the trainable variables.
|
||||
Defaults to None.
|
||||
capture_by_value: An optional boolean. If True, the func graph will capture
|
||||
|
@ -234,7 +234,7 @@ class _DefinedFunction(object):
|
||||
out_names=None,
|
||||
shape_func=None,
|
||||
capture_by_value=False,
|
||||
whitelisted_stateful_ops=None,
|
||||
allowlisted_stateful_ops=None,
|
||||
capture_resource_var_by_value=True,
|
||||
**kwargs):
|
||||
"""Creates _DefinedFunction.
|
||||
@ -256,7 +256,7 @@ class _DefinedFunction(object):
|
||||
output shapes.
|
||||
capture_by_value: Boolean (defaults to False). If True, captured values
|
||||
will be copied into the function body.
|
||||
whitelisted_stateful_ops: A set of ops that if stateful we ignore and
|
||||
allowlisted_stateful_ops: A set of ops that if stateful we ignore and
|
||||
copy into the function body, when `capture_by_value` is True.
|
||||
capture_resource_var_by_value: Boolean (defaults to True). If False,
|
||||
captured resource variable returns the handle instead of value.
|
||||
@ -275,9 +275,9 @@ class _DefinedFunction(object):
|
||||
self._out_names = out_names
|
||||
self._shape_func = shape_func
|
||||
self._capture_by_value = capture_by_value
|
||||
self._whitelisted_stateful_ops = whitelisted_stateful_ops
|
||||
if self._whitelisted_stateful_ops is None:
|
||||
self._whitelisted_stateful_ops = set()
|
||||
self._allowlisted_stateful_ops = allowlisted_stateful_ops
|
||||
if self._allowlisted_stateful_ops is None:
|
||||
self._allowlisted_stateful_ops = set()
|
||||
self._capture_resource_var_by_value = capture_resource_var_by_value
|
||||
self._extra_kwargs = kwargs
|
||||
# Constructed only when C API is disabled, lazily
|
||||
@ -403,7 +403,7 @@ class _DefinedFunction(object):
|
||||
self._capture_by_value,
|
||||
self._caller_device,
|
||||
collections_ref=collections_ref,
|
||||
whitelisted_stateful_ops=self._whitelisted_stateful_ops,
|
||||
allowlisted_stateful_ops=self._allowlisted_stateful_ops,
|
||||
capture_resource_var_by_value=self._capture_resource_var_by_value)
|
||||
|
||||
self._extra_inputs = temp_graph.extra_inputs
|
||||
@ -690,11 +690,11 @@ class _FuncGraph(ops.Graph):
|
||||
function argument and the caller passes in the captured tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, name, capture_by_value, whitelisted_stateful_ops,
|
||||
def __init__(self, name, capture_by_value, allowlisted_stateful_ops,
|
||||
capture_resource_var_by_value, *args, **kwargs):
|
||||
super(_FuncGraph, self).__init__(*args, **kwargs)
|
||||
self._capture_by_value = capture_by_value
|
||||
self._whitelisted_stateful_ops = whitelisted_stateful_ops
|
||||
self._allowlisted_stateful_ops = allowlisted_stateful_ops
|
||||
self._capture_resource_var_by_value = capture_resource_var_by_value
|
||||
self._building_function = True
|
||||
self._outer_graph = ops.get_default_graph()
|
||||
@ -879,7 +879,7 @@ class _FuncGraph(ops.Graph):
|
||||
def _add_op_and_parents(self, op):
|
||||
# pylint: disable=protected-access
|
||||
op_def = graph_to_function_def._get_op_def(op)
|
||||
if op._is_stateful and op not in self._whitelisted_stateful_ops:
|
||||
if op._is_stateful and op not in self._allowlisted_stateful_ops:
|
||||
raise ValueError("Cannot capture a stateful node (name:%s, type:%s) "
|
||||
"by value." % (op.name, op.type))
|
||||
elif op.type in ("Placeholder", "PlaceholderV2"):
|
||||
@ -912,7 +912,7 @@ def func_graph_from_py_func(func,
|
||||
container=None,
|
||||
collections_ref=None,
|
||||
arg_shapes=None,
|
||||
whitelisted_stateful_ops=None,
|
||||
allowlisted_stateful_ops=None,
|
||||
capture_resource_var_by_value=True):
|
||||
"""Returns a _FuncGraph generated from `func`.
|
||||
|
||||
@ -931,7 +931,7 @@ def func_graph_from_py_func(func,
|
||||
collections_ref: A reference to a collections dict the _FuncGraph should
|
||||
use internally.
|
||||
arg_shapes: A sequence of the function's argument shapes.
|
||||
whitelisted_stateful_ops: A set of ops that if stateful we ignore and
|
||||
allowlisted_stateful_ops: A set of ops that if stateful we ignore and
|
||||
re-create.
|
||||
capture_resource_var_by_value: Boolean (defaults to True). If False,
|
||||
captured resource variable returns the handle instead of value.
|
||||
@ -944,7 +944,7 @@ def func_graph_from_py_func(func,
|
||||
"""
|
||||
if not name:
|
||||
name = function_utils.get_func_name(func)
|
||||
func_graph = _FuncGraph(name, capture_by_value, whitelisted_stateful_ops,
|
||||
func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops,
|
||||
capture_resource_var_by_value)
|
||||
|
||||
with func_graph.as_default(), ops.device(device):
|
||||
|
@ -1043,7 +1043,7 @@ class FunctionTest(test.TestCase):
|
||||
self.assertFalse(all(val4 == val2))
|
||||
|
||||
@test_util.run_v1_only("currently failing on v2")
|
||||
def testStatefulFunctionWithWhitelisting(self):
|
||||
def testStatefulFunctionWithAllowlisting(self):
|
||||
t = random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32)
|
||||
|
||||
@function.Defun(capture_by_value=True)
|
||||
@ -1054,8 +1054,8 @@ class FunctionTest(test.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, "Cannot capture a stateful node"):
|
||||
res = StatefulFn()
|
||||
|
||||
# This time we whitelist this op, so that its recreated.
|
||||
@function.Defun(capture_by_value=True, whitelisted_stateful_ops=set([t.op]))
|
||||
# This time we allowlist this op, so that its recreated.
|
||||
@function.Defun(capture_by_value=True, allowlisted_stateful_ops=set([t.op]))
|
||||
def StatefulFn2():
|
||||
return t + constant_op.constant(3, dtype=dtypes.int32)
|
||||
|
||||
|
@ -276,7 +276,7 @@ def convert_variables_to_constants(sess,
|
||||
session=sess,
|
||||
graph_def=input_graph_def,
|
||||
output_node_names=output_node_names,
|
||||
variable_names_whitelist=variable_names_whitelist,
|
||||
variable_names_allowlist=variable_names_whitelist,
|
||||
variable_names_blacklist=variable_names_blacklist)
|
||||
# The previous code logic generated an empty versions field, we clear it here
|
||||
# to maintain backwards compatibility.
|
||||
|
@ -472,7 +472,7 @@ class ImportGraphDefTest(test.TestCase):
|
||||
node { name: 'B' op: 'FloatInput' input: 'A:0' }
|
||||
"""))
|
||||
|
||||
def testShapeWhitelistViolation(self):
|
||||
def testShapeAllowlistViolation(self):
|
||||
# L2 loss produces a scalar shape, but the graph
|
||||
# has the wrong shape, so raise an error.
|
||||
with ops.Graph().as_default():
|
||||
|
@ -351,7 +351,7 @@ string GenEagerPythonOp::Code() {
|
||||
}
|
||||
|
||||
std::unordered_map<string, string> type_annotations;
|
||||
// Only populate map for whitelisted ops
|
||||
// Only populate map for allowlisted ops
|
||||
if (add_type_annotations_) {
|
||||
type_annotations = GetTypeAnnotations();
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ string InferSourceFileName(const char* argv_zero) {
|
||||
void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
const std::vector<string>& api_def_dirs,
|
||||
const string& source_file_name,
|
||||
bool op_list_is_whitelist,
|
||||
bool op_list_is_allowlist,
|
||||
const std::unordered_set<string> type_annotate_ops) {
|
||||
OpList ops;
|
||||
OpRegistry::Global()->Export(false, &ops);
|
||||
@ -126,11 +126,11 @@ void PrintAllPythonOps(const std::vector<string>& op_list,
|
||||
api_def_map.UpdateDocs();
|
||||
}
|
||||
|
||||
if (op_list_is_whitelist) {
|
||||
std::unordered_set<string> whitelist(op_list.begin(), op_list.end());
|
||||
if (op_list_is_allowlist) {
|
||||
std::unordered_set<string> allowlist(op_list.begin(), op_list.end());
|
||||
OpList pruned_ops;
|
||||
for (const auto& op_def : ops.op()) {
|
||||
if (whitelist.find(op_def.name()) != whitelist.end()) {
|
||||
if (allowlist.find(op_def.name()) != allowlist.end()) {
|
||||
*pruned_ops.mutable_op()->Add() = op_def;
|
||||
}
|
||||
}
|
||||
@ -165,13 +165,13 @@ int main(int argc, char* argv[]) {
|
||||
|
||||
if (argc == 2) {
|
||||
tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
|
||||
false /* op_list_is_whitelist */,
|
||||
false /* op_list_is_allowlist */,
|
||||
type_annotate_ops);
|
||||
} else if (argc == 3) {
|
||||
std::vector<tensorflow::string> hidden_ops;
|
||||
TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
|
||||
tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
|
||||
false /* op_list_is_whitelist */,
|
||||
false /* op_list_is_allowlist */,
|
||||
type_annotate_ops);
|
||||
} else if (argc == 4) {
|
||||
std::vector<tensorflow::string> op_list;
|
||||
|
@ -201,7 +201,7 @@ def _recurrent_lstm(c, h):
|
||||
def _make_node_with_color(color, input_tensor, name=None):
|
||||
"""Returns a node representative of the specified list type."""
|
||||
color = color.lower()
|
||||
if color == 'w': # White node
|
||||
if color == 'w': # Allow node
|
||||
weights = _weight(input_tensor.get_shape().as_list())
|
||||
return math_ops.matmul(input_tensor, weights, name=name)
|
||||
if color == 'g': # Gray node
|
||||
@ -371,7 +371,7 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
The loop has different node colors in different sections of the graph. The
|
||||
arguments must be strings where each character represents the color of a
|
||||
node in that section of the graph: w = white, g = gray, c = clear,
|
||||
node in that section of the graph: w = allow, g = gray, c = clear,
|
||||
b = black. CAPITALIZED characters indicate that the node is expected to be
|
||||
changed to DT_HALF during graph optimization.
|
||||
|
||||
|
@ -1594,7 +1594,7 @@ def assert_not_batched(dataset):
|
||||
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
||||
return assert_not_batched(dataset._dataset)
|
||||
else:
|
||||
whitelisted_types = [
|
||||
allowed_types = [
|
||||
dataset_ops._OptionsDataset,
|
||||
dataset_ops.ConcatenateDataset,
|
||||
dataset_ops.CacheDataset,
|
||||
@ -1615,7 +1615,7 @@ def assert_not_batched(dataset):
|
||||
readers.TextLineDatasetV2,
|
||||
readers.TFRecordDatasetV2,
|
||||
]
|
||||
for ty in whitelisted_types:
|
||||
for ty in allowed_types:
|
||||
if isinstance(dataset, ty):
|
||||
for input_dataset in dataset._inputs():
|
||||
assert_not_batched(input_dataset)
|
||||
@ -1649,7 +1649,7 @@ def assert_not_shuffled(dataset):
|
||||
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
|
||||
return assert_not_shuffled(dataset._dataset)
|
||||
else:
|
||||
whitelisted_types = [
|
||||
allowed_types = [
|
||||
dataset_ops._OptionsDataset,
|
||||
dataset_ops.BatchDataset,
|
||||
dataset_ops.ConcatenateDataset,
|
||||
@ -1672,7 +1672,7 @@ def assert_not_shuffled(dataset):
|
||||
readers.TextLineDatasetV2,
|
||||
readers.TFRecordDatasetV2,
|
||||
]
|
||||
for ty in whitelisted_types:
|
||||
for ty in allowed_types:
|
||||
if isinstance(dataset, ty):
|
||||
for input_dataset in dataset._inputs():
|
||||
assert_not_shuffled(input_dataset)
|
||||
|
@ -2858,7 +2858,7 @@ class DistributedCallbackModel(Model):
|
||||
orig_model_weights)
|
||||
|
||||
def __getattr__(self, item):
|
||||
# Whitelisted attributes of the model that can be accessed by the user
|
||||
# Allowed attributes of the model that can be accessed by the user
|
||||
# during a callback.
|
||||
if item not in ('_setattr_tracking', '_layers'):
|
||||
logging.warning('You are accessing attribute ' + item + ' of the '
|
||||
|
@ -333,7 +333,7 @@ class TimeDistributedTest(keras_parameterized.TestCase):
|
||||
keras.layers.RNN(keras.layers.SimpleRNNCell(10), stateful=True))
|
||||
self.assertFalse(td2._always_use_reshape)
|
||||
|
||||
# Custom layers are not whitelisted for the fast reshape implementation.
|
||||
# Custom layers are not allowlisted for the fast reshape implementation.
|
||||
td3 = keras.layers.TimeDistributed(NoReshapeLayer())
|
||||
self.assertFalse(td3._always_use_reshape)
|
||||
|
||||
|
@ -898,7 +898,7 @@ class OptimizerWithFunctionTest(test.TestCase):
|
||||
|
||||
_NUM_LEARNERS = 50
|
||||
APPLY_SCOPE = 'debug_apply'
|
||||
WHITELIST = [
|
||||
ALLOWLIST = [
|
||||
# optimizer_v2._deduplicate_indexed_slices contains an indexed slice:
|
||||
# array_ops.shape(unique_indices)[0]
|
||||
# which winds up expanding to [0:1:1] thereby creating three constants
|
||||
@ -1025,8 +1025,8 @@ def identify_redundant_ops(graph):
|
||||
# Certain ops are simply not worth eliminating, and are instead simply
|
||||
# ignored.
|
||||
name, op_type = op_defs[0].name, op_defs[0].type
|
||||
if any(whitelisted_scope in name and op_type == whitelisted_type
|
||||
for whitelisted_scope, whitelisted_type in WHITELIST):
|
||||
if any(allowlisted_scope in name and op_type == allowlisted_type
|
||||
for allowlisted_scope, allowlisted_type in ALLOWLIST):
|
||||
continue
|
||||
|
||||
num_duplicates += len(op_defs)
|
||||
|
@ -45,7 +45,7 @@ def index_directory(directory,
|
||||
valid files found in the directory. Labels should be sorted according
|
||||
to the alphanumeric order of the image file paths
|
||||
(obtained via `os.walk(directory)` in Python).
|
||||
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
|
||||
formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
|
||||
class_names: Only valid if "labels" is "inferred". This is the explict
|
||||
list of class names (must match names of subdirectories). Used
|
||||
to control the order of the classes
|
||||
@ -136,7 +136,7 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
|
||||
class_indices: dict mapping class names to their index.
|
||||
follow_links: boolean, whether to recursively follow subdirectories
|
||||
(if False, we only list top-level images in `directory`).
|
||||
formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt").
|
||||
formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt").
|
||||
|
||||
Returns:
|
||||
tuple `(filenames, labels)`. `filenames` is a list of relative file
|
||||
|
@ -28,7 +28,7 @@ from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
WHITELIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')
|
||||
ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png')
|
||||
|
||||
|
||||
@keras_export('keras.preprocessing.image_dataset_from_directory', v1=[])
|
||||
@ -175,7 +175,7 @@ def image_dataset_from_directory(directory,
|
||||
image_paths, labels, class_names = dataset_utils.index_directory(
|
||||
directory,
|
||||
labels,
|
||||
formats=WHITELIST_FORMATS,
|
||||
formats=ALLOWLIST_FORMATS,
|
||||
class_names=class_names,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
|
@ -865,7 +865,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
This only allows capturing tensors in the forward graph. A ValueError is
|
||||
raised if an attempt is made to capture a tensor not in the forward graph.
|
||||
To manually capture capture a tensor that is not in the forward graph, call
|
||||
`capture` with `whitelisted=True`.
|
||||
`capture` with `allowlisted=True`.
|
||||
|
||||
Note: The `captures` dict does not contain the forward tensor since it is not
|
||||
directly captured. It contains the accumulator corresponding to this forward
|
||||
@ -968,16 +968,16 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
op_def=op_def,
|
||||
compute_device=compute_device)
|
||||
|
||||
def capture(self, tensor, name=None, whitelisted=False):
|
||||
def capture(self, tensor, name=None, allowlisted=False):
|
||||
"""Selectively captures external tensors.
|
||||
|
||||
If `whitelisted` is False only allows capturing tensors in the
|
||||
If `allowlisted` is False only allows capturing tensors in the
|
||||
`_forward_graph`.
|
||||
|
||||
Args:
|
||||
tensor: Tensor. May be from this FuncGraph or a different graph.
|
||||
name: Optional name if a placeholder is created.
|
||||
whitelisted: If False (default), only allows capturing tensors from the
|
||||
allowlisted: If False (default), only allows capturing tensors from the
|
||||
forward graph.
|
||||
|
||||
Returns:
|
||||
@ -985,9 +985,9 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
|
||||
Raises:
|
||||
ValueError: If attempting to capture an external tensor not in the forward
|
||||
graph with `whitelisted` set to False.
|
||||
graph with `allowlisted` set to False.
|
||||
"""
|
||||
if not whitelisted and (isinstance(tensor, ops.EagerTensor) or
|
||||
if not allowlisted and (isinstance(tensor, ops.EagerTensor) or
|
||||
(tensor.graph is not self and
|
||||
tensor.graph != self._forward_graph)):
|
||||
with self._forward_cond_graph.as_default():
|
||||
@ -1136,7 +1136,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
"Resource tensors must be loop invariants %s." % tensor_in_outer_graph)
|
||||
|
||||
self._indirect_captures[ops.tensor_id(tensor)] = self.capture(
|
||||
tensor_in_outer_graph, whitelisted=True)
|
||||
tensor_in_outer_graph, allowlisted=True)
|
||||
return self._indirect_captures[ops.tensor_id(tensor)]
|
||||
|
||||
|
||||
|
@ -143,10 +143,10 @@ def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
|
||||
# computation.
|
||||
with body_grad_graph.as_default():
|
||||
input_slices = ops.IndexedSlices(
|
||||
values=body_grad_graph.capture(init_slices.values, whitelisted=True),
|
||||
indices=body_grad_graph.capture(init_slices.indices, whitelisted=True),
|
||||
dense_shape=body_grad_graph.capture(init_slices.dense_shape,
|
||||
whitelisted=True))
|
||||
values=body_grad_graph.capture(init_slices.values, allowlisted=True),
|
||||
indices=body_grad_graph.capture(init_slices.indices, allowlisted=True),
|
||||
dense_shape=body_grad_graph.capture(
|
||||
init_slices.dense_shape, allowlisted=True))
|
||||
|
||||
# Remove the captured tensors from the function inputs. We'll add them back
|
||||
# at the correct index in _update_indexed_slices_param.
|
||||
|
@ -36,7 +36,7 @@ from tensorflow.python.platform import tf_logging
|
||||
# corresponding kernel; nodes without a corresponding kernel (perhaps due to
|
||||
# attr types) generate a warning but are otherwise ignored. Ops in this set are
|
||||
# registered even if there's no corresponding kernel.
|
||||
OPS_WITHOUT_KERNEL_WHITELIST = frozenset([
|
||||
OPS_WITHOUT_KERNEL_ALLOWLIST = frozenset([
|
||||
# AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
|
||||
# core/common_runtime/accumulate_n_optimizer.cc.
|
||||
'AccumulateNV2'
|
||||
@ -67,7 +67,7 @@ def _get_ops_from_graphdef(graph_def):
|
||||
kernel_class = _pywrap_kernel_registry.TryFindKernelClass(
|
||||
node_def.SerializeToString())
|
||||
op = str(node_def.op)
|
||||
if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
|
||||
if kernel_class or op in OPS_WITHOUT_KERNEL_ALLOWLIST:
|
||||
op_and_kernel = (op, str(kernel_class.decode('utf-8'))
|
||||
if kernel_class else None)
|
||||
ops.add(op_and_kernel)
|
||||
|
@ -68,7 +68,7 @@ class EmbeddingColumnTest(test.TestCase):
|
||||
tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension)
|
||||
|
||||
def test_custom_column(self):
|
||||
# This column is not in any whitelist but should succeed because
|
||||
# This column is not in any allowlist but should succeed because
|
||||
# it inherits from V2 CategoricalColumn.
|
||||
categorical_column = fc_lib.categorical_column_with_identity(
|
||||
key='aaa', num_buckets=10)
|
||||
|
@ -122,7 +122,7 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
|
||||
* `ClearList`: Ops that do not have numerically significant adverse effects.
|
||||
E.g. `ArgMax` and `Floor`.
|
||||
* `WhiteList`: Ops that are considered numerically safe for execution in
|
||||
* `AllowList`: Ops that are considered numerically safe for execution in
|
||||
float16, and thus are always converted. E.g. `Conv2D`.
|
||||
* `BlackList`: Ops that are numerically unsafe to execute in float16 and
|
||||
can negatively affect downstream nodes. E.g. `Softmax`.
|
||||
@ -267,7 +267,7 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
|
||||
* `ClearList`: Ops that do not have numerically significant adverse effects.
|
||||
E.g. `ArgMax` and `Floor`.
|
||||
* `WhiteList`: Ops that are considered numerically safe for execution in
|
||||
* `AllowList`: Ops that are considered numerically safe for execution in
|
||||
float16, and thus are always converted. E.g. `Conv2D`.
|
||||
* `BlackList`: Ops that are numerically unsafe to execute in float16 and
|
||||
can negatively affect downstream nodes. E.g. `Softmax`.
|
||||
|
@ -93,7 +93,7 @@ def remove_undocumented(module_name, allowed_exception_list=None,
|
||||
doc_string_modules: a list of modules from which to take the docstrings.
|
||||
If None, then a list containing only the module named `module_name` is used.
|
||||
|
||||
Furthermore, if a symbol previously added with `add_to_global_whitelist`,
|
||||
Furthermore, if a symbol previously added with `add_to_global_allowlist`,
|
||||
then it will always be allowed. This is useful for internal tests.
|
||||
|
||||
Returns:
|
||||
|
@ -96,8 +96,8 @@ do_pylint() {
|
||||
# --incremental Performs check on only the python files changed in the
|
||||
# last non-merge git commit.
|
||||
|
||||
# Use this list to whitelist pylint errors
|
||||
ERROR_WHITELIST="^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline "\
|
||||
# Use this list to allowlist pylint errors
|
||||
ERROR_ALLOWLIST="^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline "\
|
||||
"^tensorflow/python/platform/default/_gfile\.py.*\[E0301.*non-iterator "\
|
||||
"^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\
|
||||
"^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\
|
||||
@ -115,7 +115,7 @@ do_pylint() {
|
||||
"^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "\
|
||||
"^tensorflow/python/keras/preprocessing/image\.py.*\[E0240.*Inconsistent method resolution "
|
||||
|
||||
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
||||
echo "ERROR_ALLOWLIST=\"${ERROR_ALLOWLIST}\""
|
||||
|
||||
if [[ $# != "0" ]] && [[ $# != "1" ]]; then
|
||||
echo "Invalid syntax when invoking do_pylint"
|
||||
@ -195,16 +195,16 @@ do_pylint() {
|
||||
|
||||
N_ERRORS=0
|
||||
while read -r LINE; do
|
||||
IS_WHITELISTED=0
|
||||
for WL_REGEX in ${ERROR_WHITELIST}; do
|
||||
IS_ALLOWLISTED=0
|
||||
for WL_REGEX in ${ERROR_ALLOWLIST}; do
|
||||
if echo ${LINE} | grep -q "${WL_REGEX}"; then
|
||||
echo "Found a whitelisted error:"
|
||||
echo "Found a allowlisted error:"
|
||||
echo " ${LINE}"
|
||||
IS_WHITELISTED=1
|
||||
IS_ALLOWLISTED=1
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ ${IS_WHITELISTED} == "0" ]]; then
|
||||
if [[ ${IS_ALLOWLISTED} == "0" ]]; then
|
||||
echo "${LINE}" >> ${NONWL_ERRORS_FILE}
|
||||
echo "" >> ${NONWL_ERRORS_FILE}
|
||||
((N_ERRORS++))
|
||||
@ -213,11 +213,11 @@ do_pylint() {
|
||||
|
||||
echo ""
|
||||
if [[ ${N_ERRORS} != 0 ]]; then
|
||||
echo "FAIL: Found ${N_ERRORS} non-whitelisted pylint errors:"
|
||||
echo "FAIL: Found ${N_ERRORS} non-allowlisted pylint errors:"
|
||||
cat "${NONWL_ERRORS_FILE}"
|
||||
return 1
|
||||
else
|
||||
echo "PASS: No non-whitelisted pylint errors were found."
|
||||
echo "PASS: No non-allowlisted pylint errors were found."
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
@ -370,7 +370,7 @@ do_external_licenses_check(){
|
||||
-v ${MISSING_LICENSES_FILE} > temp.txt
|
||||
mv temp.txt ${MISSING_LICENSES_FILE}
|
||||
|
||||
# Whitelist
|
||||
# Allowlist
|
||||
echo ${EXTRA_LICENSE_FILE}
|
||||
grep \
|
||||
-e "//third_party/mkl" \
|
||||
|
@ -40,7 +40,7 @@ FUTURES_PATTERN_2 = re.compile(
|
||||
FUTURES_PATTERN_3 = re.compile(r'^from __future__ import (\w+) as \w+\s*$')
|
||||
REQUIRED_FUTURES = frozenset(['absolute_import', 'division', 'print_function'])
|
||||
|
||||
WHITELIST = [
|
||||
ALLOWLIST = [
|
||||
'python/platform/control_imports.py',
|
||||
'tools/docker/jupyter_notebook_config.py',
|
||||
'tools/ci_build/update_version.py',
|
||||
@ -93,12 +93,12 @@ def main():
|
||||
BASE_DIR)
|
||||
|
||||
# Verify that all files have futures
|
||||
whitelist = frozenset(os.path.join(BASE_DIR, w) for w in WHITELIST)
|
||||
allowlist = frozenset(os.path.join(BASE_DIR, w) for w in ALLOWLIST)
|
||||
old_division = frozenset(os.path.join(BASE_DIR, w) for w in OLD_DIVISION)
|
||||
for root, _, filenames in os.walk(BASE_DIR):
|
||||
for f in fnmatch.filter(filenames, '*.py'):
|
||||
path = os.path.join(root, f)
|
||||
if path not in whitelist:
|
||||
if path not in allowlist:
|
||||
try:
|
||||
check_file(path, old_division=path in old_division)
|
||||
except AssertionError as e:
|
||||
|
Loading…
Reference in New Issue
Block a user