From 5d9973623d4e7a0d632721af46ab23fb66a6d202 Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Fri, 5 Feb 2021 10:13:38 -0800 Subject: [PATCH 01/19] Fix kernel_exp_test with the Xtensa toolchain. Underlying issue is a bug in the EXPECT_NEAR macro, as described in Also, * added an exp_test rule to the BUILD file. * changed the golden value computation to make use of std::exp instead of hard-coded values. This is closer to the TfLite test as well. Manually confirmed that the following command passes: ``` make -f tensorflow/lite/micro/tools/make/Makefile TARGET=xtensa OPTIMIZED_KERNEL_DIR=xtensa TARGET_ARCH=fusion_f1 XTENSA_CORE=F1_190305_swupgrade test -j8 ``` Fixes #46960 --- tensorflow/lite/micro/kernels/BUILD | 13 +++++++++++++ tensorflow/lite/micro/kernels/exp_test.cc | 18 ++++++++++-------- tensorflow/lite/micro/testing/micro_test.h | 4 +++- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 857ca367017..74766450a58 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -235,6 +235,19 @@ tflite_micro_cc_test( ], ) +tflite_micro_cc_test( + name = "exp_test", + srcs = ["exp_test.cc"], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:debug_log", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflite_micro_cc_test( name = "pooling_test", srcs = [ diff --git a/tensorflow/lite/micro/kernels/exp_test.cc b/tensorflow/lite/micro/kernels/exp_test.cc index 536b7f491c4..9a77686fd8b 100644 --- a/tensorflow/lite/micro/kernels/exp_test.cc +++ b/tensorflow/lite/micro/kernels/exp_test.cc @@ -54,7 +54,6 @@ void TestExp(const int* input_dims_data, const float* input_data, TF_LITE_MICRO_EXPECT_NEAR(expected_output_data[i], output_data[i], 1e-5f); } } - } // namespace } // namespace testing } // namespace tflite @@ -62,13 +61,16 @@ void TestExp(const int* input_dims_data, const float* input_data, TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(SingleDim) { - float output_data[7]; - const int input_dims[] = {2, 1, 7}; - const float input_values[] = {0.0f, 1.0f, -1.0f, 100.0f, - -100.0f, 0.01f, -0.01f}; - const float golden[] = { - 1.0f, 2.71828f, 0.36788f, std::numeric_limits::infinity(), - 1.17549e-38f, 1.01005f, 0.99005f}; + constexpr int kInputSize = 7; + float output_data[kInputSize]; + const int input_dims[] = {2, 1, kInputSize}; + const float input_values[kInputSize] = {0.0f, 1.0f, -1.0f, 100.0f, + -100.0f, 0.01f, -0.01f}; + float golden[kInputSize]; + for (int i = 0; i < kInputSize; ++i) { + golden[i] = std::exp(input_values[i]); + } + tflite::testing::TestExp(input_dims, input_values, golden, output_data); } diff --git a/tensorflow/lite/micro/testing/micro_test.h b/tensorflow/lite/micro/testing/micro_test.h index 8b08d0f0751..b751876cb60 100644 --- a/tensorflow/lite/micro/testing/micro_test.h +++ b/tensorflow/lite/micro/testing/micro_test.h @@ -142,12 +142,14 @@ extern bool did_test_fail; } \ } while (false) +// The check vx != vy is needed to properly handle the case where both +// x and y evaluate to infinity. See #46960 for more details. #define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \ do { \ auto vx = (x); \ auto vy = (y); \ auto delta = ((vx) > (vy)) ? ((vx) - (vy)) : ((vy) - (vx)); \ - if (delta > epsilon) { \ + if (vx != vy && delta > epsilon) { \ MicroPrintf(#x " (%f) near " #y " (%f) failed at %s:%d", \ static_cast(vx), static_cast(vy), __FILE__, \ __LINE__); \ From 4413f34d5cf499df99999cb1d99d6021cbd99e13 Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Fri, 5 Feb 2021 14:47:25 -0800 Subject: [PATCH 02/19] Refresh device in EagerContext and pflr when device is updated. This is required to allow RuntimeFallback and KernelFallback to access TPU device created by tfrt. PiperOrigin-RevId: 355932739 Change-Id: I043d217e06612734ba0e2b0bbf1cd672b4c9bf44 --- .../core/common_runtime/eager/context.cc | 4 ++-- .../core/common_runtime/eager/context.h | 2 +- .../process_function_library_runtime.cc | 23 +++++++++++-------- .../process_function_library_runtime.h | 8 +++++-- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 4f3ca28813e..7c20766a1ce 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -1277,7 +1277,7 @@ Status EagerContext::UpdateRemoteMaster( context_view_id_++; remote_eager_workers_ = std::move(remote_eager_workers); - pflr_->InitializeDeviceSet(); + pflr_->InitializeDeviceAndFlr(); InitPrioritizedDeviceTypeList(); default_executor_.ClearError(); @@ -1496,7 +1496,7 @@ Status EagerContext::UpdateRemoteWorker( remote_contexts_ = remote_contexts; remote_eager_workers_ = std::move(remote_eager_workers); InitPrioritizedDeviceTypeList(); - pflr_->InitializeDeviceSet(); + pflr_->InitializeDeviceAndFlr(); } // No need to update remote_device_manager_ since it's not owned for remote diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 309d1c1ec44..28c0b40f43b 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -486,6 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; const SessionOptions& session_options() const { return opts_; } + void InitPrioritizedDeviceTypeList(); private: Rendezvous* CreateRendezvous(int64 step_id) const { @@ -510,7 +511,6 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { ~EagerContext() override; - void InitPrioritizedDeviceTypeList(); Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); Status RegisterExistingFunctionsOnRemoteWorkers( const std::vector& remote_workers); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 9e02fe8bfca..659d29601e8 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -100,7 +100,9 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( std::unique_ptr>), next_handle_(0), session_metadata_(session_metadata), - rendezvous_factory_(std::move(rendezvous_factory)) { + rendezvous_factory_(std::move(rendezvous_factory)), + optimizer_options_(optimizer_options), + graph_def_version_(graph_def_version) { if (device_mgr == nullptr) { (*flr_map_)[nullptr] = NewFunctionLibraryRuntime( nullptr, env, config_ ? &(*config_) : nullptr, nullptr, @@ -108,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( session_metadata_, this); return; } - for (Device* d : device_mgr->ListDevices()) { - (*flr_map_)[d] = NewFunctionLibraryRuntime( - device_mgr, env, config_ ? &(*config_) : nullptr, d, graph_def_version, - lib_def_, default_thread_pool, optimizer_options, session_metadata_, - this); - } - - InitializeDeviceSet(); + InitializeDeviceAndFlr(); } /* static */ @@ -214,7 +209,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext( "function executions"); } -void ProcessFunctionLibraryRuntime::InitializeDeviceSet() { +void ProcessFunctionLibraryRuntime::InitializeDeviceAndFlr() { DeviceMgr const* all_devices = device_mgr_; if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) { all_devices = parent_->remote_device_mgr(); @@ -225,6 +220,14 @@ void ProcessFunctionLibraryRuntime::InitializeDeviceSet() { for (auto d : all_devices->ListDevices()) { device_set_->AddDevice(d); } + for (Device* d : device_mgr_->ListDevices()) { + if ((*flr_map_)[d] == nullptr) { + (*flr_map_)[d] = NewFunctionLibraryRuntime( + device_mgr_, env_, config_ ? &(*config_) : nullptr, d, + graph_def_version_, lib_def_, default_thread_pool_, + optimizer_options_, session_metadata_, this); + } + } } FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 5d0b654a2c4..3dbb2d0fc93 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -207,8 +207,9 @@ class ProcessFunctionLibraryRuntime { return device_set_; } - // Initialize the set of local and remote devices for op device selection. - void InitializeDeviceSet(); + // Initialize the set of local and remote devices and corresponding flr for op + // device selection. + void InitializeDeviceAndFlr(); const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; } @@ -478,6 +479,9 @@ class ProcessFunctionLibraryRuntime { int next_handle_ TF_GUARDED_BY(mu_); const SessionMetadata* const session_metadata_; const Rendezvous::Factory rendezvous_factory_; + + const OptimizerOptions optimizer_options_; + const int graph_def_version_; }; } // namespace tensorflow From 5e12e5cde1a2fa13ae2222fa8a5e091b8f26b62a Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Fri, 5 Feb 2021 15:06:28 -0800 Subject: [PATCH 03/19] [XLA] Use DumpToFileInDir for hlo_execution_profile_data This ensures that hlo_execution_profile_data when xla_dump_to points to special, non-path, locations that must be interpreted. PiperOrigin-RevId: 355936626 Change-Id: I689fd3aba9017453878abc0add0f02f4ca4aedc9 --- tensorflow/compiler/xla/service/executable.cc | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index fd2f1dd54d2..8905470a2cb 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -256,16 +256,11 @@ Status ExecuteWrapperAfterExecution( } } - const auto& dump_path = - executable->module_config().debug_options().xla_dump_to(); if (executable->module_config().debug_options().xla_hlo_profile() && - state.profile_ptr != nullptr && !dump_path.empty()) { - const std::string full_path = - tensorflow::io::JoinPath(dump_path, "hlo_execution_profile_data"); - TF_CHECK_OK(tensorflow::WriteStringToFile( - tensorflow::Env::Default(), full_path, - state.profile_ptr->ToProto().SerializeAsString())) - << "Error saving HloExecutionProfileData to " << full_path; + state.profile_ptr != nullptr) { + DumpToFileInDir(executable->module(), /*file_prefix=*/"", + /*file_suffix=*/"hlo_execution_profile_data", + state.profile_ptr->ToProto().SerializeAsString()); } if (state.profile_ptr != nullptr) { From eb001c716506815ebc3bcd5b1ac6eb2cadb6d244 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Feb 2021 15:49:32 -0800 Subject: [PATCH 04/19] [tf.data] Add a new proto to store tf.data.Options data and conversion functions between tf.data.Options and their proto representation. PiperOrigin-RevId: 355945339 Change-Id: I3eba0c05899aeda629e34483b75fc61c435cb4d8 --- tensorflow/core/BUILD | 1 + tensorflow/core/framework/BUILD | 9 + .../core/framework/dataset_options.proto | 179 ++++++++++++++++++ .../experimental/ops/distribute_options.py | 65 +++++++ .../experimental/ops/optimization_options.py | 89 +++++++++ .../experimental/ops/threading_options.py | 15 ++ .../python/data/kernel_tests/options_test.py | 63 ++++++ tensorflow/python/data/ops/dataset_ops.py | 29 +++ tensorflow/python/data/util/options.py | 8 + 9 files changed, 458 insertions(+) create mode 100644 tensorflow/core/framework/dataset_options.proto diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 744c0b75ae5..eafee44d2ef 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -203,6 +203,7 @@ FRAMEWORK_PROTO_SRCS = [ "//tensorflow/core/framework:model.proto", "//tensorflow/core/framework:node_def.proto", "//tensorflow/core/framework:op_def.proto", + "//tensorflow/core/framework:dataset_options.proto", "//tensorflow/core/framework:reader_base.proto", "//tensorflow/core/framework:remote_fused_graph_execute_info.proto", "//tensorflow/core/framework:resource_handle.proto", diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 12d637ad30a..30bf857163b 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -119,6 +119,7 @@ exports_files( "api_def.proto", "attr_value.proto", "cost_graph.proto", + "dataset_options.proto", "device_attributes.proto", "function.proto", "graph.proto", @@ -1660,6 +1661,13 @@ tf_proto_library( make_default_target_header_only = True, ) +tf_proto_library( + name = "dataset_options_proto", + srcs = ["dataset_options.proto"], + cc_api_version = 2, + make_default_target_header_only = True, +) + tf_proto_library( name = "protos_all", cc_api_version = 2, @@ -1678,6 +1686,7 @@ tf_proto_library( ":model_proto", ":node_def_proto", ":op_def_proto", + ":dataset_options_proto", ":reader_base_proto", ":remote_fused_graph_execute_info_proto", ":resource_handle_proto", diff --git a/tensorflow/core/framework/dataset_options.proto b/tensorflow/core/framework/dataset_options.proto new file mode 100644 index 00000000000..05e15e15625 --- /dev/null +++ b/tensorflow/core/framework/dataset_options.proto @@ -0,0 +1,179 @@ +syntax = "proto3"; + +package tensorflow.data; + +// Represents the type of auto-sharding we enable. +enum AutoShardPolicy { + AUTO = 0; + FILE = 1; + DATA = 2; + OFF = -1; +} + +message DistributeOptions { + // The type of sharding that auto-shard should attempt. If this is set to + // FILE, then we will attempt to shard by files (each worker will get a set of + // files to process). If we cannot find a set of files to shard for at least + // one file per worker, we will error out. When this option is selected, make + // sure that you have enough files so that each worker gets at least one file. + // There will be a runtime error thrown if there are insufficient files. If + // this is set to DATA, then we will shard by elements produced by the + // dataset, and each worker will process the whole dataset and discard the + // portion that is not for itself. If this is set to OFF, then we will not + // autoshard, and each worker will receive a copy of the full dataset. This + // option is set to AUTO by default, AUTO will attempt to first shard by FILE, + // and fall back to sharding by DATA if we cannot find a set of files to + // shard. + AutoShardPolicy auto_shard_policy = 1; + // The number of devices attached to this input pipeline. + oneof optional_num_devices { + int32 num_devices = 2; + } +} + +message MapVectorization { + // Whether to vectorize map transformations. + oneof optional_enabled { + bool enabled = 1; + } + // Whether to use ChooseFastestBranchDataset with this transformation. If + // True, the pipeline picks between the vectorized and original segment at + // runtime based on their iterations speed. + oneof optional_use_choose_fastest { + bool use_choose_fastest = 2; + } +} + +message OptimizationOptions { + // Whether to apply default graph optimizations. If False, only graph + // optimizations that have been explicitly enabled will be applied. + oneof optional_apply_default_optimizations { + bool apply_default_optimizations = 1; + } + // Whether to automatically tune performance knobs. + oneof optional_autotune { + bool autotune = 2; + } + // When autotuning is enabled (through autotune), determines whether to also + // autotune buffer sizes for datasets with parallelism. + oneof optional_autotune_buffers { + bool autotune_buffers = 3; + } + // When autotuning is enabled (through autotune), determines the CPU budget to + // use. Values greater than the number of schedulable CPU cores are allowed + // but may result in CPU contention. + oneof optional_autotune_cpu_budget { + int32 autotune_cpu_budget = 4; + } + // When autotuning is enabled (through autotune), determines the RAM budget to + // use. Values greater than the available RAM in bytes may result in OOM. If + // 0, defaults to half of the available RAM in bytes. + oneof optional_autotune_ram_budget { + int32 autotune_ram_budget = 5; + } + // Whether to fuse filter transformations. + oneof optional_filter_fusion { + bool filter_fusion = 6; + } + // Whether to fuse filter dataset that predicts random_uniform < rate into a + // sampling dataset. + oneof optional_filter_with_random_uniform_fusion { + bool filter_with_random_uniform_fusion = 7; + } + // Whether to hoist tf.random_uniform() ops out of map transformations. + oneof optional_hoist_random_uniform { + bool hoist_random_uniform = 8; + } + // Whether to fuse map and batch transformations. + oneof optional_map_and_batch_fusion { + bool map_and_batch_fusion = 9; + } + // Whether to fuse map and filter transformations. + oneof optional_map_and_filter_fusion { + bool map_and_filter_fusion = 10; + } + // Whether to fuse map transformations. + oneof optional_map_fusion { + bool map_fusion = 11; + } + // Whether to parallelize stateless map transformations. + oneof optional_map_parallelization { + bool map_parallelization = 12; + } + // The map vectorization options associated with the dataset. + MapVectorization map_vectorization = 13; + // Whether to eliminate no-op transformations. + oneof optional_noop_elimination { + bool noop_elimination = 14; + } + // Whether to parallelize copying of batch elements. This optimization is + // highly experimental and can cause performance degradation (e.g. when the + // parallelization overhead exceeds the benefits of performing the data copies + // in parallel). You should only enable this optimization if a) your input + // pipeline is bottlenecked on batching and b) you have validated that this + // optimization improves performance. + oneof optional_parallel_batch { + bool parallel_batch = 15; + } + // Whether to reorder ops that will discard data to the front of unary + // cardinality preserving transformations, e.g. dataset.map(...).take(3) will + // be optimized to dataset.take(3).map(...). For now this optimization will + // move `skip`, `shard` and `take` to the front of `map` and `prefetch`. This + // optimization is only for performance; it will not affect the output of the + // dataset. + oneof optional_reorder_data_discarding_ops { + bool reorder_data_discarding_ops = 16; + } + // Whether to fuse shuffle and repeat transformations. + oneof optional_shuffle_and_repeat_fusion { + bool shuffle_and_repeat_fusion = 17; + } +} + +message ThreadingOptions { + // If set, it overrides the maximum degree of intra-op parallelism. + oneof optional_max_intra_op_parallelism { + int32 max_intra_op_parallelism = 1; + } + // If set, the dataset will use a private threadpool of the given size. + oneof optional_private_threadpool_size { + int32 private_threadpool_size = 2; + } +} + +// Represents how to handle external state during serialization. +enum ExternalStatePolicy { + WARN = 0; + IGNORE = 1; + FAIL = 2; +} + +// Message stored with Dataset objects to control how datasets are processed and +// optimized. +message Options { + // Whether the outputs need to be produced in deterministic order. + oneof optional_deterministic { + bool deterministic = 1; + } + // The distribution strategy options associated with the dataset. + DistributeOptions distribute_options = 2; + // The optimization options associated with the dataset. + OptimizationOptions optimization_options = 3; + // Whether to introduce 'slack' in the last `prefetch` of the input pipeline, + // if it exists. This may reduce CPU contention with accelerator host-side + // activity at the start of a step. The slack frequency is determined by the + // number of devices attached to this input pipeline. + oneof optional_slack { + bool slack = 4; + } + // The threading options associated with the dataset. + ThreadingOptions threading_options = 5; + // This option can be used to override the default policy for how to handle + // external state when serializing a dataset or checkpointing its iterator. + // There are three settings available - IGNORE: External state is ignored + // without a warning; WARN: External state is ignored and a warning is logged; + // FAIL: External state results in an error. + oneof optional_external_state_policy { + ExternalStatePolicy external_state_policy = 6; + } +} diff --git a/tensorflow/python/data/experimental/ops/distribute_options.py b/tensorflow/python/data/experimental/ops/distribute_options.py index 82c498ff993..9a18528513d 100644 --- a/tensorflow/python/data/experimental/ops/distribute_options.py +++ b/tensorflow/python/data/experimental/ops/distribute_options.py @@ -19,6 +19,7 @@ from __future__ import print_function import enum +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.python.data.util import options from tensorflow.python.util.tf_export import tf_export @@ -35,6 +36,34 @@ class AutoShardPolicy(enum.IntEnum): FILE = 1 DATA = 2 + @classmethod + def _to_proto(cls, obj): + """Convert enum to proto.""" + if obj == cls.OFF: + return dataset_options_pb2.AutoShardPolicy.OFF + if obj == cls.FILE: + return dataset_options_pb2.AutoShardPolicy.FILE + if obj == cls.DATA: + return dataset_options_pb2.AutoShardPolicy.DATA + if obj == cls.AUTO: + return dataset_options_pb2.AutoShardPolicy.AUTO + raise ValueError("%s._to_proto() is called with undefined enum %s." % + (cls.__name__, obj.name)) + + @classmethod + def _from_proto(cls, pb): + """Convert proto to enum.""" + if pb == dataset_options_pb2.AutoShardPolicy.OFF: + return cls.OFF + if pb == dataset_options_pb2.AutoShardPolicy.FILE: + return cls.FILE + if pb == dataset_options_pb2.AutoShardPolicy.DATA: + return cls.DATA + if pb == dataset_options_pb2.AutoShardPolicy.AUTO: + return cls.AUTO + raise ValueError("%s._from_proto() is called with undefined enum %s." % + (cls.__name__, pb)) + @tf_export("data.experimental.ExternalStatePolicy") class ExternalStatePolicy(enum.Enum): @@ -47,6 +76,30 @@ class ExternalStatePolicy(enum.Enum): IGNORE = 1 FAIL = 2 + @classmethod + def _to_proto(cls, obj): + """Convert enum to proto.""" + if obj == cls.IGNORE: + return dataset_options_pb2.ExternalStatePolicy.IGNORE + if obj == cls.FAIL: + return dataset_options_pb2.ExternalStatePolicy.FAIL + if obj == cls.WARN: + return dataset_options_pb2.ExternalStatePolicy.WARN + raise ValueError("%s._to_proto() is called with undefined enum %s." % + (cls.__name__, obj.name)) + + @classmethod + def _from_proto(cls, pb): + """Convert proto to enum.""" + if pb == dataset_options_pb2.ExternalStatePolicy.IGNORE: + return cls.IGNORE + if pb == dataset_options_pb2.ExternalStatePolicy.FAIL: + return cls.FAIL + if pb == dataset_options_pb2.ExternalStatePolicy.WARN: + return cls.WARN + raise ValueError("%s._from_proto() is called with undefined enum %s." % + (cls.__name__, pb)) + @tf_export("data.experimental.DistributeOptions") class DistributeOptions(options.OptionsBase): @@ -89,3 +142,15 @@ class DistributeOptions(options.OptionsBase): docstring= "The number of devices attached to this input pipeline. This will be " "automatically set by MultiDeviceIterator.") + + def _to_proto(self): + pb = dataset_options_pb2.DistributeOptions() + pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy) # pylint: disable=protected-access + if self.num_devices is not None: + pb.num_devices = self.num_devices + return pb + + def _from_proto(self, pb): + self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy) # pylint: disable=protected-access + if pb.WhichOneof("optional_num_devices") is not None: + self.num_devices = pb.num_devices diff --git a/tensorflow/python/data/experimental/ops/optimization_options.py b/tensorflow/python/data/experimental/ops/optimization_options.py index 5c69855e15f..992ea647955 100644 --- a/tensorflow/python/data/experimental/ops/optimization_options.py +++ b/tensorflow/python/data/experimental/ops/optimization_options.py @@ -19,6 +19,7 @@ from __future__ import print_function import enum +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.python.data.util import options from tensorflow.python.util.tf_export import tf_export @@ -69,6 +70,20 @@ class MapVectorizationOptions(options.OptionsBase): else: return ["map_vectorization:use_choose_fastest:false"] + def _to_proto(self): + pb = dataset_options_pb2.MapVectorization() + if self.enabled is not None: + pb.enabled = self.enabled + if self.use_choose_fastest is not None: + pb.use_choose_fastest = self.use_choose_fastest + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_enabled") is not None: + self.enabled = pb.enabled + if pb.WhichOneof("optional_use_choose_fastest") is not None: + self.use_choose_fastest = pb.use_choose_fastest + @tf_export("data.experimental.OptimizationOptions") class OptimizationOptions(options.OptionsBase): @@ -327,3 +342,77 @@ class OptimizationOptions(options.OptionsBase): graph_rewrite_configs.append(optimization + ":autotune:true") return graph_rewrite_configs + + def _to_proto(self): + pb = dataset_options_pb2.OptimizationOptions() + if self.apply_default_optimizations is not None: + pb.apply_default_optimizations = self.apply_default_optimizations + if self.autotune is not None: + pb.autotune = self.autotune + if self.autotune_buffers is not None: + pb.autotune_buffers = self.autotune_buffers + if self.autotune_cpu_budget is not None: + pb.autotune_cpu_budget = self.autotune_cpu_budget + if self.autotune_ram_budget is not None: + pb.autotune_ram_budget = self.autotune_ram_budget + if self.filter_fusion is not None: + pb.filter_fusion = self.filter_fusion + if self.filter_with_random_uniform_fusion is not None: + pb.filter_with_random_uniform_fusion = ( + self.filter_with_random_uniform_fusion) + if self.hoist_random_uniform is not None: + pb.hoist_random_uniform = self.hoist_random_uniform + if self.map_and_batch_fusion is not None: + pb.map_and_batch_fusion = self.map_and_batch_fusion + if self.map_and_filter_fusion is not None: + pb.map_and_filter_fusion = self.map_and_filter_fusion + if self.map_fusion is not None: + pb.map_fusion = self.map_fusion + if self.map_parallelization is not None: + pb.map_parallelization = self.map_parallelization + pb.map_vectorization.CopyFrom(self.map_vectorization._to_proto()) # pylint: disable=protected-access + if self.noop_elimination is not None: + pb.noop_elimination = self.noop_elimination + if self.parallel_batch is not None: + pb.parallel_batch = self.parallel_batch + if self.reorder_data_discarding_ops is not None: + pb.reorder_data_discarding_ops = self.reorder_data_discarding_ops + if self.shuffle_and_repeat_fusion is not None: + pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_apply_default_optimizations") is not None: + self.apply_default_optimizations = pb.apply_default_optimizations + if pb.WhichOneof("optional_autotune") is not None: + self.autotune = pb.autotune + if pb.WhichOneof("optional_autotune_buffers") is not None: + self.autotune_buffers = pb.autotune_buffers + if pb.WhichOneof("optional_autotune_cpu_budget") is not None: + self.autotune_cpu_budget = pb.autotune_cpu_budget + if pb.WhichOneof("optional_autotune_ram_budget") is not None: + self.autotune_ram_budget = pb.autotune_ram_budget + if pb.WhichOneof("optional_filter_fusion") is not None: + self.filter_fusion = pb.filter_fusion + if pb.WhichOneof("optional_filter_with_random_uniform_fusion") is not None: + self.filter_with_random_uniform_fusion = ( + pb.filter_with_random_uniform_fusion) + if pb.WhichOneof("optional_hoist_random_uniform") is not None: + self.hoist_random_uniform = pb.hoist_random_uniform + if pb.WhichOneof("optional_map_and_batch_fusion") is not None: + self.map_and_batch_fusion = pb.map_and_batch_fusion + if pb.WhichOneof("optional_map_and_filter_fusion") is not None: + self.map_and_filter_fusion = pb.map_and_filter_fusion + if pb.WhichOneof("optional_map_fusion") is not None: + self.map_fusion = pb.map_fusion + if pb.WhichOneof("optional_map_parallelization") is not None: + self.map_parallelization = pb.map_parallelization + self.map_vectorization._from_proto(pb.map_vectorization) # pylint: disable=protected-access + if pb.WhichOneof("optional_noop_elimination") is not None: + self.noop_elimination = pb.noop_elimination + if pb.WhichOneof("optional_parallel_batch") is not None: + self.parallel_batch = pb.parallel_batch + if pb.WhichOneof("optional_reorder_data_discarding_ops") is not None: + self.reorder_data_discarding_ops = pb.reorder_data_discarding_ops + if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None: + self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion diff --git a/tensorflow/python/data/experimental/ops/threading_options.py b/tensorflow/python/data/experimental/ops/threading_options.py index d713b9ae075..39da39353d6 100644 --- a/tensorflow/python/data/experimental/ops/threading_options.py +++ b/tensorflow/python/data/experimental/ops/threading_options.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.python.data.util import options from tensorflow.python.util.tf_export import tf_export @@ -48,3 +49,17 @@ class ThreadingOptions(options.OptionsBase): ty=int, docstring= "If set, the dataset will use a private threadpool of the given size.") + + def _to_proto(self): + pb = dataset_options_pb2.ThreadingOptions() + if self.max_intra_op_parallelism is not None: + pb.max_intra_op_parallelism = self.max_intra_op_parallelism + if self.private_threadpool_size is not None: + pb.private_threadpool_size = self.private_threadpool_size + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_max_intra_op_parallelism") is not None: + self.max_intra_op_parallelism = pb.max_intra_op_parallelism + if pb.WhichOneof("optional_private_threadpool_size") is not None: + self.private_threadpool_size = pb.private_threadpool_size diff --git a/tensorflow/python/data/kernel_tests/options_test.py b/tensorflow/python/data/kernel_tests/options_test.py index 31220c69d9e..efd3f598a1f 100644 --- a/tensorflow/python/data/kernel_tests/options_test.py +++ b/tensorflow/python/data/kernel_tests/options_test.py @@ -23,6 +23,8 @@ import sys from absl.testing import parameterized +from tensorflow.core.framework import dataset_options_pb2 +from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.experimental.ops import optimization_options from tensorflow.python.data.experimental.ops import stats_options from tensorflow.python.data.experimental.ops import threading_options @@ -127,6 +129,67 @@ class OptionsTest(test_base.DatasetTestBase, parameterized.TestCase): result = result.concatenate(ds) self.assertDatasetProduces(result, [0]*1000) + @combinations.generate(test_base.default_test_combinations()) + def testOptionsProtoRoundTrip(self): + options = dataset_ops.Options() + options.experimental_deterministic = True + options.experimental_external_state_policy = ( + distribute_options.ExternalStatePolicy.FAIL) + options.experimental_distribute.auto_shard_policy = ( + distribute_options.AutoShardPolicy.DATA) + options.experimental_distribute.num_devices = 1000 + options.experimental_optimization.apply_default_optimizations = True + options.experimental_optimization.autotune = True + options.experimental_optimization.autotune_buffers = True + options.experimental_optimization.autotune_cpu_budget = 10 + options.experimental_optimization.autotune_ram_budget = 20 + options.experimental_optimization.filter_fusion = True + options.experimental_optimization.filter_with_random_uniform_fusion = True + options.experimental_optimization.hoist_random_uniform = True + options.experimental_optimization.map_and_batch_fusion = True + options.experimental_optimization.map_and_filter_fusion = True + options.experimental_optimization.map_fusion = True + options.experimental_optimization.map_parallelization = True + options.experimental_optimization.map_vectorization.enabled = True + options.experimental_optimization.map_vectorization.use_choose_fastest = ( + True) + options.experimental_optimization.noop_elimination = True + options.experimental_optimization.parallel_batch = True + options.experimental_optimization.reorder_data_discarding_ops = True + options.experimental_optimization.shuffle_and_repeat_fusion = True + options.experimental_slack = True + options.experimental_threading.max_intra_op_parallelism = 30 + options.experimental_threading.private_threadpool_size = 40 + pb = options._to_proto() + result = dataset_ops.Options() + result._from_proto(pb) + self.assertEqual(options, result) + + @combinations.generate(test_base.default_test_combinations()) + def testOptionsProtoDefaultValuesRoundTrip(self): + options = dataset_ops.Options() + pb = options._to_proto() + result = dataset_ops.Options() + result._from_proto(pb) + self.assertEqual(options, result) + + @combinations.generate(test_base.default_test_combinations()) + def testProtoOptionsDefaultValuesRoundTrip(self): + pb = dataset_options_pb2.Options() + options = dataset_ops.Options() + options._from_proto(pb) + result = options._to_proto() + expected_pb = dataset_options_pb2.Options() + expected_pb.distribute_options.CopyFrom( + dataset_options_pb2.DistributeOptions()) + expected_pb.optimization_options.CopyFrom( + dataset_options_pb2.OptimizationOptions()) + expected_pb.optimization_options.map_vectorization.CopyFrom( + dataset_options_pb2.MapVectorization()) + expected_pb.threading_options.CopyFrom( + dataset_options_pb2.ThreadingOptions()) + self.assertProtoEquals(expected_pb, result) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 9b5aa9f6dda..6497cb2143b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -28,6 +28,7 @@ import numpy as np import six from six.moves import queue as Queue # pylint: disable=redefined-builtin +from tensorflow.core.framework import dataset_options_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import distribute_options @@ -3039,6 +3040,34 @@ class Options(options_lib.OptionsBase): "state is ignored and a warning is logged; FAIL: External state results " "in an error.") + def _to_proto(self): + pb = dataset_options_pb2.Options() + if self.experimental_deterministic is not None: + pb.deterministic = self.experimental_deterministic + pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto()) # pylint: disable=protected-access + if self.experimental_external_state_policy is not None: + pb.external_state_policy = ( + distribute_options.ExternalStatePolicy._to_proto( # pylint: disable=protected-access + self.experimental_external_state_policy)) + pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto()) # pylint: disable=protected-access + if self.experimental_slack is not None: + pb.slack = self.experimental_slack + pb.threading_options.CopyFrom(self.experimental_threading._to_proto()) # pylint: disable=protected-access + return pb + + def _from_proto(self, pb): + if pb.WhichOneof("optional_deterministic") is not None: + self.experimental_deterministic = pb.deterministic + self.experimental_distribute._from_proto(pb.distribute_options) # pylint: disable=protected-access + if pb.WhichOneof("optional_external_state_policy") is not None: + self.experimental_external_state_policy = ( + distribute_options.ExternalStatePolicy._from_proto( # pylint: disable=protected-access + pb.external_state_policy)) + self.experimental_optimization._from_proto(pb.optimization_options) # pylint: disable=protected-access + if pb.WhichOneof("optional_slack") is not None: + self.experimental_slack = pb.slack + self.experimental_threading._from_proto(pb.threading_options) # pylint: disable=protected-access + def _graph_rewrites(self): """Produces lists of enabled, disabled, default static graph rewrites. diff --git a/tensorflow/python/data/util/options.py b/tensorflow/python/data/util/options.py index 8af773ed68b..3df6f000bb6 100644 --- a/tensorflow/python/data/util/options.py +++ b/tensorflow/python/data/util/options.py @@ -59,6 +59,14 @@ class OptionsBase(object): raise AttributeError( "Cannot set the property %s on %s." % (name, type(self).__name__)) + def _to_proto(self): + """Convert options to protocol buffer.""" + raise NotImplementedError("%s._to_proto()" % type(self).__name__) + + def _from_proto(self, pb): + """Convert protocol buffer to options.""" + raise NotImplementedError("%s._from_proto()" % type(self).__name__) + # Creates a namedtuple with three keys for optimization graph rewrites settings. def graph_rewrites(): From 519c01c3891f58a09072d0aaa3d3c13681ff81af Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Feb 2021 15:51:53 -0800 Subject: [PATCH 05/19] Integrate LLVM at llvm/llvm-project@a4fa667dee60 Updates LLVM usage to match [a4fa667dee60](https://github.com/llvm/llvm-project/commit/a4fa667dee60) PiperOrigin-RevId: 355945749 Change-Id: Ib7c3b487bc5dda535ef8515ba37fe52102ae9b25 --- tensorflow/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 3b2005d9afe..c5b4e9c9098 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -685,8 +685,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "a1a1d338e99dc9c6d1234b70f43dea2e1bb2f8ce" - LLVM_SHA256 = "0adf75d405fe714b2c8a0ab1db4c10dcf9629b57e001191d3e5520407d563cc5" + LLVM_COMMIT = "a4fa667dee6012e350bd405ee7a759a53738b279" + LLVM_SHA256 = "11ef06ff3c01638d3bd11d9095259db92ab69ec85f101f4969c6c4ad9f154f8e" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), From 385019cd2429d6f5a406d2156a620ca3ca44409e Mon Sep 17 00:00:00 2001 From: Michael Banfield Date: Fri, 5 Feb 2021 15:53:38 -0800 Subject: [PATCH 06/19] Append tpu planes to XSpace. PiperOrigin-RevId: 355946077 Change-Id: I0e9eb9bbfde6a3230bfccf5b1c783013a7ae8c85 --- tensorflow/core/profiler/internal/tpu/tpu_tracer.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc index dbe32ee043c..528432fef91 100644 --- a/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc +++ b/tensorflow/core/profiler/internal/tpu/tpu_tracer.cc @@ -107,7 +107,12 @@ Status TpuTracer::CollectData(XSpace* space) { tpu::OpsApiFn()->TpuProfiler_CollectDataFn(tpu_profiler_, status.c_status, buffer.data(), &size_in_bytes); // Deserialize XSpace from the buffer and return it. - space->ParseFromArray(buffer.data(), buffer.size()); + XSpace tpu_space; + tpu_space.ParseFromArray(buffer.data(), buffer.size()); + for (XPlane& tpu_plane : *tpu_space.mutable_planes()) { + XPlane* plane = space->add_planes(); + plane->Swap(&tpu_plane); + } } if (!status.ok()) { LOG(ERROR) << "TPU tracer failed to collect data."; From 911d9336e222f4365b7c3a3a0d417cff2c42067e Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Fri, 5 Feb 2021 15:56:03 -0800 Subject: [PATCH 07/19] [tf.data] Default autotuning to conservative values to avoid accidentally allocating too much memory before optimization loop picks values that respect the memory budget. PiperOrigin-RevId: 355946537 Change-Id: I2e581f42fa5c016b0240fd69f0b16f08fc2fdbfd --- .../experimental/map_and_batch_dataset_op.cc | 15 ++++++++++++++- .../kernels/data/parallel_map_dataset_op.cc | 15 ++++++++++++++- .../optimization/map_vectorization_test.py | 17 +++++++---------- .../python/data/kernel_tests/test_base.py | 5 +++-- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index 51a02ef45a2..69b91211c5a 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -216,7 +216,20 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == model::kAutotune) { - num_parallel_calls_->value = ctx->runner_threadpool_size(); + // If autotuning is enabled, we initialize the parallelism to 1 to + // avoid accidentally running the machine out of memory before the + // optimization can pick values that respect the memory budget. + // + // If autotuning is disabled but the transformation uses `AUTOTUNE`, we + // default the parallelism to the size of the threadpool used for + // executing the user-defined computation. If this causes OOM, the + // input pipeline should either enable autotuning, or replace + // `AUTOTUNE` with fixed parallelism. + if (TF_PREDICT_TRUE(ctx->model())) { + num_parallel_calls_->value = 1; + } else { + num_parallel_calls_->value = ctx->runner_threadpool_size(); + } } TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index ece14401724..4f7e7156d47 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -221,7 +221,20 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == model::kAutotune) { - num_parallel_calls_->value = ctx->runner_threadpool_size(); + // If autotuning is enabled, we initialize the parallelism to 1 to + // avoid accidentally running the machine out of memory before the + // optimization can pick values that respect the memory budget. + // + // If autotuning is disabled but the transformation uses `AUTOTUNE`, we + // default the parallelism to the size of the threadpool used for + // executing the user-defined computation. If this causes OOM, the + // input pipeline should either enable autotuning, or replace + // `AUTOTUNE` with fixed parallelism. + if (TF_PREDICT_TRUE(ctx->model())) { + num_parallel_calls_->value = 1; + } else { + num_parallel_calls_->value = ctx->runner_threadpool_size(); + } } cancellation_manager_ = absl::make_unique(ctx->cancellation_manager()); diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 3876408697f..e1aa0957994 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -564,22 +564,19 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): # Tests that vectorization maintains the determinism setting. expect_determinism = local_determinism or (local_determinism is None and global_determinism) - elements = list(range(1000)) - + num_elements = 1000 def dataset_fn(delay_ms): def sleep(x): - time.sleep(delay_ms / 1000) + # Inject random delay in the interval [0, delay_ms / 1000). + time.sleep(delay_ms * (np.random.randint(x + 1) / (x + 1)) / 1000) return x def map_function(x): - if math_ops.equal(x, 0): - return check_ops.ensure_shape( - script_ops.py_func(sleep, [x], x.dtype, stateful=False), ()) - else: - return x + return check_ops.ensure_shape( + script_ops.py_func(sleep, [x], x.dtype, stateful=False), ()) - dataset = dataset_ops.Dataset.from_tensor_slices(elements) + dataset = dataset_ops.Dataset.range(num_elements) dataset = dataset.map( map_function, num_parallel_calls=10, deterministic=local_determinism) dataset = dataset.batch(1) @@ -595,7 +592,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): self.checkDeterminism( dataset_fn, expect_determinism, - expected_elements=[[element] for element in elements]) + expected_elements=[[element] for element in range(num_elements)]) @combinations.generate(test_base.default_test_combinations()) def testOptimizationIgnoreStateful(self): diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index 1af40feb4c1..0675d0f83a2 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -340,6 +340,7 @@ class DatasetTestBase(test.TestCase): dataset = dataset_fn(delay_ms) actual = self.getDatasetOutput(dataset) self.assertCountEqual(expected_elements, actual) - if actual[0] != expected_elements[0]: - return + for i in range(len(actual)): + if actual[i] != expected_elements[i]: + return self.fail("Failed to observe nondeterministic ordering") From 8617a34cee6bbdb8aa60c675dcc7d8814a0c18be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Feb 2021 16:07:44 -0800 Subject: [PATCH 08/19] Qualifying "string" as std::string PiperOrigin-RevId: 355948874 Change-Id: I02284fb8753bde69d12c54bd2b85be8040ea378b --- tensorflow/core/platform/path.h | 17 +++++++++-------- tensorflow/core/util/env_var.h | 2 +- tensorflow/stream_executor/blas.cc | 2 +- tensorflow/stream_executor/blas.h | 2 +- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/platform/path.h b/tensorflow/core/platform/path.h index 47c0dc07614..408f5abe011 100644 --- a/tensorflow/core/platform/path.h +++ b/tensorflow/core/platform/path.h @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace io { namespace internal { -string JoinPathImpl(std::initializer_list paths); +std::string JoinPathImpl(std::initializer_list paths); } // Utility routines for processing filenames @@ -43,7 +43,7 @@ string JoinPathImpl(std::initializer_list paths); // string path = io::JoinPath(FLAGS_test_srcdir, filename); // string path = io::JoinPath("/full", "path", "to", "filename"); template -string JoinPath(const T&... args) { +std::string JoinPath(const T&... args) { return internal::JoinPathImpl({args...}); } #endif /* SWIG */ @@ -71,7 +71,7 @@ tensorflow::StringPiece Extension(tensorflow::StringPiece path); // "/alpha/beta/". // // Does not perform any path normalization. -string CommonPathPrefix(absl::Span paths); +std::string CommonPathPrefix(absl::Span paths); // Collapse duplicate "/"s, resolve ".." and "." path elements, remove // trailing "/". @@ -80,7 +80,7 @@ string CommonPathPrefix(absl::Span paths); // invoke any system calls (getcwd(2)) in order to resolve relative // paths with respect to the actual working directory. That is, this is purely // string manipulation, completely independent of process state. -string CleanPath(tensorflow::StringPiece path); +std::string CleanPath(tensorflow::StringPiece path); // Populates the scheme, host, and path from a URI. scheme, host, and path are // guaranteed by this function to point into the contents of uri, even if @@ -95,11 +95,12 @@ void ParseURI(tensorflow::StringPiece uri, tensorflow::StringPiece* scheme, // Creates a URI from a scheme, host, and path. If the scheme is empty, we just // return the path. -string CreateURI(tensorflow::StringPiece scheme, tensorflow::StringPiece host, - tensorflow::StringPiece path); +std::string CreateURI(tensorflow::StringPiece scheme, + tensorflow::StringPiece host, + tensorflow::StringPiece path); // Creates a temporary file name with an extension. -string GetTempFilename(const string& extension); +std::string GetTempFilename(const std::string& extension); // Reads the TEST_UNDECLARED_OUTPUTS_DIR environment variable, and if set // assigns `dir` to the value. `dir` is not modified if the environment variable @@ -108,7 +109,7 @@ string GetTempFilename(const string& extension); // // Note: This function obviates the need to deal with Bazel's odd path decisions // on Windows, and should be preferred over a simple `getenv`. -bool GetTestUndeclaredOutputsDir(string* dir); +bool GetTestUndeclaredOutputsDir(std::string* dir); } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/util/env_var.h b/tensorflow/core/util/env_var.h index 7d10f229102..1125e21abfd 100644 --- a/tensorflow/core/util/env_var.h +++ b/tensorflow/core/util/env_var.h @@ -44,7 +44,7 @@ Status ReadFloatFromEnvVar(StringPiece env_var_name, float default_val, // Returns a string into "value" from the environmental variable "env_var_name". // If it is unset, the default value is used. Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val, - string* value); + std::string* value); } // namespace tensorflow diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index ca597595969..5c7783d99c1 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -95,7 +95,7 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) { return os << ComputationTypeString(ty); } -string DataTypeString(DataType ty) { +std::string DataTypeString(DataType ty) { switch (ty) { case DataType::kHalf: return "f16"; diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 20776b8416d..0f3e77352af 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -134,7 +134,7 @@ enum class PointerMode { }; // Converts a ComputationType to a string. -string DataTypeString(DataType ty); +std::string DataTypeString(DataType ty); std::ostream &operator<<(std::ostream &os, DataType ty); From fa3b02f2b153cd533da211c490f0f95097f75aff Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 5 Feb 2021 16:08:08 -0800 Subject: [PATCH 09/19] Add space for reporting memory bandwidth statistics PiperOrigin-RevId: 355948962 Change-Id: I606e722f5800774783e7abc46be7946a801ac183 --- tensorflow/core/kernels/data/BUILD | 1 + tensorflow/core/kernels/data/iterator_ops.cc | 17 ++++++++++++++--- tensorflow/core/platform/default/port.cc | 2 +- tensorflow/core/platform/mem.h | 1 + tensorflow/core/platform/windows/port.cc | 2 +- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index e5adb2dfba2..448a5def807 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -463,6 +463,7 @@ tf_kernel_library( "//tensorflow/core:session_options", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/lib:traceme_encode", "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index f3d2382db1f..dce6ebd49bc 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -47,10 +47,12 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/resource.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -965,9 +967,18 @@ void RecordElementSize(const std::vector element, Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { profiler::TraceMe traceme( [&] { - return strings::StrCat( - "IteratorGetNextOp::DoCompute#id=", ctx->step_id(), - ",iter_num=", ctx->frame_iter().iter_id, "#"); + int64 mem_bw = port::GetMemoryInfo().bw_used; + + if (mem_bw != INT64_MAX) { + return profiler::TraceMeEncode( + "IteratorGetNextOp::DoCompute", + {{"id", ctx->step_id()}, + {"iter_num", ctx->frame_iter().iter_id}, + {"mem_bw_used", mem_bw}}); + } + return profiler::TraceMeEncode( + "IteratorGetNextOp::DoCompute", + {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}}); }, profiler::kInfo); tensorflow::ResourceTagger tag(kTFDataResourceTag, diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc index e25ed074844..6e82c67be99 100644 --- a/tensorflow/core/platform/default/port.cc +++ b/tensorflow/core/platform/default/port.cc @@ -357,7 +357,7 @@ double NominalCPUFrequency() { } MemoryInfo GetMemoryInfo() { - MemoryInfo mem_info = {INT64_MAX, INT64_MAX}; + MemoryInfo mem_info = {INT64_MAX, INT64_MAX, INT64_MAX}; #if defined(__linux__) && !defined(__ANDROID__) struct sysinfo info; int err = sysinfo(&info); diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h index 27ad3574182..65411eeac28 100644 --- a/tensorflow/core/platform/mem.h +++ b/tensorflow/core/platform/mem.h @@ -62,6 +62,7 @@ std::size_t MallocExtension_GetAllocatedSize(const void* p); struct MemoryInfo { int64 total = 0; int64 free = 0; + int64 bw_used = 0; // memory bandwidth used across all CPU }; // Retrieves the host memory information. If any of the fields in the returned diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 16b5a328256..256f525a38d 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -192,7 +192,7 @@ double NominalCPUFrequency() { } MemoryInfo GetMemoryInfo() { - MemoryInfo mem_info = {INT64_MAX, INT64_MAX}; + MemoryInfo mem_info = {INT64_MAX, INT64_MAX, INT64_MAX}; MEMORYSTATUSEX statex; statex.dwLength = sizeof(statex); if (GlobalMemoryStatusEx(&statex)) { From 361f3d654a1fc8c996e707b126eaa959757fcfde Mon Sep 17 00:00:00 2001 From: Tianrun Li Date: Fri, 5 Feb 2021 16:11:31 -0800 Subject: [PATCH 10/19] Change ProcessLegacyRootEvents to ProcessUserDefinedRootEvents. PiperOrigin-RevId: 355949556 Change-Id: Id1a08ab35fc009ee5e8affa3e47ca71307e3fe4e --- tensorflow/core/profiler/utils/group_events.cc | 14 ++++++++------ tensorflow/core/profiler/utils/group_events.h | 7 ++++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/profiler/utils/group_events.cc b/tensorflow/core/profiler/utils/group_events.cc index 566a4186344..2cf45f7797c 100644 --- a/tensorflow/core/profiler/utils/group_events.cc +++ b/tensorflow/core/profiler/utils/group_events.cc @@ -584,10 +584,11 @@ void EventForest::ConnectInterThread( } } -void EventForest::ProcessLegacyRootEvents( - const std::vector& root_event_types) { - for (int64 root_event_type : root_event_types) { - if (auto root_events = gtl::FindOrNull(event_node_map_, root_event_type)) { +void EventForest::ProcessUserDefinedRootEvents( + const std::vector& user_defined_root_event_types) { + for (int64 user_defined_root_event_type : user_defined_root_event_types) { + if (auto root_events = + gtl::FindOrNull(event_node_map_, user_defined_root_event_type)) { for (const auto& root_event : *root_events) { root_event->SetIsRoot(true); root_events_.push_back(root_event.get()); @@ -869,10 +870,11 @@ void EventForest::ConnectTfDataEvents() { VLOG(1) << num_matched << " consumer iterators matched."; } -void EventForest::GroupEvents(const std::vector& root_event_types) { +void EventForest::GroupEvents( + const std::vector& user_defined_root_event_types) { ProcessTensorFlowLoop(); ProcessWorker(); - ProcessLegacyRootEvents(root_event_types); + ProcessUserDefinedRootEvents(user_defined_root_event_types); CreateEventGroups(); MarkEagerlyExecutedGpuKernels(); MarkEagerlyExecutedCpuTfOps(); diff --git a/tensorflow/core/profiler/utils/group_events.h b/tensorflow/core/profiler/utils/group_events.h index 706a2cbf67d..aeffaaa37aa 100644 --- a/tensorflow/core/profiler/utils/group_events.h +++ b/tensorflow/core/profiler/utils/group_events.h @@ -176,7 +176,8 @@ class EventForest { void ConnectTfDataEvents(); - void GroupEvents(const std::vector& root_event_types = {}); + void GroupEvents( + const std::vector& user_defined_root_event_types = {}); const EventNodeMap& GetEventNodeMap() const { return event_node_map_; } @@ -198,8 +199,8 @@ class EventForest { void ConnectInterThread( const std::vector& connect_info_list); - void ProcessLegacyRootEvents( - const std::vector& root_event_types); + void ProcessUserDefinedRootEvents( + const std::vector& user_defined_root_event_types); // Creates event groups and populates group_metadata_map. If a TF loop is // used, each TF loop iteration becomes a root. Otherwise, top root events From cf18ba2332fa5cd561803e7425b67c8fd008c44f Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 5 Feb 2021 16:38:39 -0800 Subject: [PATCH 11/19] PR #46901: Fold StridedSliceOp when input is defined by ShapeOp. Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/46901 Fixes #46879 and #46080. This PR adds a sub-shape folder from StridedSliceOp. Fold StridedSliceOp when input is defined by ShapeOp. The pattern is common in TF python library like ```python height = tf.shape(x)[1] spatial_shape = tf.shape(x)[1:3] ``` When `x` has some dynamic dimensions (typically batch dim), `tf.shape` can not be folded so `height` and `spatial_shape` can not be inferred as a constant even if the corresponding dimensions are static. This PR folds this kind of patterns to improve sub-shape constant folding. Note that there is a workaround in python lib to use ```python height = x.shape[1] or tf.shape(x)[1] spatial_shape = [x.shape[i] or tf.shape(x)[i] for i in (1, 2)] ``` to make sure constant is propagated. However, most of TensorFlow codes do not use this. Hi @abattery, would you mind taking a look at this? Thank you! Copybara import of the project: -- 8c5522dbef98505c51dca136ac17bf8feba2565a by Tzu-Wei Sung : Fold StridedSliceOp when input is defined by ShapeOp. The pattern is common is TF python library like height = tf.shape(x)[1]. When x has some dynamic dimensions (typically batch dim), tf.shape can not be constant folded so height cannot be inferred as a constant. This PR folds this kind of patterns to improve sub-shape constant folding. -- 14d97ae85fd921209c1ac1cf2c2d9cfdb874e91d by Tzu-Wei Sung : Rename some testcases -- de9e828d0a5711152c83b3d0879f13a4d3448efa by Tzu-Wei Sung : Correctly handle negative strides -- 085d490edf936b4549835b7cbcb787abe02d28ae by Tzu-Wei Sung : Add testcases for out of bound begin and end -- ae86fe419036f0adc758c39138045b95754e087e by Tzu-Wei Sung : clang-format -- 4e0af9d81c086e08747957f0f4d62e2b13f4ca46 by Tzu-Wei Sung : Address comments -- 8c6d03818534594bbda5e5956ed261f0926ebf84 by Tzu-Wei Sung : Fix Windows build. Templated Lambda is not supported in MSVC. COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/46901 from WindQAQ:fold-strided-slice-from-shape 8c6d03818534594bbda5e5956ed261f0926ebf84 PiperOrigin-RevId: 355954030 Change-Id: I2601b6276e649536bf4737285c90ca674ba24363 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 2 + .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 111 +++++++++ .../mlir/tensorflow/tests/canonicalize.mlir | 212 ++++++++++++++++-- 3 files changed, 309 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index cce5f71651a..bc61d27daf1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -14412,6 +14412,8 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; + let hasFolder = 1; + let verifier = [{ return VerifyStridedSliceBase(*this); }]; let extraClassDeclaration = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 0f8a423124f..ae255f6db00 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -1886,6 +1886,117 @@ bool StridedSliceOp::GetSlicedBoundRanges( return true; } +OpFoldResult StridedSliceOp::fold(ArrayRef operands) { + // Fold StridedSlice operation if it extracts statically known dimensions. + // + // For example, + // + // %shape = tf.Shape(%arg) // %arg: tensor + // %height = tf.StridedSlice(%shape, 1, 2, 1) + // + // In this case %height can be replaced with a constant 2. + // + // Or, + // + // %shape = tf.Shape(%arg) // %arg: tensor + // %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1) + // + // In this case %spatial_shape can be replaced with a constant [2, 3]. + + // Input to strided slice op is defined by shape operation. + auto shape_op = input().getDefiningOp(); + if (!shape_op) { + return {}; + } + + // `begin`, `end` and `strides` should be constant in order to infer static + // dimension. + DenseIntElementsAttr begin_attr, end_attr, strides_attr; + if (!matchPattern(begin(), m_Constant(&begin_attr)) || + !matchPattern(end(), m_Constant(&end_attr)) || + !matchPattern(strides(), m_Constant(&strides_attr)) || + begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 || + strides_attr.getNumElements() != 1) { + return {}; + } + + auto tensor_ty = shape_op.input().getType().dyn_cast(); + // Only ranked tensor can be folded. + if (!tensor_ty) return {}; + + int64_t rank = tensor_ty.getRank(); + int64_t begin_int = begin_attr.getValue(0).getSExtValue(); + int64_t end_int = end_attr.getValue(0).getSExtValue(); + int64_t strides_int = strides_attr.getValue(0).getSExtValue(); + + // Canonicalize `begin` and `end` in case of negative index. + if (begin_int < 0) begin_int += rank; + if (end_int < 0) end_int += rank; + + // Create `begin` and `end` from `*_mask`. Note that we don't care about + // `new_axis_mask` as it can be inferred from `output_ty`. + if (shrink_axis_mask() == 1) { + // When `shrink_axis_mask` is set, output is always a scalar so only + // one element is sliced. + end_int = begin_int + 1; + } + if (begin_mask() == 1) { + begin_int = (strides_int > 0) ? 0 : rank - 1; + } + if (end_mask() == 1) { + end_int = (strides_int > 0) ? rank : -1; + } + if (ellipsis_mask() == 1) { + begin_int = 0; + end_int = rank; + } + + // It's possible that `begin` and `end` are out of bound. See + // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations. + if (strides_int > 0) { + begin_int = std::min(begin_int, rank); + end_int = std::min(end_int, rank); + } else { + begin_int = std::min(begin_int, rank - 1); + end_int = std::min(end_int, rank - 1); + } + + SmallVector sub_shape; + // Only handle cases that have something to slice to avoid infinite for-loop. + if ((end_int > begin_int && strides_int > 0) || + (end_int < begin_int && strides_int < 0)) { + // Extract sub-shape only if all of those dimensions are static. + for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int; + i += strides_int) { + if (tensor_ty.isDynamicDim(i)) { + return {}; + } + sub_shape.push_back(tensor_ty.getDimSize(i)); + } + } + + // Down-cast to 32 bit int if needed. + auto output_elt_ty = output().getType().cast().getElementType(); + + auto output_ty = output().getType().dyn_cast(); + if (!output_ty) { + // If the output is unranked, we infer a result using a 0-ranked tensor for + // scalar element. + if (sub_shape.size() == 1) + output_ty = RankedTensorType::get({}, output_elt_ty); + else + output_ty = RankedTensorType::get( + {static_cast(sub_shape.size())}, output_elt_ty); + } + if (output_elt_ty.isInteger(32)) { + SmallVector sub_shape_i32(sub_shape.size()); + std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(), + [](int64_t d) { return static_cast(d); }); + return DenseIntElementsAttr::get(output_ty, sub_shape_i32); + } + return DenseIntElementsAttr::get(output_ty, sub_shape); +} + //===----------------------------------------------------------------------===// // StridedSliceGradOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index e2a0552ef48..64cb3480bca 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -486,7 +486,7 @@ func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tenso } // CHECK-LABEL: func @testPackShapeComputation -func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { +func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { // Test dimensions sizes. %d1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %d2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor @@ -526,26 +526,20 @@ func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> // CHECK: %[[PACK0:.*]] = "tf.Pack" - // StridedSlice takes second dimension from the shape: - // begin = [1], end = [2], stride = [1] - %17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[PACK1:.*]] = "tf.Pack" - // Packed dimensions have higher rank than the reshape operand: // [?, 1] vs [?, 1, 1] - %20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> - // CHECK: %[[PACK2:.*]] = "tf.Pack" + %16 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %17 = "tf.Pack"(%16, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> + // CHECK: %[[PACK1:.*]] = "tf.Pack" // Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass - %23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> - %24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> - %25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> - // CHECK: %[[PACK3:.*]] = "tf.Pack" + %18 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> + %19 = "tf.StridedSlice"(%18, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> + %20 = "tf.Pack"(%19, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> + // CHECK: %[[PACK2:.*]] = "tf.Pack" - // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]] - return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> + // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]] + return %5, %9, %15, %17, %20 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> } // CHECK-LABEL: testSelectScalarPred @@ -1373,3 +1367,189 @@ func @testUnpackAndCwiseUnary(%arg0: tensor) -> (tensor, tensor< // CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1 return %0, %1 : tensor, tensor } + +// CHECK-LABEL: testFoldStridedSliceShapeI32 +func @testFoldStridedSliceShapeI32(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %3 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeI64 +func @testFoldStridedSliceShapeI64(%arg0: tensor) -> (tensor<2xi64>) { + %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi64> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + return %3 : tensor<2xi64> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI32 +func @testFoldStridedSliceShapeWithShrinkAxisMaskI32(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %3 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI64 +func @testFoldStridedSliceShapeWithShrinkAxisMaskI64(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi64> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %3 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1 +func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %4 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2 +func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2(%arg0: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + return %4 : tensor + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testUnfoldedStridedSliceShape +func @testUnfoldedStridedSliceShape(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[SLICE:.*]] = "tf.StridedSlice" + // CHECK: return %[[SLICE]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithBeginMask +func @testFoldStridedSliceShapeWithBeginMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithEndMask +func @testFoldStridedSliceShapeWithEndMask(%arg0: tensor) -> (tensor<3xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + return %3 : tensor<3xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStrides +func @testFoldStridedSliceShapeWithPositiveStrides(%arg0: tensor<1x2x3x4x?xf32>) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd +func @testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd(%arg0: tensor) -> (tensor<3xi32>) { + %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + return %3 : tensor<3xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStrides +func @testFoldStridedSliceShapeWithNegativeStrides(%arg0: tensor<1x2x3x?xf32>) -> (tensor<1xi32>) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + return %4 : tensor<1xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin +func @testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesBeginMask +func @testFoldStridedSliceShapeWithNegativeStridesBeginMask(%arg0: tensor) -> (tensor<2xi32>) { + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + return %4 : tensor<2xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesEndMask +func @testFoldStridedSliceShapeWithNegativeStridesEndMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<3xi32>) { + %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + return %4 : tensor<3xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK: return %[[CST]] +} + +// CHECK-LABEL: testFoldStridedSliceShapeWithEmptySlice +func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor) -> (tensor<0xi32>) { + %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + return %4 : tensor<0xi32> + // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + // CHECK: return %[[CST]] +} From 7cd52d03c423c27e5daf4e981ec44a5c84362d2c Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Fri, 5 Feb 2021 16:41:06 -0800 Subject: [PATCH 12/19] [XLA] Initialize fields of RematerializationSizes by default PiperOrigin-RevId: 355954409 Change-Id: I0da0d3ce320c53321778ad66ab9703307bc231c7 --- tensorflow/compiler/xla/service/hlo_rematerialization.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 878bb2a8eef..fc7db859b51 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -45,8 +45,8 @@ class HloRematerialization : public HloModulePass { // Helper struct that communicates the before / after sizes for the // rematerialization process. struct RematerializationSizes { - int64 before_bytes; - int64 after_bytes; + int64 before_bytes = -1; + int64 after_bytes = -1; }; // Mode in which the rematerialization algorithm should be run. From b71b5ac77ff68405f2a326a67a0a61b84c0e2a33 Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Fri, 5 Feb 2021 17:05:00 -0800 Subject: [PATCH 13/19] Move custom device placement from eager/execute.cc to c_api.cc. Then it can be reused by TFRT. PiperOrigin-RevId: 355957931 Change-Id: Ibd22404359ed8a84d25a7d358b9e062b7f32a36f --- tensorflow/c/eager/BUILD | 2 + tensorflow/c/eager/c_api.cc | 31 ++-- .../c/eager/immediate_execution_context.h | 16 ++ .../c/eager/immediate_execution_operation.h | 11 ++ tensorflow/core/common_runtime/eager/BUILD | 25 +++ .../core/common_runtime/eager/context.cc | 29 +-- .../core/common_runtime/eager/context.h | 14 +- tensorflow/core/common_runtime/eager/core.cc | 27 +-- .../eager/custom_device_op_handler.cc | 169 ++++++++++++++++++ .../eager/custom_device_op_handler.h | 51 ++++++ .../eager/custom_device_test.cc | 30 ++-- .../common_runtime/eager/eager_operation.cc | 64 +++---- .../common_runtime/eager/eager_operation.h | 18 +- .../common_runtime/eager/placement_utils.cc | 70 -------- .../common_runtime/eager/placement_utils.h | 8 +- 15 files changed, 360 insertions(+), 205 deletions(-) create mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.cc create mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.h diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 700912f3eff..f9d726cb5d2 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -73,9 +73,11 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context_distributed_manager", "//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:custom_device", + "//tensorflow/core/common_runtime/eager:custom_device_op_handler", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:placement_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 204db3078f4..5a31c434eaa 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -41,7 +41,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" #include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/placement_utils.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -532,7 +534,8 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::CustomDevice* device = nullptr; - if (!context->FindCustomDeviceFromName(device_name, &device)) { + if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name, + &device)) { deallocator(data, arg); status->status = tensorflow::errors::InvalidArgument(device_name, " unknown device."); @@ -562,7 +565,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { - if (!context->FindCustomDeviceFromName(device_name, &custom_device)) { + if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName( + device_name, &custom_device)) { deallocator(data, len, deallocator_arg); status->status = tensorflow::errors::InvalidArgument(device_name, " unknown device."); @@ -654,8 +658,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { } TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { - return tensorflow::wrap( - &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext())); + return tensorflow::wrap(tensorflow::unwrap(op)->GetContext()); } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { @@ -889,11 +892,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - status->status = tensorflow::unwrap(op)->Execute( - absl::MakeSpan(reinterpret_cast( - tensorflow::unwrap(retvals)), - *num_retvals), - num_retvals); + tensorflow::ImmediateExecutionOperation* unwrapped_op = + tensorflow::unwrap(op); + + status->status = + unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute( + unwrapped_op, + reinterpret_cast( + retvals), + num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, @@ -1150,10 +1157,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, } auto custom_device = std::make_unique( ctx, device, device_info, device_name); - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = - context->RegisterCustomDevice(device_name, std::move(custom_device)); + status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice( + device_name, std::move(custom_device)); } } // extern "C" diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index abb24cb0c54..6c2231017d3 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -38,6 +38,9 @@ limitations under the License. namespace tensorflow { class EagerExecutor; +class EagerContext; +class CustomDevice; +class CustomDeviceOpHandler; // LINT.IfChange // Note: Keep in sync with exported copy of enum in eager/c_api.h. @@ -122,6 +125,7 @@ class ImmediateExecutionContext : public AbstractContext { // Return the ParsedName of Host CPU device. virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; + virtual const string& HostCPUName() const = 0; // Configure soft device placement policy. virtual void SetAllowSoftPlacement(bool enable) = 0; @@ -147,6 +151,18 @@ class ImmediateExecutionContext : public AbstractContext { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; } + //===--------------------------------------------------------------------===// + // Experimental Custom Device. + //===--------------------------------------------------------------------===// + virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0; + + // Register a custom device. It will return error is the device name is + // already registered. + // TODO(tfrt-devs): Remove this method. Let caller register it directly into + // CustomDeviceOpHandler. + virtual Status RegisterCustomDevice(const string& name, + std::unique_ptr device) = 0; + //===--------------------------------------------------------------------===// // Following are features in current TF Eager Runtime. // TODO(tfrt-devs): Figure out a way to deprecate following features after diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index 85af5a706e1..a23177b56d5 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -33,6 +33,8 @@ struct TFE_Op; namespace tensorflow { +class ImmediateExecutionContext; + // Abstract interface to an operation. class ImmediateExecutionOperation : public AbstractOperation { public: @@ -41,6 +43,15 @@ class ImmediateExecutionOperation : public AbstractOperation { // Returns the inputs of this op. virtual absl::Span GetInputs() const = 0; + virtual Status SetInput(size_t index, + ImmediateExecutionTensorHandle* input) = 0; + + virtual ImmediateExecutionContext* GetContext() const = 0; + + // Following two methods are used to support custom device. + // Return true if the inputs contain custom device tensor handle. It means + // that the argument need to be handled by a custom device. + virtual bool HasCustomDeviceInput() const = 0; virtual const tensorflow::OpDef* OpDef() const = 0; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 8549c32417a..dddfe47de6b 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -87,6 +87,7 @@ tf_cuda_library( deps = [ ":eager_executor", ":kernel_and_device", + ":custom_device_op_handler", ":custom_device", "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_tensor_internal", @@ -140,6 +141,28 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "custom_device_op_handler", + srcs = ["custom_device_op_handler.cc"], + hdrs = ["custom_device_op_handler.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":custom_device", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/core/lib/core:status", + ], + }), +) + tf_cc_test( name = "custom_device_test", srcs = ["custom_device_test.cc"], @@ -647,6 +670,7 @@ tf_cuda_library( ":custom_device", ":attr_builder", ":eager_operation", + "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", ] + select({ "//tensorflow:android": [ @@ -714,6 +738,7 @@ filegroup( "attr_builder.h", "context.h", "custom_device.h", + "custom_device_op_handler.h", "eager_executor.h", "eager_operation.h", "kernel_and_device.h", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 7c20766a1ce..7fe6e00928c 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -522,7 +522,7 @@ EagerContext::~EagerContext() { // Custom devices may have obtained references to various context components // (executors, thread pool). It's safer to run their destructors early. - custom_devices_.clear(); + custom_device_op_handler_.Clear(); ClearCachesAndThreadExecutors(); std::unordered_map executors_copy; @@ -904,38 +904,15 @@ Status EagerContext::FindCompositeDeviceFromName( return errors::NotFound("Unknown composite device: ", device_name); } -bool EagerContext::FindCustomDeviceFromName(const string& device_name, - CustomDevice** dev) const { - auto dev_it = custom_devices_.find(device_name); - if (dev_it == custom_devices_.end()) { - return false; - } - *dev = dev_it->second.get(); - return true; -} - Status EagerContext::RegisterCustomDevice( const string& device_name, std::unique_ptr device) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device_name, &parsed) || - !parsed.has_job || !parsed.has_replica || !parsed.has_task || - !parsed.has_type || !parsed.has_id) { - return errors::InvalidArgument( - device_name, - " could not be parsed as a device name. Use the full " - "/job:/replica:/task:/device:: " - "format."); - } Device* existing_physical_device = nullptr; if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) { return errors::AlreadyExists(device_name, " already registered as a physical device."); } - if (!custom_devices_.emplace(device_name, std::move(device)).second) { - return errors::AlreadyExists(device_name, - " already registered as a custom device."); - } - return Status::OK(); + return custom_device_op_handler_.RegisterCustomDevice(device_name, + std::move(device)); } Status EagerContext::FindOrCreateCompositeDevice( diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 28c0b40f43b..fd6d896d9dd 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/function.h" @@ -204,6 +205,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { return HostCPU()->parsed_name(); } + const string& HostCPUName() const override { return HostCPU()->name(); } + GraphCollector* GetGraphCollector() { return &graph_collector_; } EagerExecutor& Executor() override; @@ -469,11 +472,12 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { Status FindCompositeDeviceFromName(StringPiece device_name, CompositeDevice** device) const; - bool FindCustomDeviceFromName(const string& device_name, - CustomDevice** dev) const; - Status RegisterCustomDevice(const string& name, - std::unique_ptr device); + std::unique_ptr device) override; + + CustomDeviceOpHandler& GetCustomDeviceOpHandler() override { + return custom_device_op_handler_; + }; // Find or create a composite device with the given `underlying_devices` and // `device_name` (if not empty). @@ -583,7 +587,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { TF_GUARDED_BY(device_type_list_mu_); Rendezvous* rendezvous_; std::function rendezvous_creator_; - std::unordered_map> custom_devices_; + CustomDeviceOpHandler custom_device_op_handler_; mutable mutex composite_devices_mu_; // Maps from the fingerprint of a set of device names to a virtual diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index 81b1e3594f2..905b1d94dad 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( *status = this->FindDeviceFromName(device_name, &device); if (!status->ok()) { tensorflow::CustomDevice* dev; - if (this->FindCustomDeviceFromName(device_name, &dev)) { + if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) { *status = dev->CopyTensorToDevice(handle, &result); if (status->ok()) { return result; @@ -128,7 +128,8 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( return nullptr; } tensorflow::CustomDevice* dev; - if (this->FindCustomDeviceFromName(handle_device_name, &dev)) { + if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name, + &dev)) { *status = dev->CopyTensorFromDevice(handle, device_name, &result); if (status->ok()) { return result; @@ -202,28 +203,8 @@ Status EagerOperation::Execute(absl::Span retvals, } } - // Decide to either run the operation on a custom device or copy off all of - // the custom device inputs. - VariantDevice maybe_custom_device = Device(); - if (absl::holds_alternative(maybe_custom_device) || - !inputs_are_tensor_handles_) { - // If the op wasn't placed on a custom device explicitly and there are no - // non-TensorHandle inputs, the op will definitely be placed on a physical - // device. Otherwise we need to check the inputs one by one. - TF_RETURN_IF_ERROR( - eager::MaybePinToCustomDevice(&maybe_custom_device, *this)); - if (absl::holds_alternative(maybe_custom_device)) { - ImmediateExecutionTensorHandle** retval_array = - reinterpret_cast(retvals.data()); - return absl::get(maybe_custom_device) - ->Execute(this, retval_array, num_retvals); - } else { - TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs()); - } - } - // Run eager placement logic. - class Device* device = absl::get(maybe_custom_device); + class Device* device = absl::get(Device()); if (device == nullptr) { TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this)); } diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc new file mode 100644 index 00000000000..719d113ea2a --- /dev/null +++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc @@ -0,0 +1,169 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" + +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); } + +Status CustomDeviceOpHandler::RegisterCustomDevice( + const string& device_name, std::unique_ptr device) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed) || + !parsed.has_job || !parsed.has_replica || !parsed.has_task || + !parsed.has_type || !parsed.has_id) { + return errors::InvalidArgument( + device_name, + " could not be parsed as a device name. Use the full " + "/job:/replica:/task:/device:: " + "format."); + } + + if (!custom_devices_.emplace(device_name, std::move(device)).second) { + return errors::AlreadyExists(device_name, + " already registered as a custom device."); + } + return Status::OK(); +} + +bool CustomDeviceOpHandler::FindCustomDeviceFromName( + const string& name, CustomDevice** device) const { + auto dev_it = custom_devices_.find(name); + if (dev_it == custom_devices_.end()) { + return false; + } + *device = dev_it->second.get(); + return true; +} + +Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) { + tensorflow::CustomDevice* custom_device = nullptr; + + TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op)); + + if (custom_device != nullptr) { + return custom_device->Execute(op, retvals, num_retvals); + } + + // The op will be placed on physical device. However, it contains custom + // device tensor handles. The tensor handles will be copy to physical device + // first. + if (op->HasCustomDeviceInput()) { + auto inputs = op->GetInputs(); + for (int i = 0; i < inputs.size(); ++i) { + auto target_device = op->DeviceName(); + if (target_device.empty()) { + target_device = op->GetContext()->HostCPUName(); + } + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) { + tensorflow::CustomDeviceTensorHandle* previous = + tensorflow::down_cast( + inputs[i]); + tensorflow::ImmediateExecutionTensorHandle* new_tesnor; + TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( + previous, target_device, &new_tesnor)); + Status s = op->SetInput(i, new_tesnor); + new_tesnor->Unref(); + TF_RETURN_IF_ERROR(s); + } + } + } + + return op->Execute( + absl::MakeSpan( + reinterpret_cast(retvals), + *num_retvals), + num_retvals); +} + +Status CustomDeviceOpHandler::MaybePinToCustomDevice( + CustomDevice** device, const ImmediateExecutionOperation& op) const { + CustomDevice* requested_device = nullptr; + if (!FindCustomDeviceFromName(op.DeviceName(), &requested_device) && + !op.HasCustomDeviceInput()) { + return Status::OK(); + } + + // Ops are placed on a custom device if there's no other explicit requested + // placement and there is only one custom device in the op + // inputs. + // + // Resource-dtype inputs take precedence over non-resource inputs and explicit + // placements; this function pins ops with a resource-dtype custom device + // input to that custom device. + CustomDevice* first = nullptr; + if (!op.GetInputs().empty()) { + for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) { + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (CustomDeviceTensorHandle::classof(generic_input)) { + const CustomDeviceTensorHandle* input = + down_cast(generic_input); + CustomDevice* current = input->device(); + if (first == nullptr) { + first = current; + } else if (first != current) { + return errors::InvalidArgument(absl::StrCat( + "If an operation has one of its inputs in a custom device, then " + "all inputs should be on that same custom device or another " + "physical device. Operation ", + op.Name(), + " has one input in custom " + "device ", + first->name(), + " and at least one input in a different custom device ", + current->name())); + } + } + } + for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) { + if (generic_input->DataType() == DT_RESOURCE) { + if (CustomDeviceTensorHandle::classof(generic_input)) { + const CustomDeviceTensorHandle* input = + down_cast(generic_input); + // There's only one custom device input, and it's a resource input, so + // we'll force-place the op on to that custom device. As with physical + // devices, this overrides any explicit placement for the op. + *device = input->device(); + return Status::OK(); + } else { + // Don't set a custom device if there's a physical-device resource + // input. + return Status::OK(); + } + } + } + } + // Since there are no resource-dtype inputs, we'll respect explicit placements + // before considering input-based placement. + if (requested_device != nullptr) { + *device = requested_device; + } else if (op.DeviceName().empty() && first != nullptr) { + // If there are non-resource inputs on a custom device we will default the + // op to that custom device, but not override an explicit op placement. + *device = first; + return Status::OK(); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.h b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h new file mode 100644 index 00000000000..00ac5f324ba --- /dev/null +++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h @@ -0,0 +1,51 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ + +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/lib/core/status.h" +namespace tensorflow { + +// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT. +class CustomDeviceOpHandler { + public: + ~CustomDeviceOpHandler() {} + // Register a new custom device. + Status RegisterCustomDevice(const string& device_name, + std::unique_ptr device); + + // Find the custom device from given name. Return true if it finds one. + bool FindCustomDeviceFromName(const string& name, + CustomDevice** device) const; + + Status Execute(ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, int* num_retvals); + + // Determine whether to place an op on a custom device. This method is + // exposed as public for test only. + Status MaybePinToCustomDevice(CustomDevice** device, + const ImmediateExecutionOperation& op) const; + + void Clear(); + + private: + std::unordered_map> custom_devices_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc index a642a816c76..cd7340e8463 100644 --- a/tensorflow/core/common_runtime/eager/custom_device_test.cc +++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc @@ -138,43 +138,47 @@ TEST(CustomDevice, TestResourcePlacement) { TF_ASSERT_OK(op.Reset("AssignVariableOp", "")); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); - VariantDevice placed_device(kVariantDeviceNull); - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + CustomDevice* placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // MaybePinToCustomDevice has no opinion about ops which have physical // resource-dtype inputs. They'll get placed on physical devices. - EXPECT_EQ(kVariantDeviceNull, placed_device); + EXPECT_EQ(nullptr, placed_device); op.Clear(); TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str())); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); - placed_device = kVariantDeviceNull; - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // Explicit placement onto a custom device also doesn't trigger custom device // placement if there's a physical device resource input. - EXPECT_EQ(kVariantDeviceNull, placed_device); + EXPECT_EQ(nullptr, placed_device); op.Clear(); TF_ASSERT_OK( op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0")); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); - placed_device = kVariantDeviceNull; - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // Explicit placements typically override input-based placement onto a custom // device. - EXPECT_EQ(kVariantDeviceNull, placed_device); + EXPECT_EQ(nullptr, placed_device); op.Clear(); TF_ASSERT_OK(op.Reset("AssignVariableOp", "/job:localhost/replica:0/task:0/device:CPU:0")); TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); - placed_device = kVariantDeviceNull; - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // Even with an explicit physical device placement, custom device resource // inputs place the op on the custom device. - ASSERT_TRUE(absl::holds_alternative(placed_device)); - EXPECT_EQ(&custom_device, absl::get(placed_device)); + ASSERT_NE(placed_device, nullptr); + EXPECT_EQ(&custom_device, placed_device); } } // namespace diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 883e9a8a8b0..de4a4495e87 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -36,7 +36,7 @@ void EagerOperation::Clear() { h->Unref(); } inputs_.clear(); - inputs_are_tensor_handles_ = true; + custom_device_tensor_handles_count_ = 0; ClearInferenceState(); } @@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) { down_cast(input); // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. if (CustomDeviceTensorHandle::classof(h)) { - inputs_are_tensor_handles_ = false; + custom_device_tensor_handles_count_++; } AddTensorHandle(h); return MaybeInferSingleInputAttrs(h); @@ -281,7 +281,7 @@ Status EagerOperation::AddInputList( // TODO(b/175427838): It would be nice to be able to use tensorflow::isa // here. if (CustomDeviceTensorHandle::classof(input)) { - inputs_are_tensor_handles_ = false; + custom_device_tensor_handles_count_++; } ImmediateExecutionTensorHandle* h = down_cast(input); @@ -290,6 +290,25 @@ Status EagerOperation::AddInputList( return InferInputListAttrs(inputs.size()); } +Status EagerOperation::SetInput(size_t index, + ImmediateExecutionTensorHandle* input) { + if (index >= inputs_.size()) { + return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index, + inputs_.size()); + } + auto* previous = inputs_[index]; + if (CustomDeviceTensorHandle::classof(previous)) { + custom_device_tensor_handles_count_--; + } + if (CustomDeviceTensorHandle::classof(input)) { + custom_device_tensor_handles_count_++; + } + input->Ref(); + inputs_[index] = input; + previous->Unref(); + return Status::OK(); +} + Status EagerOperation::Reset( const char* op, const char* device_name, bool remote, EagerExecutor* executor, @@ -407,7 +426,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { Status EagerOperation::TensorHandleInputs( const absl::InlinedVector** inputs) const { - if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { + if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { *inputs = reinterpret_cast*>( &inputs_); return Status::OK(); @@ -418,7 +437,7 @@ Status EagerOperation::TensorHandleInputs( Status EagerOperation::MutableTensorHandleInputs( absl::InlinedVector** inputs) { - if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { + if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { *inputs = reinterpret_cast*>(&inputs_); return Status::OK(); @@ -436,14 +455,7 @@ Status EagerOperation::SetDeviceName(const char* c_name) { } last_set_device_name_ = name; device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_); - CustomDevice* custom_device; - if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) { - device_ = custom_device; - } else { - // Device placement for physical devices happens lazily in - // EagerExecute/EagerRemoteExecute, and can depend on the inputs. - device_ = kVariantDeviceNull; - } + device_ = kVariantDeviceNull; } return Status::OK(); } @@ -495,30 +507,4 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) { attrs_.NumInputs(static_cast(inputs_.size())); } -Status EagerOperation::CopyOffCustomDeviceInputs() { - if (absl::holds_alternative(device_)) { - return errors::Internal( - "Trying to copy inputs to a custom device op off a custom device."); - } - for (int i = 0; i < inputs_.size(); ++i) { - // TODO(b/175427838): It would be nice to be able to use tensorflow::isa - // here. - if (CustomDeviceTensorHandle::classof(inputs_[i])) { - CustomDeviceTensorHandle* previous = - down_cast(inputs_[i]); - class Device* target_device; - if (device_ == kVariantDeviceNull) { - target_device = ctx_.HostCPU(); - } else { - target_device = absl::get(device_); - } - TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( - previous, target_device->name(), &inputs_[i])); - previous->Unref(); - } - } - inputs_are_tensor_handles_ = true; - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index e440a4a79dd..e1cb20b7575 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -55,6 +55,8 @@ class EagerOperation : public ImmediateExecutionOperation { const string& DeviceName() const override { return device_name_; } + ImmediateExecutionContext* GetContext() const override { return &ctx_; } + const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { return device_parsed_name_; } @@ -83,7 +85,11 @@ class EagerOperation : public ImmediateExecutionOperation { Status AddInput(AbstractTensorHandle* input) override; Status AddInputList(absl::Span inputs) override; + Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override; absl::Span GetInputs() const override; + bool HasCustomDeviceInput() const override { + return custom_device_tensor_handles_count_ > 0; + } Status Execute(absl::Span retvals, int* num_retvals) override; const tensorflow::OpDef* OpDef() const override { return op_def_; }; @@ -207,20 +213,14 @@ class EagerOperation : public ImmediateExecutionOperation { void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, const std::vector& dtypes); - // Replaces input tensors placed on custom devices with physical device - // equivalents. Used if an op is placed on a physical device but may have - // custom device inputs. - Status CopyOffCustomDeviceInputs(); - tensorflow::EagerContext& ctx_; const char* op_name_ = nullptr; AttrBuilder attrs_; const AttrTypeMap* attr_types_; - // Toggled to indicate whether all inputs are known to be TensorHandles and - // not another type (e.g. custom device tensor handles). Explicitly set to - // false when custom device TensorHandles are added. - bool inputs_are_tensor_handles_ = true; + // The number of custom device TensorHandle inputs. These inputs need to be + // processed by CustomDeviceOpHandler first. + int custom_device_tensor_handles_count_ = 0; absl::InlinedVector inputs_; // The last device name given to SetDeviceName. diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc index 77514d67e3a..3b9fa7bb2d1 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils.cc @@ -77,11 +77,6 @@ bool IsFunction(StringPiece op_name) { return false; } -bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) { - CustomDevice* custom_device; - return ctx.FindCustomDeviceFromName(string(device_name), &custom_device); -} - Status MaybePinSmallOpsToCpu( bool* result, StringPiece op_name, absl::Span args, @@ -182,70 +177,5 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) { return Status::OK(); } -Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) { - // Ops are placed on a custom device if there's no other explicit requested - // placement and there is only one custom device in the op - // inputs. - // - // Resource-dtype inputs take precedence over non-resource inputs and explicit - // placements; this function pins ops with a resource-dtype custom device - // input to that custom device. - CustomDevice* first = nullptr; - if (!op.Inputs().empty()) { - for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) { - // TODO(b/175427838): It would be nice to be able to use tensorflow::isa - // here. - if (CustomDeviceTensorHandle::classof(generic_input)) { - const CustomDeviceTensorHandle* input = - down_cast(generic_input); - CustomDevice* current = input->device(); - if (first == nullptr) { - first = current; - } else if (first != current) { - return errors::InvalidArgument(absl::StrCat( - "If an operation has one of its inputs in a custom device, then " - "all inputs should be on that same custom device or another " - "physical device. Operation ", - op.Name(), - " has one input in custom " - "device ", - first->name(), - " and at least one input in a different custom device ", - current->name())); - } - } - } - for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) { - if (generic_input->DataType() == DT_RESOURCE) { - if (CustomDeviceTensorHandle::classof(generic_input)) { - const CustomDeviceTensorHandle* input = - down_cast(generic_input); - // There's only one custom device input, and it's a resource input, so - // we'll force-place the op on to that custom device. As with physical - // devices, this overrides any explicit placement for the op. - *device = input->device(); - return Status::OK(); - } else { - // Don't set a custom device if there's a physical-device resource - // input. - return Status::OK(); - } - } - } - } - // Since there are no resource-dtype inputs, we'll respect explicit placements - // before considering input-based placement. - if (absl::holds_alternative(op.Device())) { - *device = op.Device(); - } else if (op.DeviceName().empty() && first != nullptr) { - // If there are non-resource inputs on a custom device we will default the - // op to that custom device, but not override an explicit op placement. - *device = first; - return Status::OK(); - } - - return Status::OK(); -} - } // namespace eager } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/placement_utils.h b/tensorflow/core/common_runtime/eager/placement_utils.h index 7676fe01b43..9435f9848d3 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.h +++ b/tensorflow/core/common_runtime/eager/placement_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ +#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" @@ -27,8 +28,6 @@ bool IsColocationExempt(StringPiece op_name); bool IsFunction(StringPiece op_name); -bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx); - // TODO(b/154234908): Unify placement logic. // TODO(b/159647422): Add C++ unit tests for placement logic. @@ -44,11 +43,6 @@ Status MaybePinSmallOpsToCpu( // the device the resource is, regardless of anything else that has been // specified. This is identical to the graph mode behavior. Status MaybePinToResourceDevice(Device** device, const EagerOperation& op); - -// If all the inputs are on the same custom device, use that custom -// device. Otherwise, it is an error to have a custom device as an input. -Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op); - } // namespace eager } // namespace tensorflow From cd3b1686951e524425a557efc18cd310c560dba0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Feb 2021 17:31:39 -0800 Subject: [PATCH 14/19] Add space for reporting memory bandwidth statistics PiperOrigin-RevId: 355961338 Change-Id: I0f5216bb48058e93d88d2171c91e4f0c94b2eaae --- tensorflow/core/kernels/data/BUILD | 1 - tensorflow/core/kernels/data/iterator_ops.cc | 17 +++-------------- tensorflow/core/platform/default/port.cc | 2 +- tensorflow/core/platform/mem.h | 1 - tensorflow/core/platform/windows/port.cc | 2 +- 5 files changed, 5 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 448a5def807..e5adb2dfba2 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -463,7 +463,6 @@ tf_kernel_library( "//tensorflow/core:session_options", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/profiler/lib:traceme_encode", "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index dce6ebd49bc..f3d2382db1f 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -47,12 +47,10 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/resource.h" #include "tensorflow/core/profiler/lib/traceme.h" -#include "tensorflow/core/profiler/lib/traceme_encode.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -967,18 +965,9 @@ void RecordElementSize(const std::vector element, Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { profiler::TraceMe traceme( [&] { - int64 mem_bw = port::GetMemoryInfo().bw_used; - - if (mem_bw != INT64_MAX) { - return profiler::TraceMeEncode( - "IteratorGetNextOp::DoCompute", - {{"id", ctx->step_id()}, - {"iter_num", ctx->frame_iter().iter_id}, - {"mem_bw_used", mem_bw}}); - } - return profiler::TraceMeEncode( - "IteratorGetNextOp::DoCompute", - {{"id", ctx->step_id()}, {"iter_num", ctx->frame_iter().iter_id}}); + return strings::StrCat( + "IteratorGetNextOp::DoCompute#id=", ctx->step_id(), + ",iter_num=", ctx->frame_iter().iter_id, "#"); }, profiler::kInfo); tensorflow::ResourceTagger tag(kTFDataResourceTag, diff --git a/tensorflow/core/platform/default/port.cc b/tensorflow/core/platform/default/port.cc index 6e82c67be99..e25ed074844 100644 --- a/tensorflow/core/platform/default/port.cc +++ b/tensorflow/core/platform/default/port.cc @@ -357,7 +357,7 @@ double NominalCPUFrequency() { } MemoryInfo GetMemoryInfo() { - MemoryInfo mem_info = {INT64_MAX, INT64_MAX, INT64_MAX}; + MemoryInfo mem_info = {INT64_MAX, INT64_MAX}; #if defined(__linux__) && !defined(__ANDROID__) struct sysinfo info; int err = sysinfo(&info); diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h index 65411eeac28..27ad3574182 100644 --- a/tensorflow/core/platform/mem.h +++ b/tensorflow/core/platform/mem.h @@ -62,7 +62,6 @@ std::size_t MallocExtension_GetAllocatedSize(const void* p); struct MemoryInfo { int64 total = 0; int64 free = 0; - int64 bw_used = 0; // memory bandwidth used across all CPU }; // Retrieves the host memory information. If any of the fields in the returned diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 256f525a38d..16b5a328256 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -192,7 +192,7 @@ double NominalCPUFrequency() { } MemoryInfo GetMemoryInfo() { - MemoryInfo mem_info = {INT64_MAX, INT64_MAX, INT64_MAX}; + MemoryInfo mem_info = {INT64_MAX, INT64_MAX}; MEMORYSTATUSEX statex; statex.dwLength = sizeof(statex); if (GlobalMemoryStatusEx(&statex)) { From 739392739a5fb77d50747206f73c5e1e738ac150 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Feb 2021 18:13:11 -0800 Subject: [PATCH 15/19] Internal change PiperOrigin-RevId: 355966509 Change-Id: I98a1428a66f1ddeec970a367572b1c5c1ea10ab2 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 2 - .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 111 --------- .../mlir/tensorflow/tests/canonicalize.mlir | 214 ++---------------- 3 files changed, 17 insertions(+), 310 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index bc61d27daf1..cce5f71651a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -14412,8 +14412,6 @@ receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; - let hasFolder = 1; - let verifier = [{ return VerifyStridedSliceBase(*this); }]; let extraClassDeclaration = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index ae255f6db00..0f8a423124f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -1886,117 +1886,6 @@ bool StridedSliceOp::GetSlicedBoundRanges( return true; } -OpFoldResult StridedSliceOp::fold(ArrayRef operands) { - // Fold StridedSlice operation if it extracts statically known dimensions. - // - // For example, - // - // %shape = tf.Shape(%arg) // %arg: tensor - // %height = tf.StridedSlice(%shape, 1, 2, 1) - // - // In this case %height can be replaced with a constant 2. - // - // Or, - // - // %shape = tf.Shape(%arg) // %arg: tensor - // %spatial_shape = tf.StridedSlice(%shape, 1, 3, 1) - // - // In this case %spatial_shape can be replaced with a constant [2, 3]. - - // Input to strided slice op is defined by shape operation. - auto shape_op = input().getDefiningOp(); - if (!shape_op) { - return {}; - } - - // `begin`, `end` and `strides` should be constant in order to infer static - // dimension. - DenseIntElementsAttr begin_attr, end_attr, strides_attr; - if (!matchPattern(begin(), m_Constant(&begin_attr)) || - !matchPattern(end(), m_Constant(&end_attr)) || - !matchPattern(strides(), m_Constant(&strides_attr)) || - begin_attr.getNumElements() != 1 || end_attr.getNumElements() != 1 || - strides_attr.getNumElements() != 1) { - return {}; - } - - auto tensor_ty = shape_op.input().getType().dyn_cast(); - // Only ranked tensor can be folded. - if (!tensor_ty) return {}; - - int64_t rank = tensor_ty.getRank(); - int64_t begin_int = begin_attr.getValue(0).getSExtValue(); - int64_t end_int = end_attr.getValue(0).getSExtValue(); - int64_t strides_int = strides_attr.getValue(0).getSExtValue(); - - // Canonicalize `begin` and `end` in case of negative index. - if (begin_int < 0) begin_int += rank; - if (end_int < 0) end_int += rank; - - // Create `begin` and `end` from `*_mask`. Note that we don't care about - // `new_axis_mask` as it can be inferred from `output_ty`. - if (shrink_axis_mask() == 1) { - // When `shrink_axis_mask` is set, output is always a scalar so only - // one element is sliced. - end_int = begin_int + 1; - } - if (begin_mask() == 1) { - begin_int = (strides_int > 0) ? 0 : rank - 1; - } - if (end_mask() == 1) { - end_int = (strides_int > 0) ? rank : -1; - } - if (ellipsis_mask() == 1) { - begin_int = 0; - end_int = rank; - } - - // It's possible that `begin` and `end` are out of bound. See - // https://docs.python.org/3/library/stdtypes.html#common-sequence-operations. - if (strides_int > 0) { - begin_int = std::min(begin_int, rank); - end_int = std::min(end_int, rank); - } else { - begin_int = std::min(begin_int, rank - 1); - end_int = std::min(end_int, rank - 1); - } - - SmallVector sub_shape; - // Only handle cases that have something to slice to avoid infinite for-loop. - if ((end_int > begin_int && strides_int > 0) || - (end_int < begin_int && strides_int < 0)) { - // Extract sub-shape only if all of those dimensions are static. - for (int64_t i = begin_int; (strides_int > 0) ? i < end_int : i > end_int; - i += strides_int) { - if (tensor_ty.isDynamicDim(i)) { - return {}; - } - sub_shape.push_back(tensor_ty.getDimSize(i)); - } - } - - // Down-cast to 32 bit int if needed. - auto output_elt_ty = output().getType().cast().getElementType(); - - auto output_ty = output().getType().dyn_cast(); - if (!output_ty) { - // If the output is unranked, we infer a result using a 0-ranked tensor for - // scalar element. - if (sub_shape.size() == 1) - output_ty = RankedTensorType::get({}, output_elt_ty); - else - output_ty = RankedTensorType::get( - {static_cast(sub_shape.size())}, output_elt_ty); - } - if (output_elt_ty.isInteger(32)) { - SmallVector sub_shape_i32(sub_shape.size()); - std::transform(sub_shape.begin(), sub_shape.end(), sub_shape_i32.begin(), - [](int64_t d) { return static_cast(d); }); - return DenseIntElementsAttr::get(output_ty, sub_shape_i32); - } - return DenseIntElementsAttr::get(output_ty, sub_shape); -} - //===----------------------------------------------------------------------===// // StridedSliceGradOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 64cb3480bca..e2a0552ef48 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -486,7 +486,7 @@ func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tenso } // CHECK-LABEL: func @testPackShapeComputation -func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { +func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { // Test dimensions sizes. %d1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %d2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor @@ -526,20 +526,26 @@ func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> // CHECK: %[[PACK0:.*]] = "tf.Pack" - // Packed dimensions have higher rank than the reshape operand: - // [?, 1] vs [?, 1, 1] - %16 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %17 = "tf.Pack"(%16, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> + // StridedSlice takes second dimension from the shape: + // begin = [1], end = [2], stride = [1] + %17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> // CHECK: %[[PACK1:.*]] = "tf.Pack" - // Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass - %18 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> - %19 = "tf.StridedSlice"(%18, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> - %20 = "tf.Pack"(%19, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> + // Packed dimensions have higher rank than the reshape operand: + // [?, 1] vs [?, 1, 1] + %20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> // CHECK: %[[PACK2:.*]] = "tf.Pack" - // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]] - return %5, %9, %15, %17, %20 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> + // Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass + %23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> + %24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> + %25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> + // CHECK: %[[PACK3:.*]] = "tf.Pack" + + // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]] + return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> } // CHECK-LABEL: testSelectScalarPred @@ -1367,189 +1373,3 @@ func @testUnpackAndCwiseUnary(%arg0: tensor) -> (tensor, tensor< // CHECK: return %[[UNPACK]]#0, %[[UNPACK]]#1 return %0, %1 : tensor, tensor } - -// CHECK-LABEL: testFoldStridedSliceShapeI32 -func @testFoldStridedSliceShapeI32(%arg0: tensor) -> (tensor<2xi32>) { - %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - return %3 : tensor<2xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeI64 -func @testFoldStridedSliceShapeI64(%arg0: tensor) -> (tensor<2xi64>) { - %0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi64> - %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> - return %3 : tensor<2xi64> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI32 -func @testFoldStridedSliceShapeWithShrinkAxisMaskI32(%arg0: tensor) -> (tensor) { - %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - return %3 : tensor - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskI64 -func @testFoldStridedSliceShapeWithShrinkAxisMaskI64(%arg0: tensor) -> (tensor) { - %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi64> - %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - return %3 : tensor - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1 -func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin1(%arg0: tensor) -> (tensor) { - %0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - return %4 : tensor - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor} : () -> tensor - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2 -func @testFoldStridedSliceShapeWithShrinkAxisMaskNegativeBegin2(%arg0: tensor) -> (tensor) { - %0 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - return %4 : tensor - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<2> : tensor} : () -> tensor - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testUnfoldedStridedSliceShape -func @testUnfoldedStridedSliceShape(%arg0: tensor) -> (tensor<2xi32>) { - %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - return %4 : tensor<2xi32> - // CHECK: %[[SLICE:.*]] = "tf.StridedSlice" - // CHECK: return %[[SLICE]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithBeginMask -func @testFoldStridedSliceShapeWithBeginMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<2xi32>) { - %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - return %4 : tensor<2xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithEndMask -func @testFoldStridedSliceShapeWithEndMask(%arg0: tensor) -> (tensor<3xi32>) { - %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - return %3 : tensor<3xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStrides -func @testFoldStridedSliceShapeWithPositiveStrides(%arg0: tensor<1x2x3x4x?xf32>) -> (tensor<2xi32>) { - %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x4x?xf32>) -> tensor<5xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - return %4 : tensor<2xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[2, 4]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd -func @testFoldStridedSliceShapeWithPositiveStridesOutOfBoundEnd(%arg0: tensor) -> (tensor<3xi32>) { - %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %3 = "tf.StridedSlice"(%2, %1, %0, %1) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - return %3 : tensor<3xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStrides -func @testFoldStridedSliceShapeWithNegativeStrides(%arg0: tensor<1x2x3x?xf32>) -> (tensor<1xi32>) { - %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - return %4 : tensor<1xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin -func @testFoldStridedSliceShapeWithNegativeStridesOutOfBoundBegin(%arg0: tensor) -> (tensor<2xi32>) { - %0 = "tf.Const"() {value = dense<20> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - return %4 : tensor<2xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesBeginMask -func @testFoldStridedSliceShapeWithNegativeStridesBeginMask(%arg0: tensor) -> (tensor<2xi32>) { - %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 1 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - return %4 : tensor<2xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithNegativeStridesEndMask -func @testFoldStridedSliceShapeWithNegativeStridesEndMask(%arg0: tensor<1x2x3x?xf32>) -> (tensor<3xi32>) { - %0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor<1x2x3x?xf32>) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> - return %4 : tensor<3xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<[3, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32> - // CHECK: return %[[CST]] -} - -// CHECK-LABEL: testFoldStridedSliceShapeWithEmptySlice -func @testFoldStridedSliceShapeWithEmptySlice(%arg0: tensor) -> (tensor<0xi32>) { - %0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<4xi32> - %4 = "tf.StridedSlice"(%3, %0, %1, %2) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - return %4 : tensor<0xi32> - // CHECK: %[[CST:.*]] = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> - // CHECK: return %[[CST]] -} From e1c5fa72f29b360f972d8411e666ada1dfd54981 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Fri, 5 Feb 2021 18:24:02 -0800 Subject: [PATCH 16/19] Fix issue preventing debug mode from working. Some constants were declared but not defined. This worked in opt mode as they were inlined, but not debug mode. PiperOrigin-RevId: 355967655 Change-Id: I9d0a4b7a91f98fa35b73f218c31177b90044e37b --- tensorflow/core/common_runtime/bfc_allocator.cc | 1 + tensorflow/core/framework/model.cc | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 3551472e9c6..b271bc5058f 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -35,6 +35,7 @@ limitations under the License. namespace tensorflow { constexpr BFCAllocator::ChunkHandle BFCAllocator::kInvalidChunkHandle; +constexpr uint64 BFCAllocator::kMemDebugHistorySize; BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory, bool allow_growth, const string& name, diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index 4bd2af5f9c6..d48e88be011 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -25,6 +25,10 @@ limitations under the License. namespace tensorflow { namespace data { namespace model { + +constexpr int64 Model::kOptimizationPeriodMinMs; +constexpr int64 Model::kOptimizationPeriodMaxMs; + namespace { // Helper function for node traversal that doesn't skip any nodes. From d487303b13a83362573be4c60b7b1c27a3990f8e Mon Sep 17 00:00:00 2001 From: Karim Nosir Date: Fri, 5 Feb 2021 18:40:58 -0800 Subject: [PATCH 17/19] [lite] Update tf_tfl_passes to run variable related passes post TF legalization. This includes - Legalizing Assign/Read variable ops. - Initialize the variables at graph start. - Remove unused function attrs and global tensors. All this is controlled by a flag (Default to false). PiperOrigin-RevId: 355969495 Change-Id: Ia6d346e4f8d3df6c373310acc8eb284a98726bfa --- .../compiler/mlir/lite/flatbuffer_export.cc | 31 +++++++++++++++++++ .../compiler/mlir/lite/tf_tfl_passes.cc | 5 +++ 2 files changed, 36 insertions(+) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 62841afa2e9..90d13b39289 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -38,6 +38,7 @@ limitations under the License. #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" @@ -525,6 +526,13 @@ class Translator { BufferOffset BuildNumericVerifyOperator( mlir::TFL::NumericVerifyOp op, const std::vector& operands, const std::vector& results); + + // Builds Assign/Read Variable ops. + template + BufferOffset BuildVariableOperator( + T op, const std::string& op_name, const std::vector& operands, + const std::vector& results); + BufferOffset BuildCustomOperator( Operation* inst, mlir::TFL::CustomOp op, const std::vector& operands, @@ -936,6 +944,17 @@ BufferOffset Translator::BuildNumericVerifyOperator( tflite::CustomOptionsFormat_FLEXBUFFERS); } +// Builds Assign/Read Variable ops. +template +BufferOffset Translator::BuildVariableOperator( + T op, const std::string& op_name, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); + return tflite::CreateOperator( + builder_, opcode_index, builder_.CreateVector(operands), + builder_.CreateVector(results), tflite::BuiltinOptions_NONE); +} + BufferOffset Translator::BuildCustomOperator( Operation* inst, mlir::TFL::CustomOp op, const std::vector& operands, const std::vector& results) { @@ -1077,6 +1096,18 @@ Optional> Translator::BuildOperator( return llvm::None; } + // TODO(b/149099381): Remove this once the kernels are promoted as + // builtin TFLite kernels. + // We export the Assign/Read variable ops as custom ops. + if (auto read_op = llvm::dyn_cast(inst)) { + return BuildVariableOperator( + read_op, "ReadVariable", operands, results); + } else if (auto assign_op = + llvm::dyn_cast(inst)) { + return BuildVariableOperator( + assign_op, "AssignVariable", operands, results); + } + // If TFLite built in op, create operator as a builtin op. if (dialect == tfl_dialect_) { // Only if built-in TFLite op emission is enabled, would legalization have diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 4555d1f8514..510f49ed41f 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -214,6 +214,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addNestedPass( mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); + if (pass_config.enable_tflite_variables) { + pass_manager->addPass(mlir::TFL::CreateInitializeVariablesPass()); + pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass()); + pass_manager->addPass(mlir::TFL::CreateRemoveArgsAndGlobalTensors()); + } pass_manager->addNestedPass(mlir::TFL::CreateOptimizePass()); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops From 8dd21a79d413a00df6fafd7fa98fa7408ec86006 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 5 Feb 2021 18:42:23 -0800 Subject: [PATCH 18/19] Move custom device placement from eager/execute.cc to c_api.cc. Then it can be reused by TFRT. PiperOrigin-RevId: 355969651 Change-Id: Ia3bf837fd15b171c06ede798005767879138922d --- tensorflow/c/eager/BUILD | 2 - tensorflow/c/eager/c_api.cc | 31 ++-- .../c/eager/immediate_execution_context.h | 16 -- .../c/eager/immediate_execution_operation.h | 11 -- tensorflow/core/common_runtime/eager/BUILD | 25 --- .../core/common_runtime/eager/context.cc | 29 ++- .../core/common_runtime/eager/context.h | 14 +- tensorflow/core/common_runtime/eager/core.cc | 27 ++- .../eager/custom_device_op_handler.cc | 169 ------------------ .../eager/custom_device_op_handler.h | 51 ------ .../eager/custom_device_test.cc | 30 ++-- .../common_runtime/eager/eager_operation.cc | 64 ++++--- .../common_runtime/eager/eager_operation.h | 18 +- .../common_runtime/eager/placement_utils.cc | 70 ++++++++ .../common_runtime/eager/placement_utils.h | 8 +- 15 files changed, 205 insertions(+), 360 deletions(-) delete mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.cc delete mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.h diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f9d726cb5d2..700912f3eff 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -73,11 +73,9 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context_distributed_manager", "//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:custom_device", - "//tensorflow/core/common_runtime/eager:custom_device_op_handler", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", - "//tensorflow/core/common_runtime/eager:placement_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5a31c434eaa..204db3078f4 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -41,9 +41,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" -#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" #include "tensorflow/core/common_runtime/eager/execute.h" -#include "tensorflow/core/common_runtime/eager/placement_utils.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -534,8 +532,7 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::CustomDevice* device = nullptr; - if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name, - &device)) { + if (!context->FindCustomDeviceFromName(device_name, &device)) { deallocator(data, arg); status->status = tensorflow::errors::InvalidArgument(device_name, " unknown device."); @@ -565,8 +562,7 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { - if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName( - device_name, &custom_device)) { + if (!context->FindCustomDeviceFromName(device_name, &custom_device)) { deallocator(data, len, deallocator_arg); status->status = tensorflow::errors::InvalidArgument(device_name, " unknown device."); @@ -658,7 +654,8 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { } TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { - return tensorflow::wrap(tensorflow::unwrap(op)->GetContext()); + return tensorflow::wrap( + &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext())); } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { @@ -892,15 +889,11 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - tensorflow::ImmediateExecutionOperation* unwrapped_op = - tensorflow::unwrap(op); - - status->status = - unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute( - unwrapped_op, - reinterpret_cast( - retvals), - num_retvals); + status->status = tensorflow::unwrap(op)->Execute( + absl::MakeSpan(reinterpret_cast( + tensorflow::unwrap(retvals)), + *num_retvals), + num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, @@ -1157,8 +1150,10 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, } auto custom_device = std::make_unique( ctx, device, device_info, device_name); - status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice( - device_name, std::move(custom_device)); + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + status->status = + context->RegisterCustomDevice(device_name, std::move(custom_device)); } } // extern "C" diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 6c2231017d3..abb24cb0c54 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -38,9 +38,6 @@ limitations under the License. namespace tensorflow { class EagerExecutor; -class EagerContext; -class CustomDevice; -class CustomDeviceOpHandler; // LINT.IfChange // Note: Keep in sync with exported copy of enum in eager/c_api.h. @@ -125,7 +122,6 @@ class ImmediateExecutionContext : public AbstractContext { // Return the ParsedName of Host CPU device. virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; - virtual const string& HostCPUName() const = 0; // Configure soft device placement policy. virtual void SetAllowSoftPlacement(bool enable) = 0; @@ -151,18 +147,6 @@ class ImmediateExecutionContext : public AbstractContext { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; } - //===--------------------------------------------------------------------===// - // Experimental Custom Device. - //===--------------------------------------------------------------------===// - virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0; - - // Register a custom device. It will return error is the device name is - // already registered. - // TODO(tfrt-devs): Remove this method. Let caller register it directly into - // CustomDeviceOpHandler. - virtual Status RegisterCustomDevice(const string& name, - std::unique_ptr device) = 0; - //===--------------------------------------------------------------------===// // Following are features in current TF Eager Runtime. // TODO(tfrt-devs): Figure out a way to deprecate following features after diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index a23177b56d5..85af5a706e1 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -33,8 +33,6 @@ struct TFE_Op; namespace tensorflow { -class ImmediateExecutionContext; - // Abstract interface to an operation. class ImmediateExecutionOperation : public AbstractOperation { public: @@ -43,15 +41,6 @@ class ImmediateExecutionOperation : public AbstractOperation { // Returns the inputs of this op. virtual absl::Span GetInputs() const = 0; - virtual Status SetInput(size_t index, - ImmediateExecutionTensorHandle* input) = 0; - - virtual ImmediateExecutionContext* GetContext() const = 0; - - // Following two methods are used to support custom device. - // Return true if the inputs contain custom device tensor handle. It means - // that the argument need to be handled by a custom device. - virtual bool HasCustomDeviceInput() const = 0; virtual const tensorflow::OpDef* OpDef() const = 0; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index dddfe47de6b..8549c32417a 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -87,7 +87,6 @@ tf_cuda_library( deps = [ ":eager_executor", ":kernel_and_device", - ":custom_device_op_handler", ":custom_device", "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_tensor_internal", @@ -141,28 +140,6 @@ tf_cuda_library( }), ) -tf_cuda_library( - name = "custom_device_op_handler", - srcs = ["custom_device_op_handler.cc"], - hdrs = ["custom_device_op_handler.h"], - visibility = ["//tensorflow:internal"], - deps = [ - ":custom_device", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:portable_tensorflow_lib_lite", - ], - "//conditions:default": [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/c/eager:immediate_execution_context", - "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/c/eager:immediate_execution_operation", - "//tensorflow/core/lib/core:status", - ], - }), -) - tf_cc_test( name = "custom_device_test", srcs = ["custom_device_test.cc"], @@ -670,7 +647,6 @@ tf_cuda_library( ":custom_device", ":attr_builder", ":eager_operation", - "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", ] + select({ "//tensorflow:android": [ @@ -738,7 +714,6 @@ filegroup( "attr_builder.h", "context.h", "custom_device.h", - "custom_device_op_handler.h", "eager_executor.h", "eager_operation.h", "kernel_and_device.h", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 7fe6e00928c..7c20766a1ce 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -522,7 +522,7 @@ EagerContext::~EagerContext() { // Custom devices may have obtained references to various context components // (executors, thread pool). It's safer to run their destructors early. - custom_device_op_handler_.Clear(); + custom_devices_.clear(); ClearCachesAndThreadExecutors(); std::unordered_map executors_copy; @@ -904,15 +904,38 @@ Status EagerContext::FindCompositeDeviceFromName( return errors::NotFound("Unknown composite device: ", device_name); } +bool EagerContext::FindCustomDeviceFromName(const string& device_name, + CustomDevice** dev) const { + auto dev_it = custom_devices_.find(device_name); + if (dev_it == custom_devices_.end()) { + return false; + } + *dev = dev_it->second.get(); + return true; +} + Status EagerContext::RegisterCustomDevice( const string& device_name, std::unique_ptr device) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed) || + !parsed.has_job || !parsed.has_replica || !parsed.has_task || + !parsed.has_type || !parsed.has_id) { + return errors::InvalidArgument( + device_name, + " could not be parsed as a device name. Use the full " + "/job:/replica:/task:/device:: " + "format."); + } Device* existing_physical_device = nullptr; if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) { return errors::AlreadyExists(device_name, " already registered as a physical device."); } - return custom_device_op_handler_.RegisterCustomDevice(device_name, - std::move(device)); + if (!custom_devices_.emplace(device_name, std::move(device)).second) { + return errors::AlreadyExists(device_name, + " already registered as a custom device."); + } + return Status::OK(); } Status EagerContext::FindOrCreateCompositeDevice( diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index fd6d896d9dd..28c0b40f43b 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" -#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/function.h" @@ -205,8 +204,6 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { return HostCPU()->parsed_name(); } - const string& HostCPUName() const override { return HostCPU()->name(); } - GraphCollector* GetGraphCollector() { return &graph_collector_; } EagerExecutor& Executor() override; @@ -472,12 +469,11 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { Status FindCompositeDeviceFromName(StringPiece device_name, CompositeDevice** device) const; - Status RegisterCustomDevice(const string& name, - std::unique_ptr device) override; + bool FindCustomDeviceFromName(const string& device_name, + CustomDevice** dev) const; - CustomDeviceOpHandler& GetCustomDeviceOpHandler() override { - return custom_device_op_handler_; - }; + Status RegisterCustomDevice(const string& name, + std::unique_ptr device); // Find or create a composite device with the given `underlying_devices` and // `device_name` (if not empty). @@ -587,7 +583,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { TF_GUARDED_BY(device_type_list_mu_); Rendezvous* rendezvous_; std::function rendezvous_creator_; - CustomDeviceOpHandler custom_device_op_handler_; + std::unordered_map> custom_devices_; mutable mutex composite_devices_mu_; // Maps from the fingerprint of a set of device names to a virtual diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index 905b1d94dad..81b1e3594f2 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( *status = this->FindDeviceFromName(device_name, &device); if (!status->ok()) { tensorflow::CustomDevice* dev; - if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) { + if (this->FindCustomDeviceFromName(device_name, &dev)) { *status = dev->CopyTensorToDevice(handle, &result); if (status->ok()) { return result; @@ -128,8 +128,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( return nullptr; } tensorflow::CustomDevice* dev; - if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name, - &dev)) { + if (this->FindCustomDeviceFromName(handle_device_name, &dev)) { *status = dev->CopyTensorFromDevice(handle, device_name, &result); if (status->ok()) { return result; @@ -203,8 +202,28 @@ Status EagerOperation::Execute(absl::Span retvals, } } + // Decide to either run the operation on a custom device or copy off all of + // the custom device inputs. + VariantDevice maybe_custom_device = Device(); + if (absl::holds_alternative(maybe_custom_device) || + !inputs_are_tensor_handles_) { + // If the op wasn't placed on a custom device explicitly and there are no + // non-TensorHandle inputs, the op will definitely be placed on a physical + // device. Otherwise we need to check the inputs one by one. + TF_RETURN_IF_ERROR( + eager::MaybePinToCustomDevice(&maybe_custom_device, *this)); + if (absl::holds_alternative(maybe_custom_device)) { + ImmediateExecutionTensorHandle** retval_array = + reinterpret_cast(retvals.data()); + return absl::get(maybe_custom_device) + ->Execute(this, retval_array, num_retvals); + } else { + TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs()); + } + } + // Run eager placement logic. - class Device* device = absl::get(Device()); + class Device* device = absl::get(maybe_custom_device); if (device == nullptr) { TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this)); } diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc deleted file mode 100644 index 719d113ea2a..00000000000 --- a/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc +++ /dev/null @@ -1,169 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" - -#include "tensorflow/core/platform/errors.h" - -namespace tensorflow { - -void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); } - -Status CustomDeviceOpHandler::RegisterCustomDevice( - const string& device_name, std::unique_ptr device) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device_name, &parsed) || - !parsed.has_job || !parsed.has_replica || !parsed.has_task || - !parsed.has_type || !parsed.has_id) { - return errors::InvalidArgument( - device_name, - " could not be parsed as a device name. Use the full " - "/job:/replica:/task:/device:: " - "format."); - } - - if (!custom_devices_.emplace(device_name, std::move(device)).second) { - return errors::AlreadyExists(device_name, - " already registered as a custom device."); - } - return Status::OK(); -} - -bool CustomDeviceOpHandler::FindCustomDeviceFromName( - const string& name, CustomDevice** device) const { - auto dev_it = custom_devices_.find(name); - if (dev_it == custom_devices_.end()) { - return false; - } - *device = dev_it->second.get(); - return true; -} - -Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op, - ImmediateExecutionTensorHandle** retvals, - int* num_retvals) { - tensorflow::CustomDevice* custom_device = nullptr; - - TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op)); - - if (custom_device != nullptr) { - return custom_device->Execute(op, retvals, num_retvals); - } - - // The op will be placed on physical device. However, it contains custom - // device tensor handles. The tensor handles will be copy to physical device - // first. - if (op->HasCustomDeviceInput()) { - auto inputs = op->GetInputs(); - for (int i = 0; i < inputs.size(); ++i) { - auto target_device = op->DeviceName(); - if (target_device.empty()) { - target_device = op->GetContext()->HostCPUName(); - } - // TODO(b/175427838): It would be nice to be able to use tensorflow::isa - // here. - if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) { - tensorflow::CustomDeviceTensorHandle* previous = - tensorflow::down_cast( - inputs[i]); - tensorflow::ImmediateExecutionTensorHandle* new_tesnor; - TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( - previous, target_device, &new_tesnor)); - Status s = op->SetInput(i, new_tesnor); - new_tesnor->Unref(); - TF_RETURN_IF_ERROR(s); - } - } - } - - return op->Execute( - absl::MakeSpan( - reinterpret_cast(retvals), - *num_retvals), - num_retvals); -} - -Status CustomDeviceOpHandler::MaybePinToCustomDevice( - CustomDevice** device, const ImmediateExecutionOperation& op) const { - CustomDevice* requested_device = nullptr; - if (!FindCustomDeviceFromName(op.DeviceName(), &requested_device) && - !op.HasCustomDeviceInput()) { - return Status::OK(); - } - - // Ops are placed on a custom device if there's no other explicit requested - // placement and there is only one custom device in the op - // inputs. - // - // Resource-dtype inputs take precedence over non-resource inputs and explicit - // placements; this function pins ops with a resource-dtype custom device - // input to that custom device. - CustomDevice* first = nullptr; - if (!op.GetInputs().empty()) { - for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) { - // TODO(b/175427838): It would be nice to be able to use tensorflow::isa - // here. - if (CustomDeviceTensorHandle::classof(generic_input)) { - const CustomDeviceTensorHandle* input = - down_cast(generic_input); - CustomDevice* current = input->device(); - if (first == nullptr) { - first = current; - } else if (first != current) { - return errors::InvalidArgument(absl::StrCat( - "If an operation has one of its inputs in a custom device, then " - "all inputs should be on that same custom device or another " - "physical device. Operation ", - op.Name(), - " has one input in custom " - "device ", - first->name(), - " and at least one input in a different custom device ", - current->name())); - } - } - } - for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) { - if (generic_input->DataType() == DT_RESOURCE) { - if (CustomDeviceTensorHandle::classof(generic_input)) { - const CustomDeviceTensorHandle* input = - down_cast(generic_input); - // There's only one custom device input, and it's a resource input, so - // we'll force-place the op on to that custom device. As with physical - // devices, this overrides any explicit placement for the op. - *device = input->device(); - return Status::OK(); - } else { - // Don't set a custom device if there's a physical-device resource - // input. - return Status::OK(); - } - } - } - } - // Since there are no resource-dtype inputs, we'll respect explicit placements - // before considering input-based placement. - if (requested_device != nullptr) { - *device = requested_device; - } else if (op.DeviceName().empty() && first != nullptr) { - // If there are non-resource inputs on a custom device we will default the - // op to that custom device, but not override an explicit op placement. - *device = first; - return Status::OK(); - } - return Status::OK(); -} - -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.h b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h deleted file mode 100644 index 00ac5f324ba..00000000000 --- a/tensorflow/core/common_runtime/eager/custom_device_op_handler.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ - -#include "tensorflow/c/eager/immediate_execution_operation.h" -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/core/common_runtime/eager/custom_device.h" -#include "tensorflow/core/lib/core/status.h" -namespace tensorflow { - -// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT. -class CustomDeviceOpHandler { - public: - ~CustomDeviceOpHandler() {} - // Register a new custom device. - Status RegisterCustomDevice(const string& device_name, - std::unique_ptr device); - - // Find the custom device from given name. Return true if it finds one. - bool FindCustomDeviceFromName(const string& name, - CustomDevice** device) const; - - Status Execute(ImmediateExecutionOperation* op, - ImmediateExecutionTensorHandle** retvals, int* num_retvals); - - // Determine whether to place an op on a custom device. This method is - // exposed as public for test only. - Status MaybePinToCustomDevice(CustomDevice** device, - const ImmediateExecutionOperation& op) const; - - void Clear(); - - private: - std::unordered_map> custom_devices_; -}; -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc index cd7340e8463..a642a816c76 100644 --- a/tensorflow/core/common_runtime/eager/custom_device_test.cc +++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc @@ -138,47 +138,43 @@ TEST(CustomDevice, TestResourcePlacement) { TF_ASSERT_OK(op.Reset("AssignVariableOp", "")); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); - CustomDevice* placed_device = nullptr; - TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( - &placed_device, op)); + VariantDevice placed_device(kVariantDeviceNull); + TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); // MaybePinToCustomDevice has no opinion about ops which have physical // resource-dtype inputs. They'll get placed on physical devices. - EXPECT_EQ(nullptr, placed_device); + EXPECT_EQ(kVariantDeviceNull, placed_device); op.Clear(); TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str())); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); - placed_device = nullptr; - TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( - &placed_device, op)); + placed_device = kVariantDeviceNull; + TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); // Explicit placement onto a custom device also doesn't trigger custom device // placement if there's a physical device resource input. - EXPECT_EQ(nullptr, placed_device); + EXPECT_EQ(kVariantDeviceNull, placed_device); op.Clear(); TF_ASSERT_OK( op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0")); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); - placed_device = nullptr; - TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( - &placed_device, op)); + placed_device = kVariantDeviceNull; + TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); // Explicit placements typically override input-based placement onto a custom // device. - EXPECT_EQ(nullptr, placed_device); + EXPECT_EQ(kVariantDeviceNull, placed_device); op.Clear(); TF_ASSERT_OK(op.Reset("AssignVariableOp", "/job:localhost/replica:0/task:0/device:CPU:0")); TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); - placed_device = nullptr; - TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( - &placed_device, op)); + placed_device = kVariantDeviceNull; + TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); // Even with an explicit physical device placement, custom device resource // inputs place the op on the custom device. - ASSERT_NE(placed_device, nullptr); - EXPECT_EQ(&custom_device, placed_device); + ASSERT_TRUE(absl::holds_alternative(placed_device)); + EXPECT_EQ(&custom_device, absl::get(placed_device)); } } // namespace diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index de4a4495e87..883e9a8a8b0 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -36,7 +36,7 @@ void EagerOperation::Clear() { h->Unref(); } inputs_.clear(); - custom_device_tensor_handles_count_ = 0; + inputs_are_tensor_handles_ = true; ClearInferenceState(); } @@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) { down_cast(input); // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. if (CustomDeviceTensorHandle::classof(h)) { - custom_device_tensor_handles_count_++; + inputs_are_tensor_handles_ = false; } AddTensorHandle(h); return MaybeInferSingleInputAttrs(h); @@ -281,7 +281,7 @@ Status EagerOperation::AddInputList( // TODO(b/175427838): It would be nice to be able to use tensorflow::isa // here. if (CustomDeviceTensorHandle::classof(input)) { - custom_device_tensor_handles_count_++; + inputs_are_tensor_handles_ = false; } ImmediateExecutionTensorHandle* h = down_cast(input); @@ -290,25 +290,6 @@ Status EagerOperation::AddInputList( return InferInputListAttrs(inputs.size()); } -Status EagerOperation::SetInput(size_t index, - ImmediateExecutionTensorHandle* input) { - if (index >= inputs_.size()) { - return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index, - inputs_.size()); - } - auto* previous = inputs_[index]; - if (CustomDeviceTensorHandle::classof(previous)) { - custom_device_tensor_handles_count_--; - } - if (CustomDeviceTensorHandle::classof(input)) { - custom_device_tensor_handles_count_++; - } - input->Ref(); - inputs_[index] = input; - previous->Unref(); - return Status::OK(); -} - Status EagerOperation::Reset( const char* op, const char* device_name, bool remote, EagerExecutor* executor, @@ -426,7 +407,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { Status EagerOperation::TensorHandleInputs( const absl::InlinedVector** inputs) const { - if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { + if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { *inputs = reinterpret_cast*>( &inputs_); return Status::OK(); @@ -437,7 +418,7 @@ Status EagerOperation::TensorHandleInputs( Status EagerOperation::MutableTensorHandleInputs( absl::InlinedVector** inputs) { - if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { + if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { *inputs = reinterpret_cast*>(&inputs_); return Status::OK(); @@ -455,7 +436,14 @@ Status EagerOperation::SetDeviceName(const char* c_name) { } last_set_device_name_ = name; device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_); - device_ = kVariantDeviceNull; + CustomDevice* custom_device; + if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) { + device_ = custom_device; + } else { + // Device placement for physical devices happens lazily in + // EagerExecute/EagerRemoteExecute, and can depend on the inputs. + device_ = kVariantDeviceNull; + } } return Status::OK(); } @@ -507,4 +495,30 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) { attrs_.NumInputs(static_cast(inputs_.size())); } +Status EagerOperation::CopyOffCustomDeviceInputs() { + if (absl::holds_alternative(device_)) { + return errors::Internal( + "Trying to copy inputs to a custom device op off a custom device."); + } + for (int i = 0; i < inputs_.size(); ++i) { + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (CustomDeviceTensorHandle::classof(inputs_[i])) { + CustomDeviceTensorHandle* previous = + down_cast(inputs_[i]); + class Device* target_device; + if (device_ == kVariantDeviceNull) { + target_device = ctx_.HostCPU(); + } else { + target_device = absl::get(device_); + } + TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( + previous, target_device->name(), &inputs_[i])); + previous->Unref(); + } + } + inputs_are_tensor_handles_ = true; + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index e1cb20b7575..e440a4a79dd 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -55,8 +55,6 @@ class EagerOperation : public ImmediateExecutionOperation { const string& DeviceName() const override { return device_name_; } - ImmediateExecutionContext* GetContext() const override { return &ctx_; } - const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { return device_parsed_name_; } @@ -85,11 +83,7 @@ class EagerOperation : public ImmediateExecutionOperation { Status AddInput(AbstractTensorHandle* input) override; Status AddInputList(absl::Span inputs) override; - Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override; absl::Span GetInputs() const override; - bool HasCustomDeviceInput() const override { - return custom_device_tensor_handles_count_ > 0; - } Status Execute(absl::Span retvals, int* num_retvals) override; const tensorflow::OpDef* OpDef() const override { return op_def_; }; @@ -213,14 +207,20 @@ class EagerOperation : public ImmediateExecutionOperation { void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, const std::vector& dtypes); + // Replaces input tensors placed on custom devices with physical device + // equivalents. Used if an op is placed on a physical device but may have + // custom device inputs. + Status CopyOffCustomDeviceInputs(); + tensorflow::EagerContext& ctx_; const char* op_name_ = nullptr; AttrBuilder attrs_; const AttrTypeMap* attr_types_; - // The number of custom device TensorHandle inputs. These inputs need to be - // processed by CustomDeviceOpHandler first. - int custom_device_tensor_handles_count_ = 0; + // Toggled to indicate whether all inputs are known to be TensorHandles and + // not another type (e.g. custom device tensor handles). Explicitly set to + // false when custom device TensorHandles are added. + bool inputs_are_tensor_handles_ = true; absl::InlinedVector inputs_; // The last device name given to SetDeviceName. diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc index 3b9fa7bb2d1..77514d67e3a 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils.cc @@ -77,6 +77,11 @@ bool IsFunction(StringPiece op_name) { return false; } +bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) { + CustomDevice* custom_device; + return ctx.FindCustomDeviceFromName(string(device_name), &custom_device); +} + Status MaybePinSmallOpsToCpu( bool* result, StringPiece op_name, absl::Span args, @@ -177,5 +182,70 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) { return Status::OK(); } +Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) { + // Ops are placed on a custom device if there's no other explicit requested + // placement and there is only one custom device in the op + // inputs. + // + // Resource-dtype inputs take precedence over non-resource inputs and explicit + // placements; this function pins ops with a resource-dtype custom device + // input to that custom device. + CustomDevice* first = nullptr; + if (!op.Inputs().empty()) { + for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) { + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (CustomDeviceTensorHandle::classof(generic_input)) { + const CustomDeviceTensorHandle* input = + down_cast(generic_input); + CustomDevice* current = input->device(); + if (first == nullptr) { + first = current; + } else if (first != current) { + return errors::InvalidArgument(absl::StrCat( + "If an operation has one of its inputs in a custom device, then " + "all inputs should be on that same custom device or another " + "physical device. Operation ", + op.Name(), + " has one input in custom " + "device ", + first->name(), + " and at least one input in a different custom device ", + current->name())); + } + } + } + for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) { + if (generic_input->DataType() == DT_RESOURCE) { + if (CustomDeviceTensorHandle::classof(generic_input)) { + const CustomDeviceTensorHandle* input = + down_cast(generic_input); + // There's only one custom device input, and it's a resource input, so + // we'll force-place the op on to that custom device. As with physical + // devices, this overrides any explicit placement for the op. + *device = input->device(); + return Status::OK(); + } else { + // Don't set a custom device if there's a physical-device resource + // input. + return Status::OK(); + } + } + } + } + // Since there are no resource-dtype inputs, we'll respect explicit placements + // before considering input-based placement. + if (absl::holds_alternative(op.Device())) { + *device = op.Device(); + } else if (op.DeviceName().empty() && first != nullptr) { + // If there are non-resource inputs on a custom device we will default the + // op to that custom device, but not override an explicit op placement. + *device = first; + return Status::OK(); + } + + return Status::OK(); +} + } // namespace eager } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/placement_utils.h b/tensorflow/core/common_runtime/eager/placement_utils.h index 9435f9848d3..7676fe01b43 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.h +++ b/tensorflow/core/common_runtime/eager/placement_utils.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ -#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" @@ -28,6 +27,8 @@ bool IsColocationExempt(StringPiece op_name); bool IsFunction(StringPiece op_name); +bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx); + // TODO(b/154234908): Unify placement logic. // TODO(b/159647422): Add C++ unit tests for placement logic. @@ -43,6 +44,11 @@ Status MaybePinSmallOpsToCpu( // the device the resource is, regardless of anything else that has been // specified. This is identical to the graph mode behavior. Status MaybePinToResourceDevice(Device** device, const EagerOperation& op); + +// If all the inputs are on the same custom device, use that custom +// device. Otherwise, it is an error to have a custom device as an input. +Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op); + } // namespace eager } // namespace tensorflow From 5bace7f1ca1dc43da4534df9f2e55d0f270601cf Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 5 Feb 2021 19:03:12 -0800 Subject: [PATCH 19/19] [libtpu] Don't include extra nullptr in GetLibTpuInitArguments(). PiperOrigin-RevId: 355971710 Change-Id: Ia169526087b1e8292c41aca38e1795ea1a7baa3f --- tensorflow/core/tpu/tpu_initializer_helper.cc | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/tpu/tpu_initializer_helper.cc b/tensorflow/core/tpu/tpu_initializer_helper.cc index 0856957d39c..c97a09b1cb0 100644 --- a/tensorflow/core/tpu/tpu_initializer_helper.cc +++ b/tensorflow/core/tpu/tpu_initializer_helper.cc @@ -26,23 +26,22 @@ std::pair, std::vector> GetLibTpuInitArguments() { // We make copies of the arguments returned by getenv because the memory // returned may be altered or invalidated by further calls to getenv. - std::vector argv; - std::vector argv_ptr; + std::vector args; + std::vector arg_ptrs; // Retrieve arguments from environment if applicable. char* env = getenv("LIBTPU_INIT_ARGS"); if (env != nullptr) { // TODO(frankchn): Handles quotes properly if necessary. - argv = absl::StrSplit(env, ' '); + args = absl::StrSplit(env, ' '); } - argv_ptr.reserve(argv.size()); - for (int i = 0; i < argv.size(); ++i) { - argv_ptr.push_back(argv[i].data()); + arg_ptrs.reserve(args.size()); + for (int i = 0; i < args.size(); ++i) { + arg_ptrs.push_back(args[i].data()); } - argv_ptr.push_back(nullptr); - return {argv, argv_ptr}; + return {args, arg_ptrs}; } } // namespace tpu