Merge branch 'master' into cast_zeros_like_micro_build

This commit is contained in:
rsun-bdti 2021-02-05 19:27:03 -08:00 committed by GitHub
commit eb7a1fac1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 630 additions and 80 deletions

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
@ -525,6 +526,13 @@ class Translator {
BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds Assign/Read Variable ops.
template <typename T>
BufferOffset<tflite::Operator> BuildVariableOperator(
T op, const std::string& op_name, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
BufferOffset<tflite::Operator> BuildCustomOperator(
Operation* inst, mlir::TFL::CustomOp op,
const std::vector<int32_t>& operands,
@ -936,6 +944,17 @@ BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
tflite::CustomOptionsFormat_FLEXBUFFERS);
}
// Builds Assign/Read Variable ops.
template <typename T>
BufferOffset<tflite::Operator> Translator::BuildVariableOperator(
T op, const std::string& op_name, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE);
}
BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
Operation* inst, mlir::TFL::CustomOp op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
@ -1077,6 +1096,18 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
return llvm::None;
}
// TODO(b/149099381): Remove this once the kernels are promoted as
// builtin TFLite kernels.
// We export the Assign/Read variable ops as custom ops.
if (auto read_op = llvm::dyn_cast<mlir::TFL::ReadVariableOp>(inst)) {
return BuildVariableOperator<mlir::TFL::ReadVariableOp>(
read_op, "ReadVariable", operands, results);
} else if (auto assign_op =
llvm::dyn_cast<mlir::TFL::AssignVariableOp>(inst)) {
return BuildVariableOperator<mlir::TFL::AssignVariableOp>(
assign_op, "AssignVariable", operands, results);
}
// If TFLite built in op, create operator as a builtin op.
if (dialect == tfl_dialect_) {
// Only if built-in TFLite op emission is enabled, would legalization have

View File

@ -214,6 +214,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
if (pass_config.enable_tflite_variables) {
pass_manager->addPass(mlir::TFL::CreateInitializeVariablesPass());
pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
pass_manager->addPass(mlir::TFL::CreateRemoveArgsAndGlobalTensors());
}
pass_manager->addNestedPass<mlir::FuncOp>(mlir::TFL::CreateOptimizePass());
// This pass operates on TensorFlow ops but is triggered after legalization
// so that it can target constants introduced once TensorFlow Identity ops

View File

@ -256,16 +256,11 @@ Status ExecuteWrapperAfterExecution(
}
}
const auto& dump_path =
executable->module_config().debug_options().xla_dump_to();
if (executable->module_config().debug_options().xla_hlo_profile() &&
state.profile_ptr != nullptr && !dump_path.empty()) {
const std::string full_path =
tensorflow::io::JoinPath(dump_path, "hlo_execution_profile_data");
TF_CHECK_OK(tensorflow::WriteStringToFile(
tensorflow::Env::Default(), full_path,
state.profile_ptr->ToProto().SerializeAsString()))
<< "Error saving HloExecutionProfileData to " << full_path;
state.profile_ptr != nullptr) {
DumpToFileInDir(executable->module(), /*file_prefix=*/"",
/*file_suffix=*/"hlo_execution_profile_data",
state.profile_ptr->ToProto().SerializeAsString());
}
if (state.profile_ptr != nullptr) {

View File

@ -45,8 +45,8 @@ class HloRematerialization : public HloModulePass {
// Helper struct that communicates the before / after sizes for the
// rematerialization process.
struct RematerializationSizes {
int64 before_bytes;
int64 after_bytes;
int64 before_bytes = -1;
int64 after_bytes = -1;
};
// Mode in which the rematerialization algorithm should be run.

View File

@ -203,6 +203,7 @@ FRAMEWORK_PROTO_SRCS = [
"//tensorflow/core/framework:model.proto",
"//tensorflow/core/framework:node_def.proto",
"//tensorflow/core/framework:op_def.proto",
"//tensorflow/core/framework:dataset_options.proto",
"//tensorflow/core/framework:reader_base.proto",
"//tensorflow/core/framework:remote_fused_graph_execute_info.proto",
"//tensorflow/core/framework:resource_handle.proto",

View File

@ -35,6 +35,7 @@ limitations under the License.
namespace tensorflow {
constexpr BFCAllocator::ChunkHandle BFCAllocator::kInvalidChunkHandle;
constexpr uint64 BFCAllocator::kMemDebugHistorySize;
BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
bool allow_growth, const string& name,

View File

@ -1277,7 +1277,7 @@ Status EagerContext::UpdateRemoteMaster(
context_view_id_++;
remote_eager_workers_ = std::move(remote_eager_workers);
pflr_->InitializeDeviceSet();
pflr_->InitializeDeviceAndFlr();
InitPrioritizedDeviceTypeList();
default_executor_.ClearError();
@ -1496,7 +1496,7 @@ Status EagerContext::UpdateRemoteWorker(
remote_contexts_ = remote_contexts;
remote_eager_workers_ = std::move(remote_eager_workers);
InitPrioritizedDeviceTypeList();
pflr_->InitializeDeviceSet();
pflr_->InitializeDeviceAndFlr();
}
// No need to update remote_device_manager_ since it's not owned for remote

View File

@ -486,6 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
const SessionOptions& session_options() const { return opts_; }
void InitPrioritizedDeviceTypeList();
private:
Rendezvous* CreateRendezvous(int64 step_id) const {
@ -510,7 +511,6 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
~EagerContext() override;
void InitPrioritizedDeviceTypeList();
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
Status RegisterExistingFunctionsOnRemoteWorkers(
const std::vector<string>& remote_workers);

View File

@ -100,7 +100,9 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
std::unique_ptr<FunctionLibraryRuntime>>),
next_handle_(0),
session_metadata_(session_metadata),
rendezvous_factory_(std::move(rendezvous_factory)) {
rendezvous_factory_(std::move(rendezvous_factory)),
optimizer_options_(optimizer_options),
graph_def_version_(graph_def_version) {
if (device_mgr == nullptr) {
(*flr_map_)[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, config_ ? &(*config_) : nullptr, nullptr,
@ -108,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
session_metadata_, this);
return;
}
for (Device* d : device_mgr->ListDevices()) {
(*flr_map_)[d] = NewFunctionLibraryRuntime(
device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version,
lib_def_, default_thread_pool, optimizer_options, session_metadata_,
this);
}
InitializeDeviceSet();
InitializeDeviceAndFlr();
}
/* static */
@ -214,7 +209,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
"function executions");
}
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() {
DeviceMgr const* all_devices = device_mgr_;
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
all_devices = parent_->remote_device_mgr();
@ -225,6 +220,14 @@ void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
for (auto d : all_devices->ListDevices()) {
device_set_->AddDevice(d);
}
for (Device* d : device_mgr_->ListDevices()) {
if ((*flr_map_)[d] == nullptr) {
(*flr_map_)[d] = NewFunctionLibraryRuntime(
device_mgr_, env_, config_ ? &(*config_) : nullptr, d,
graph_def_version_, lib_def_, default_thread_pool_,
optimizer_options_, session_metadata_, this);
}
}
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(

View File

@ -207,8 +207,9 @@ class ProcessFunctionLibraryRuntime {
return device_set_;
}
// Initialize the set of local and remote devices for op device selection.
void InitializeDeviceSet();
// Initialize the set of local and remote devices and corresponding flr for op
// device selection.
void InitializeDeviceAndFlr();
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
@ -478,6 +479,9 @@ class ProcessFunctionLibraryRuntime {
int next_handle_ TF_GUARDED_BY(mu_);
const SessionMetadata* const session_metadata_;
const Rendezvous::Factory rendezvous_factory_;
const OptimizerOptions optimizer_options_;
const int graph_def_version_;
};
} // namespace tensorflow

View File

@ -119,6 +119,7 @@ exports_files(
"api_def.proto",
"attr_value.proto",
"cost_graph.proto",
"dataset_options.proto",
"device_attributes.proto",
"function.proto",
"graph.proto",
@ -1660,6 +1661,13 @@ tf_proto_library(
make_default_target_header_only = True,
)
tf_proto_library(
name = "dataset_options_proto",
srcs = ["dataset_options.proto"],
cc_api_version = 2,
make_default_target_header_only = True,
)
tf_proto_library(
name = "protos_all",
cc_api_version = 2,
@ -1678,6 +1686,7 @@ tf_proto_library(
":model_proto",
":node_def_proto",
":op_def_proto",
":dataset_options_proto",
":reader_base_proto",
":remote_fused_graph_execute_info_proto",
":resource_handle_proto",

View File

@ -0,0 +1,179 @@
syntax = "proto3";
package tensorflow.data;
// Represents the type of auto-sharding we enable.
enum AutoShardPolicy {
AUTO = 0;
FILE = 1;
DATA = 2;
OFF = -1;
}
message DistributeOptions {
// The type of sharding that auto-shard should attempt. If this is set to
// FILE, then we will attempt to shard by files (each worker will get a set of
// files to process). If we cannot find a set of files to shard for at least
// one file per worker, we will error out. When this option is selected, make
// sure that you have enough files so that each worker gets at least one file.
// There will be a runtime error thrown if there are insufficient files. If
// this is set to DATA, then we will shard by elements produced by the
// dataset, and each worker will process the whole dataset and discard the
// portion that is not for itself. If this is set to OFF, then we will not
// autoshard, and each worker will receive a copy of the full dataset. This
// option is set to AUTO by default, AUTO will attempt to first shard by FILE,
// and fall back to sharding by DATA if we cannot find a set of files to
// shard.
AutoShardPolicy auto_shard_policy = 1;
// The number of devices attached to this input pipeline.
oneof optional_num_devices {
int32 num_devices = 2;
}
}
message MapVectorization {
// Whether to vectorize map transformations.
oneof optional_enabled {
bool enabled = 1;
}
// Whether to use ChooseFastestBranchDataset with this transformation. If
// True, the pipeline picks between the vectorized and original segment at
// runtime based on their iterations speed.
oneof optional_use_choose_fastest {
bool use_choose_fastest = 2;
}
}
message OptimizationOptions {
// Whether to apply default graph optimizations. If False, only graph
// optimizations that have been explicitly enabled will be applied.
oneof optional_apply_default_optimizations {
bool apply_default_optimizations = 1;
}
// Whether to automatically tune performance knobs.
oneof optional_autotune {
bool autotune = 2;
}
// When autotuning is enabled (through autotune), determines whether to also
// autotune buffer sizes for datasets with parallelism.
oneof optional_autotune_buffers {
bool autotune_buffers = 3;
}
// When autotuning is enabled (through autotune), determines the CPU budget to
// use. Values greater than the number of schedulable CPU cores are allowed
// but may result in CPU contention.
oneof optional_autotune_cpu_budget {
int32 autotune_cpu_budget = 4;
}
// When autotuning is enabled (through autotune), determines the RAM budget to
// use. Values greater than the available RAM in bytes may result in OOM. If
// 0, defaults to half of the available RAM in bytes.
oneof optional_autotune_ram_budget {
int32 autotune_ram_budget = 5;
}
// Whether to fuse filter transformations.
oneof optional_filter_fusion {
bool filter_fusion = 6;
}
// Whether to fuse filter dataset that predicts random_uniform < rate into a
// sampling dataset.
oneof optional_filter_with_random_uniform_fusion {
bool filter_with_random_uniform_fusion = 7;
}
// Whether to hoist tf.random_uniform() ops out of map transformations.
oneof optional_hoist_random_uniform {
bool hoist_random_uniform = 8;
}
// Whether to fuse map and batch transformations.
oneof optional_map_and_batch_fusion {
bool map_and_batch_fusion = 9;
}
// Whether to fuse map and filter transformations.
oneof optional_map_and_filter_fusion {
bool map_and_filter_fusion = 10;
}
// Whether to fuse map transformations.
oneof optional_map_fusion {
bool map_fusion = 11;
}
// Whether to parallelize stateless map transformations.
oneof optional_map_parallelization {
bool map_parallelization = 12;
}
// The map vectorization options associated with the dataset.
MapVectorization map_vectorization = 13;
// Whether to eliminate no-op transformations.
oneof optional_noop_elimination {
bool noop_elimination = 14;
}
// Whether to parallelize copying of batch elements. This optimization is
// highly experimental and can cause performance degradation (e.g. when the
// parallelization overhead exceeds the benefits of performing the data copies
// in parallel). You should only enable this optimization if a) your input
// pipeline is bottlenecked on batching and b) you have validated that this
// optimization improves performance.
oneof optional_parallel_batch {
bool parallel_batch = 15;
}
// Whether to reorder ops that will discard data to the front of unary
// cardinality preserving transformations, e.g. dataset.map(...).take(3) will
// be optimized to dataset.take(3).map(...). For now this optimization will
// move `skip`, `shard` and `take` to the front of `map` and `prefetch`. This
// optimization is only for performance; it will not affect the output of the
// dataset.
oneof optional_reorder_data_discarding_ops {
bool reorder_data_discarding_ops = 16;
}
// Whether to fuse shuffle and repeat transformations.
oneof optional_shuffle_and_repeat_fusion {
bool shuffle_and_repeat_fusion = 17;
}
}
message ThreadingOptions {
// If set, it overrides the maximum degree of intra-op parallelism.
oneof optional_max_intra_op_parallelism {
int32 max_intra_op_parallelism = 1;
}
// If set, the dataset will use a private threadpool of the given size.
oneof optional_private_threadpool_size {
int32 private_threadpool_size = 2;
}
}
// Represents how to handle external state during serialization.
enum ExternalStatePolicy {
WARN = 0;
IGNORE = 1;
FAIL = 2;
}
// Message stored with Dataset objects to control how datasets are processed and
// optimized.
message Options {
// Whether the outputs need to be produced in deterministic order.
oneof optional_deterministic {
bool deterministic = 1;
}
// The distribution strategy options associated with the dataset.
DistributeOptions distribute_options = 2;
// The optimization options associated with the dataset.
OptimizationOptions optimization_options = 3;
// Whether to introduce 'slack' in the last `prefetch` of the input pipeline,
// if it exists. This may reduce CPU contention with accelerator host-side
// activity at the start of a step. The slack frequency is determined by the
// number of devices attached to this input pipeline.
oneof optional_slack {
bool slack = 4;
}
// The threading options associated with the dataset.
ThreadingOptions threading_options = 5;
// This option can be used to override the default policy for how to handle
// external state when serializing a dataset or checkpointing its iterator.
// There are three settings available - IGNORE: External state is ignored
// without a warning; WARN: External state is ignored and a warning is logged;
// FAIL: External state results in an error.
oneof optional_external_state_policy {
ExternalStatePolicy external_state_policy = 6;
}
}

View File

@ -25,6 +25,10 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace model {
constexpr int64 Model::kOptimizationPeriodMinMs;
constexpr int64 Model::kOptimizationPeriodMaxMs;
namespace {
// Helper function for node traversal that doesn't skip any nodes.

View File

@ -216,7 +216,20 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
// If autotuning is enabled, we initialize the parallelism to 1 to
// avoid accidentally running the machine out of memory before the
// optimization can pick values that respect the memory budget.
//
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
// default the parallelism to the size of the threadpool used for
// executing the user-defined computation. If this causes OOM, the
// input pipeline should either enable autotuning, or replace
// `AUTOTUNE` with fixed parallelism.
if (TF_PREDICT_TRUE(ctx->model())) {
num_parallel_calls_->value = 1;
} else {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
}
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),

View File

@ -221,7 +221,20 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == model::kAutotune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
// If autotuning is enabled, we initialize the parallelism to 1 to
// avoid accidentally running the machine out of memory before the
// optimization can pick values that respect the memory budget.
//
// If autotuning is disabled but the transformation uses `AUTOTUNE`, we
// default the parallelism to the size of the threadpool used for
// executing the user-defined computation. If this causes OOM, the
// input pipeline should either enable autotuning, or replace
// `AUTOTUNE` with fixed parallelism.
if (TF_PREDICT_TRUE(ctx->model())) {
num_parallel_calls_->value = 1;
} else {
num_parallel_calls_->value = ctx->runner_threadpool_size();
}
}
cancellation_manager_ =
absl::make_unique<CancellationManager>(ctx->cancellation_manager());

View File

@ -22,7 +22,7 @@ limitations under the License.
namespace tensorflow {
namespace io {
namespace internal {
string JoinPathImpl(std::initializer_list<tensorflow::StringPiece> paths);
std::string JoinPathImpl(std::initializer_list<tensorflow::StringPiece> paths);
}
// Utility routines for processing filenames
@ -43,7 +43,7 @@ string JoinPathImpl(std::initializer_list<tensorflow::StringPiece> paths);
// string path = io::JoinPath(FLAGS_test_srcdir, filename);
// string path = io::JoinPath("/full", "path", "to", "filename");
template <typename... T>
string JoinPath(const T&... args) {
std::string JoinPath(const T&... args) {
return internal::JoinPathImpl({args...});
}
#endif /* SWIG */
@ -71,7 +71,7 @@ tensorflow::StringPiece Extension(tensorflow::StringPiece path);
// "/alpha/beta/".
//
// Does not perform any path normalization.
string CommonPathPrefix(absl::Span<string const> paths);
std::string CommonPathPrefix(absl::Span<std::string const> paths);
// Collapse duplicate "/"s, resolve ".." and "." path elements, remove
// trailing "/".
@ -80,7 +80,7 @@ string CommonPathPrefix(absl::Span<string const> paths);
// invoke any system calls (getcwd(2)) in order to resolve relative
// paths with respect to the actual working directory. That is, this is purely
// string manipulation, completely independent of process state.
string CleanPath(tensorflow::StringPiece path);
std::string CleanPath(tensorflow::StringPiece path);
// Populates the scheme, host, and path from a URI. scheme, host, and path are
// guaranteed by this function to point into the contents of uri, even if
@ -95,11 +95,12 @@ void ParseURI(tensorflow::StringPiece uri, tensorflow::StringPiece* scheme,
// Creates a URI from a scheme, host, and path. If the scheme is empty, we just
// return the path.
string CreateURI(tensorflow::StringPiece scheme, tensorflow::StringPiece host,
tensorflow::StringPiece path);
std::string CreateURI(tensorflow::StringPiece scheme,
tensorflow::StringPiece host,
tensorflow::StringPiece path);
// Creates a temporary file name with an extension.
string GetTempFilename(const string& extension);
std::string GetTempFilename(const std::string& extension);
// Reads the TEST_UNDECLARED_OUTPUTS_DIR environment variable, and if set
// assigns `dir` to the value. `dir` is not modified if the environment variable
@ -108,7 +109,7 @@ string GetTempFilename(const string& extension);
//
// Note: This function obviates the need to deal with Bazel's odd path decisions
// on Windows, and should be preferred over a simple `getenv`.
bool GetTestUndeclaredOutputsDir(string* dir);
bool GetTestUndeclaredOutputsDir(std::string* dir);
} // namespace io
} // namespace tensorflow

View File

@ -107,7 +107,12 @@ Status TpuTracer::CollectData(XSpace* space) {
tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status,
buffer.data(), &size_in_bytes);
// Deserialize XSpace from the buffer and return it.
space->ParseFromArray(buffer.data(), buffer.size());
XSpace tpu_space;
tpu_space.ParseFromArray(buffer.data(), buffer.size());
for (XPlane& tpu_plane : *tpu_space.mutable_planes()) {
XPlane* plane = space->add_planes();
plane->Swap(&tpu_plane);
}
}
if (!status.ok()) {
LOG(ERROR) << "TPU tracer failed to collect data.";

View File

@ -584,10 +584,11 @@ void EventForest::ConnectInterThread(
}
}
void EventForest::ProcessLegacyRootEvents(
const std::vector<int64 /*EventType*/>& root_event_types) {
for (int64 root_event_type : root_event_types) {
if (auto root_events = gtl::FindOrNull(event_node_map_, root_event_type)) {
void EventForest::ProcessUserDefinedRootEvents(
const std::vector<int64 /*EventType*/>& user_defined_root_event_types) {
for (int64 user_defined_root_event_type : user_defined_root_event_types) {
if (auto root_events =
gtl::FindOrNull(event_node_map_, user_defined_root_event_type)) {
for (const auto& root_event : *root_events) {
root_event->SetIsRoot(true);
root_events_.push_back(root_event.get());
@ -869,10 +870,11 @@ void EventForest::ConnectTfDataEvents() {
VLOG(1) << num_matched << " consumer iterators matched.";
}
void EventForest::GroupEvents(const std::vector<int64>& root_event_types) {
void EventForest::GroupEvents(
const std::vector<int64>& user_defined_root_event_types) {
ProcessTensorFlowLoop();
ProcessWorker();
ProcessLegacyRootEvents(root_event_types);
ProcessUserDefinedRootEvents(user_defined_root_event_types);
CreateEventGroups();
MarkEagerlyExecutedGpuKernels();
MarkEagerlyExecutedCpuTfOps();

View File

@ -176,7 +176,8 @@ class EventForest {
void ConnectTfDataEvents();
void GroupEvents(const std::vector<int64>& root_event_types = {});
void GroupEvents(
const std::vector<int64>& user_defined_root_event_types = {});
const EventNodeMap& GetEventNodeMap() const { return event_node_map_; }
@ -198,8 +199,8 @@ class EventForest {
void ConnectInterThread(
const std::vector<InterThreadConnectInfo>& connect_info_list);
void ProcessLegacyRootEvents(
const std::vector<int64 /*EventType*/>& root_event_types);
void ProcessUserDefinedRootEvents(
const std::vector<int64 /*EventType*/>& user_defined_root_event_types);
// Creates event groups and populates group_metadata_map. If a TF loop is
// used, each TF loop iteration becomes a root. Otherwise, top root events

View File

@ -26,23 +26,22 @@ std::pair<std::vector<std::string>, std::vector<const char*>>
GetLibTpuInitArguments() {
// We make copies of the arguments returned by getenv because the memory
// returned may be altered or invalidated by further calls to getenv.
std::vector<std::string> argv;
std::vector<const char*> argv_ptr;
std::vector<std::string> args;
std::vector<const char*> arg_ptrs;
// Retrieve arguments from environment if applicable.
char* env = getenv("LIBTPU_INIT_ARGS");
if (env != nullptr) {
// TODO(frankchn): Handles quotes properly if necessary.
argv = absl::StrSplit(env, ' ');
args = absl::StrSplit(env, ' ');
}
argv_ptr.reserve(argv.size());
for (int i = 0; i < argv.size(); ++i) {
argv_ptr.push_back(argv[i].data());
arg_ptrs.reserve(args.size());
for (int i = 0; i < args.size(); ++i) {
arg_ptrs.push_back(args[i].data());
}
argv_ptr.push_back(nullptr);
return {argv, argv_ptr};
return {args, arg_ptrs};
}
} // namespace tpu

View File

@ -44,7 +44,7 @@ Status ReadFloatFromEnvVar(StringPiece env_var_name, float default_val,
// Returns a string into "value" from the environmental variable "env_var_name".
// If it is unset, the default value is used.
Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val,
string* value);
std::string* value);
} // namespace tensorflow

View File

@ -248,6 +248,19 @@ tflite_micro_cc_test(
],
)
tflite_micro_cc_test(
name = "exp_test",
srcs = ["exp_test.cc"],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
tflite_micro_cc_test(
name = "zeros_like_test",
srcs = ["zeros_like_test.cc"],

View File

@ -54,7 +54,6 @@ void TestExp(const int* input_dims_data, const float* input_data,
TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f);
}
}
} // namespace
} // namespace testing
} // namespace tflite
@ -62,13 +61,16 @@ void TestExp(const int* input_dims_data, const float* input_data,
TF_LITE_MICRO_TESTS_BEGIN
TF_LITE_MICRO_TEST(SingleDim) {
float output_data[7];
const int input_dims[] = {2, 1, 7};
const float input_values[] = {0.0f, 1.0f, -1.0f, 100.0f,
-100.0f, 0.01f, -0.01f};
const float golden[] = {
1.0f, 2.71828f, 0.36788f, std::numeric_limits<float>::infinity(),
1.17549e-38f, 1.01005f, 0.99005f};
constexpr int kInputSize = 7;
float output_data[kInputSize];
const int input_dims[] = {2, 1, kInputSize};
const float input_values[kInputSize] = {0.0f, 1.0f, -1.0f, 100.0f,
-100.0f, 0.01f, -0.01f};
float golden[kInputSize];
for (int i = 0; i < kInputSize; ++i) {
golden[i] = std::exp(input_values[i]);
}
tflite::testing::TestExp(input_dims, input_values, golden, output_data);
}

View File

@ -142,12 +142,14 @@ extern bool did_test_fail;
} \
} while (false)
// The check vx != vy is needed to properly handle the case where both
// x and y evaluate to infinity. See #46960 for more details.
#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \
do { \
auto vx = (x); \
auto vy = (y); \
auto delta = ((vx) > (vy)) ? ((vx) - (vy)) : ((vy) - (vx)); \
if (delta > epsilon) { \
if (vx != vy && delta > epsilon) { \
MicroPrintf(#x " (%f) near " #y " (%f) failed at %s:%d", \
static_cast<double>(vx), static_cast<double>(vy), __FILE__, \
__LINE__); \

View File

@ -564,22 +564,19 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
# Tests that vectorization maintains the determinism setting.
expect_determinism = local_determinism or (local_determinism is None and
global_determinism)
elements = list(range(1000))
num_elements = 1000
def dataset_fn(delay_ms):
def sleep(x):
time.sleep(delay_ms / 1000)
# Inject random delay in the interval [0, delay_ms / 1000).
time.sleep(delay_ms * (np.random.randint(x + 1) / (x + 1)) / 1000)
return x
def map_function(x):
if math_ops.equal(x, 0):
return check_ops.ensure_shape(
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
else:
return x
return check_ops.ensure_shape(
script_ops.py_func(sleep, [x], x.dtype, stateful=False), ())
dataset = dataset_ops.Dataset.from_tensor_slices(elements)
dataset = dataset_ops.Dataset.range(num_elements)
dataset = dataset.map(
map_function, num_parallel_calls=10, deterministic=local_determinism)
dataset = dataset.batch(1)
@ -595,7 +592,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
self.checkDeterminism(
dataset_fn,
expect_determinism,
expected_elements=[[element] for element in elements])
expected_elements=[[element] for element in range(num_elements)])
@combinations.generate(test_base.default_test_combinations())
def testOptimizationIgnoreStateful(self):

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import enum
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@ -35,6 +36,34 @@ class AutoShardPolicy(enum.IntEnum):
FILE = 1
DATA = 2
@classmethod
def _to_proto(cls, obj):
"""Convert enum to proto."""
if obj == cls.OFF:
return dataset_options_pb2.AutoShardPolicy.OFF
if obj == cls.FILE:
return dataset_options_pb2.AutoShardPolicy.FILE
if obj == cls.DATA:
return dataset_options_pb2.AutoShardPolicy.DATA
if obj == cls.AUTO:
return dataset_options_pb2.AutoShardPolicy.AUTO
raise ValueError("%s._to_proto() is called with undefined enum %s." %
(cls.__name__, obj.name))
@classmethod
def _from_proto(cls, pb):
"""Convert proto to enum."""
if pb == dataset_options_pb2.AutoShardPolicy.OFF:
return cls.OFF
if pb == dataset_options_pb2.AutoShardPolicy.FILE:
return cls.FILE
if pb == dataset_options_pb2.AutoShardPolicy.DATA:
return cls.DATA
if pb == dataset_options_pb2.AutoShardPolicy.AUTO:
return cls.AUTO
raise ValueError("%s._from_proto() is called with undefined enum %s." %
(cls.__name__, pb))
@tf_export("data.experimental.ExternalStatePolicy")
class ExternalStatePolicy(enum.Enum):
@ -47,6 +76,30 @@ class ExternalStatePolicy(enum.Enum):
IGNORE = 1
FAIL = 2
@classmethod
def _to_proto(cls, obj):
"""Convert enum to proto."""
if obj == cls.IGNORE:
return dataset_options_pb2.ExternalStatePolicy.IGNORE
if obj == cls.FAIL:
return dataset_options_pb2.ExternalStatePolicy.FAIL
if obj == cls.WARN:
return dataset_options_pb2.ExternalStatePolicy.WARN
raise ValueError("%s._to_proto() is called with undefined enum %s." %
(cls.__name__, obj.name))
@classmethod
def _from_proto(cls, pb):
"""Convert proto to enum."""
if pb == dataset_options_pb2.ExternalStatePolicy.IGNORE:
return cls.IGNORE
if pb == dataset_options_pb2.ExternalStatePolicy.FAIL:
return cls.FAIL
if pb == dataset_options_pb2.ExternalStatePolicy.WARN:
return cls.WARN
raise ValueError("%s._from_proto() is called with undefined enum %s." %
(cls.__name__, pb))
@tf_export("data.experimental.DistributeOptions")
class DistributeOptions(options.OptionsBase):
@ -89,3 +142,15 @@ class DistributeOptions(options.OptionsBase):
docstring=
"The number of devices attached to this input pipeline. This will be "
"automatically set by MultiDeviceIterator.")
def _to_proto(self):
pb = dataset_options_pb2.DistributeOptions()
pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy) # pylint: disable=protected-access
if self.num_devices is not None:
pb.num_devices = self.num_devices
return pb
def _from_proto(self, pb):
self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy) # pylint: disable=protected-access
if pb.WhichOneof("optional_num_devices") is not None:
self.num_devices = pb.num_devices

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import enum
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@ -69,6 +70,20 @@ class MapVectorizationOptions(options.OptionsBase):
else:
return ["map_vectorization:use_choose_fastest:false"]
def _to_proto(self):
pb = dataset_options_pb2.MapVectorization()
if self.enabled is not None:
pb.enabled = self.enabled
if self.use_choose_fastest is not None:
pb.use_choose_fastest = self.use_choose_fastest
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_enabled") is not None:
self.enabled = pb.enabled
if pb.WhichOneof("optional_use_choose_fastest") is not None:
self.use_choose_fastest = pb.use_choose_fastest
@tf_export("data.experimental.OptimizationOptions")
class OptimizationOptions(options.OptionsBase):
@ -327,3 +342,77 @@ class OptimizationOptions(options.OptionsBase):
graph_rewrite_configs.append(optimization + ":autotune:true")
return graph_rewrite_configs
def _to_proto(self):
pb = dataset_options_pb2.OptimizationOptions()
if self.apply_default_optimizations is not None:
pb.apply_default_optimizations = self.apply_default_optimizations
if self.autotune is not None:
pb.autotune = self.autotune
if self.autotune_buffers is not None:
pb.autotune_buffers = self.autotune_buffers
if self.autotune_cpu_budget is not None:
pb.autotune_cpu_budget = self.autotune_cpu_budget
if self.autotune_ram_budget is not None:
pb.autotune_ram_budget = self.autotune_ram_budget
if self.filter_fusion is not None:
pb.filter_fusion = self.filter_fusion
if self.filter_with_random_uniform_fusion is not None:
pb.filter_with_random_uniform_fusion = (
self.filter_with_random_uniform_fusion)
if self.hoist_random_uniform is not None:
pb.hoist_random_uniform = self.hoist_random_uniform
if self.map_and_batch_fusion is not None:
pb.map_and_batch_fusion = self.map_and_batch_fusion
if self.map_and_filter_fusion is not None:
pb.map_and_filter_fusion = self.map_and_filter_fusion
if self.map_fusion is not None:
pb.map_fusion = self.map_fusion
if self.map_parallelization is not None:
pb.map_parallelization = self.map_parallelization
pb.map_vectorization.CopyFrom(self.map_vectorization._to_proto()) # pylint: disable=protected-access
if self.noop_elimination is not None:
pb.noop_elimination = self.noop_elimination
if self.parallel_batch is not None:
pb.parallel_batch = self.parallel_batch
if self.reorder_data_discarding_ops is not None:
pb.reorder_data_discarding_ops = self.reorder_data_discarding_ops
if self.shuffle_and_repeat_fusion is not None:
pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_apply_default_optimizations") is not None:
self.apply_default_optimizations = pb.apply_default_optimizations
if pb.WhichOneof("optional_autotune") is not None:
self.autotune = pb.autotune
if pb.WhichOneof("optional_autotune_buffers") is not None:
self.autotune_buffers = pb.autotune_buffers
if pb.WhichOneof("optional_autotune_cpu_budget") is not None:
self.autotune_cpu_budget = pb.autotune_cpu_budget
if pb.WhichOneof("optional_autotune_ram_budget") is not None:
self.autotune_ram_budget = pb.autotune_ram_budget
if pb.WhichOneof("optional_filter_fusion") is not None:
self.filter_fusion = pb.filter_fusion
if pb.WhichOneof("optional_filter_with_random_uniform_fusion") is not None:
self.filter_with_random_uniform_fusion = (
pb.filter_with_random_uniform_fusion)
if pb.WhichOneof("optional_hoist_random_uniform") is not None:
self.hoist_random_uniform = pb.hoist_random_uniform
if pb.WhichOneof("optional_map_and_batch_fusion") is not None:
self.map_and_batch_fusion = pb.map_and_batch_fusion
if pb.WhichOneof("optional_map_and_filter_fusion") is not None:
self.map_and_filter_fusion = pb.map_and_filter_fusion
if pb.WhichOneof("optional_map_fusion") is not None:
self.map_fusion = pb.map_fusion
if pb.WhichOneof("optional_map_parallelization") is not None:
self.map_parallelization = pb.map_parallelization
self.map_vectorization._from_proto(pb.map_vectorization) # pylint: disable=protected-access
if pb.WhichOneof("optional_noop_elimination") is not None:
self.noop_elimination = pb.noop_elimination
if pb.WhichOneof("optional_parallel_batch") is not None:
self.parallel_batch = pb.parallel_batch
if pb.WhichOneof("optional_reorder_data_discarding_ops") is not None:
self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops
if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@ -48,3 +49,17 @@ class ThreadingOptions(options.OptionsBase):
ty=int,
docstring=
"If set, the dataset will use a private threadpool of the given size.")
def _to_proto(self):
pb = dataset_options_pb2.ThreadingOptions()
if self.max_intra_op_parallelism is not None:
pb.max_intra_op_parallelism = self.max_intra_op_parallelism
if self.private_threadpool_size is not None:
pb.private_threadpool_size = self.private_threadpool_size
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_max_intra_op_parallelism") is not None:
self.max_intra_op_parallelism = pb.max_intra_op_parallelism
if pb.WhichOneof("optional_private_threadpool_size") is not None:
self.private_threadpool_size = pb.private_threadpool_size

View File

@ -23,6 +23,8 @@ import sys
from absl.testing import parameterized
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
@ -127,6 +129,67 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase):
result = result.concatenate(ds)
self.assertDatasetProduces(result, [0]*1000)
@combinations.generate(test_base.default_test_combinations())
def testOptionsProtoRoundTrip(self):
options = dataset_ops.Options()
options.experimental_deterministic = True
options.experimental_external_state_policy = (
distribute_options.ExternalStatePolicy.FAIL)
options.experimental_distribute.auto_shard_policy = (
distribute_options.AutoShardPolicy.DATA)
options.experimental_distribute.num_devices = 1000
options.experimental_optimization.apply_default_optimizations = True
options.experimental_optimization.autotune = True
options.experimental_optimization.autotune_buffers = True
options.experimental_optimization.autotune_cpu_budget = 10
options.experimental_optimization.autotune_ram_budget = 20
options.experimental_optimization.filter_fusion = True
options.experimental_optimization.filter_with_random_uniform_fusion = True
options.experimental_optimization.hoist_random_uniform = True
options.experimental_optimization.map_and_batch_fusion = True
options.experimental_optimization.map_and_filter_fusion = True
options.experimental_optimization.map_fusion = True
options.experimental_optimization.map_parallelization = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.map_vectorization.use_choose_fastest = (
True)
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.parallel_batch = True
options.experimental_optimization.reorder_data_discarding_ops = True
options.experimental_optimization.shuffle_and_repeat_fusion = True
options.experimental_slack = True
options.experimental_threading.max_intra_op_parallelism = 30
options.experimental_threading.private_threadpool_size = 40
pb = options._to_proto()
result = dataset_ops.Options()
result._from_proto(pb)
self.assertEqual(options, result)
@combinations.generate(test_base.default_test_combinations())
def testOptionsProtoDefaultValuesRoundTrip(self):
options = dataset_ops.Options()
pb = options._to_proto()
result = dataset_ops.Options()
result._from_proto(pb)
self.assertEqual(options, result)
@combinations.generate(test_base.default_test_combinations())
def testProtoOptionsDefaultValuesRoundTrip(self):
pb = dataset_options_pb2.Options()
options = dataset_ops.Options()
options._from_proto(pb)
result = options._to_proto()
expected_pb = dataset_options_pb2.Options()
expected_pb.distribute_options.CopyFrom(
dataset_options_pb2.DistributeOptions())
expected_pb.optimization_options.CopyFrom(
dataset_options_pb2.OptimizationOptions())
expected_pb.optimization_options.map_vectorization.CopyFrom(
dataset_options_pb2.MapVectorization())
expected_pb.threading_options.CopyFrom(
dataset_options_pb2.ThreadingOptions())
self.assertProtoEquals(expected_pb, result)
if __name__ == "__main__":
test.main()

View File

@ -340,6 +340,7 @@ class DatasetTestBase(test.TestCase):
dataset = dataset_fn(delay_ms)
actual = self.getDatasetOutput(dataset)
self.assertCountEqual(expected_elements, actual)
if actual[0] != expected_elements[0]:
return
for i in range(len(actual)):
if actual[i] != expected_elements[i]:
return
self.fail("Failed to observe nondeterministic ordering")

View File

@ -28,6 +28,7 @@ import numpy as np
import six
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from tensorflow.core.framework import dataset_options_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import distribute_options
@ -3039,6 +3040,34 @@ class Options(options_lib.OptionsBase):
"state is ignored and a warning is logged; FAIL: External state results "
"in an error.")
def _to_proto(self):
pb = dataset_options_pb2.Options()
if self.experimental_deterministic is not None:
pb.deterministic = self.experimental_deterministic
pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto()) # pylint: disable=protected-access
if self.experimental_external_state_policy is not None:
pb.external_state_policy = (
distribute_options.ExternalStatePolicy._to_proto( # pylint: disable=protected-access
self.experimental_external_state_policy))
pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access
if self.experimental_slack is not None:
pb.slack = self.experimental_slack
pb.threading_options.CopyFrom(self.experimental_threading._to_proto()) # pylint: disable=protected-access
return pb
def _from_proto(self, pb):
if pb.WhichOneof("optional_deterministic") is not None:
self.experimental_deterministic = pb.deterministic
self.experimental_distribute._from_proto(pb.distribute_options) # pylint: disable=protected-access
if pb.WhichOneof("optional_external_state_policy") is not None:
self.experimental_external_state_policy = (
distribute_options.ExternalStatePolicy._from_proto( # pylint: disable=protected-access
pb.external_state_policy))
self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access
if pb.WhichOneof("optional_slack") is not None:
self.experimental_slack = pb.slack
self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access
def _graph_rewrites(self):
"""Produces lists of enabled, disabled, default static graph rewrites.

View File

@ -59,6 +59,14 @@ class OptionsBase(object):
raise AttributeError(
"Cannot set the property %s on %s." % (name, type(self).__name__))
def _to_proto(self):
"""Convert options to protocol buffer."""
raise NotImplementedError("%s._to_proto()" % type(self).__name__)
def _from_proto(self, pb):
"""Convert protocol buffer to options."""
raise NotImplementedError("%s._from_proto()" % type(self).__name__)
# Creates a namedtuple with three keys for optimization graph rewrites settings.
def graph_rewrites():

View File

@ -95,7 +95,7 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
return os << ComputationTypeString(ty);
}
string DataTypeString(DataType ty) {
std::string DataTypeString(DataType ty) {
switch (ty) {
case DataType::kHalf:
return "f16";

View File

@ -134,7 +134,7 @@ enum class PointerMode {
};
// Converts a ComputationType to a string.
string DataTypeString(DataType ty);
std::string DataTypeString(DataType ty);
std::ostream &operator<<(std::ostream &os, DataType ty);

View File

@ -685,8 +685,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
)
# Check out LLVM and MLIR from llvm-project.
LLVM_COMMIT = "a1a1d338e99dc9c6d1234b70f43dea2e1bb2f8ce"
LLVM_SHA256 = "0adf75d405fe714b2c8a0ab1db4c10dcf9629b57e001191d3e5520407d563cc5"
LLVM_COMMIT = "a4fa667dee6012e350bd405ee7a759a53738b279"
LLVM_SHA256 = "11ef06ff3c01638d3bd11d9095259db92ab69ec85f101f4969c6c4ad9f154f8e"
LLVM_URLS = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),